#!/usr/bin/env nix-script
#!>haskell

{- Learn you a Haskell exercise -}

module Data.Probability 
 ( Event, Prob
 , flatten
 , onProb
 , sumCheck
 , probs
 , group
 ) where

import Data.Ratio
import Data.Function (on)
import Data.List     (groupBy)
import Control.Monad (ap)
import Control.Applicative


type    Event a = (Rational, a)
newtype Prob  a = Prob { getProb :: [Event a] }


instance Show a => Show (Prob a) where
  show (Prob xs) = show xs

instance Functor Prob where
  fmap f (Prob xs) = Prob $ map (fmap f) xs

instance Applicative Prob where
  pure  = return
  (<*>) = ap

instance Monad Prob where
  return x = Prob [(1, x)]
  m >>= f  = flatten $ f <$> m
  fail _   = Prob []


-- Flatten nested probabilities by one level
flatten :: Prob (Prob a) -> Prob a  
flatten = onProb $ concat . map (uncurry multBy)
    where multBy p = map (\(r,x) -> (p*r,x)) . getProb


-- Raise a function to work on the Prob newtype
onProb :: ([Event a] -> [Event b]) -> (Prob a -> Prob b)
onProb f = Prob . f . getProb


-- Get the probabily of every event
probs :: Prob a -> [Rational]
probs = map fst . getProb


-- Check whether the 1-norm of the probability is 1
sumCheck :: Prob a -> Bool
sumCheck x = sum (probs x) == 1


-- Group events with the same outcome
group :: Eq a => Prob a -> Prob a
group = onProb $ map head . groupBy ((==) `on` snd)