module Lava.MyST
  ( ST
  , STRef
  , newSTRef
  , readSTRef
  , writeSTRef
  , runST
  , fixST

  , unsafePerformST
  , unsafeInterleaveST
  , unsafeIOtoST
  )
 where

import Control.Applicative (Applicative (..))
import Control.Monad (ap)
import System.IO
import System.IO.Unsafe
import Data.IORef

newtype ST s a
  = ST (IO a)

unST :: ST s a -> IO a
unST :: ST s a -> IO a
unST (ST io :: IO a
io) = IO a
io

instance Functor (ST s) where
  fmap :: (a -> b) -> ST s a -> ST s b
fmap f :: a -> b
f (ST io :: IO a
io) = IO b -> ST s b
forall s a. IO a -> ST s a
ST ((a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f IO a
io)

instance Applicative (ST s) where
  pure :: a -> ST s a
pure  = a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return
  <*> :: ST s (a -> b) -> ST s a -> ST s b
(<*>) = ST s (a -> b) -> ST s a -> ST s b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad (ST s) where
  return :: a -> ST s a
return a :: a
a    = IO a -> ST s a
forall s a. IO a -> ST s a
ST (a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a)
  ST io :: IO a
io >>= :: ST s a -> (a -> ST s b) -> ST s b
>>= k :: a -> ST s b
k = IO b -> ST s b
forall s a. IO a -> ST s a
ST (do a
a <- IO a
io ; ST s b -> IO b
forall s a. ST s a -> IO a
unST (a -> ST s b
k a
a))

newtype STRef s a
  = STRef (IORef a)

instance Eq (STRef s a) where
  STRef r1 :: IORef a
r1 == :: STRef s a -> STRef s a -> Bool
== STRef r2 :: IORef a
r2 = IORef a
r1 IORef a -> IORef a -> Bool
forall a. Eq a => a -> a -> Bool
== IORef a
r2

newSTRef :: a -> ST s (STRef s a)
newSTRef :: a -> ST s (STRef s a)
newSTRef a :: a
a = IO (STRef s a) -> ST s (STRef s a)
forall s a. IO a -> ST s a
ST (IORef a -> STRef s a
forall s a. IORef a -> STRef s a
STRef (IORef a -> STRef s a) -> IO (IORef a) -> IO (STRef s a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef a
a)

readSTRef :: STRef s a -> ST s a
readSTRef :: STRef s a -> ST s a
readSTRef (STRef r :: IORef a
r) = IO a -> ST s a
forall s a. IO a -> ST s a
ST (IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
r)

writeSTRef :: STRef s a -> a -> ST s ()
writeSTRef :: STRef s a -> a -> ST s ()
writeSTRef (STRef r :: IORef a
r) a :: a
a = IO () -> ST s ()
forall s a. IO a -> ST s a
ST (IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
r a
a)

runST :: (forall s . ST s a) -> a
runST :: (forall s. ST s a) -> a
runST st :: forall s. ST s a
st = ST Any a -> a
forall s a. ST s a -> a
unsafePerformST ST Any a
forall s. ST s a
st

fixST :: (a -> ST s a) -> ST s a
fixST :: (a -> ST s a) -> ST s a
fixST f :: a -> ST s a
f = IO a -> ST s a
forall s a. IO a -> ST s a
ST ((a -> IO a) -> IO a
forall a. (a -> IO a) -> IO a
fixIO (ST s a -> IO a
forall s a. ST s a -> IO a
unST (ST s a -> IO a) -> (a -> ST s a) -> a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ST s a
f))

unsafePerformST :: ST s a -> a
unsafePerformST :: ST s a -> a
unsafePerformST (ST io :: IO a
io) = IO a -> a
forall a. IO a -> a
unsafePerformIO IO a
io

unsafeInterleaveST :: ST s a -> ST s a
unsafeInterleaveST :: ST s a -> ST s a
unsafeInterleaveST (ST io :: IO a
io) = IO a -> ST s a
forall s a. IO a -> ST s a
ST (IO a -> IO a
forall a. IO a -> IO a
unsafeInterleaveIO IO a
io)

unsafeIOtoST :: IO a -> ST s a
unsafeIOtoST :: IO a -> ST s a
unsafeIOtoST = IO a -> ST s a
forall s a. IO a -> ST s a
ST