-- | Tracing executed AST nodes via some mechanism.
module Test.Mutagen.Tracer.Trace
  ( -- * Tracing
    TraceNode
  , trace
  , Trace (..)
  , withTrace
  , truncateTrace
  )
where

import Data.IORef (IORef, atomicModifyIORef', newIORef, readIORef)
import System.IO.Unsafe (unsafePerformIO)

{-------------------------------------------------------------------------------
-- * Tracing
-------------------------------------------------------------------------------}

-- | AST node identifiers.
type TraceNode = Int

-- | Trace the evaluation of a given 'TraceNode'.
--
-- This function is intended to be used by the GHC plugin to instrument code.
trace :: TraceNode -> a -> a
trace :: forall a. TraceNode -> a -> a
trace TraceNode
n a
expr = IO a -> a
forall a. IO a -> a
unsafePerformIO (TraceNode -> IO ()
addTraceNode TraceNode
n IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
expr)
{-# INLINE trace #-}

-- | A dynamic trace keeping track of executed trace nodes.
newtype Trace = Trace {Trace -> [TraceNode]
unTrace :: [TraceNode]}
  deriving (TraceNode -> Trace -> ShowS
[Trace] -> ShowS
Trace -> String
(TraceNode -> Trace -> ShowS)
-> (Trace -> String) -> ([Trace] -> ShowS) -> Show Trace
forall a.
(TraceNode -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: TraceNode -> Trace -> ShowS
showsPrec :: TraceNode -> Trace -> ShowS
$cshow :: Trace -> String
show :: Trace -> String
$cshowList :: [Trace] -> ShowS
showList :: [Trace] -> ShowS
Show)

-- | Global 'IORef' holding the current execution trace.
globalTraceRef :: IORef Trace
globalTraceRef :: IORef Trace
globalTraceRef = IO (IORef Trace) -> IORef Trace
forall a. IO a -> a
unsafePerformIO (Trace -> IO (IORef Trace)
forall a. a -> IO (IORef a)
newIORef ([TraceNode] -> Trace
Trace []))
{-# NOINLINE globalTraceRef #-}

-- | Add a new entry to the current global trace.
addTraceNode :: TraceNode -> IO ()
addTraceNode :: TraceNode -> IO ()
addTraceNode TraceNode
n =
  IORef Trace -> (Trace -> (Trace, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Trace
globalTraceRef ((Trace -> (Trace, ())) -> IO ())
-> (Trace -> (Trace, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Trace [TraceNode]
entries) ->
    ([TraceNode] -> Trace
Trace (TraceNode
n TraceNode -> [TraceNode] -> [TraceNode]
forall a. a -> [a] -> [a]
: [TraceNode]
entries), ())

-- | Reset the global trace.
resetTraceRef :: IO ()
resetTraceRef :: IO ()
resetTraceRef =
  IORef Trace -> (Trace -> (Trace, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Trace
globalTraceRef ((Trace -> (Trace, ())) -> IO ())
-> (Trace -> (Trace, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Trace
_ ->
    ([TraceNode] -> Trace
Trace [], ())

-- | Read the current global trace.
readTraceRef :: IO Trace
readTraceRef :: IO Trace
readTraceRef = do
  Trace entries <- IORef Trace -> IO Trace
forall a. IORef a -> IO a
readIORef IORef Trace
globalTraceRef
  return (Trace (reverse entries))

-- | Run a computation and obtain its trace.
withTrace :: IO a -> IO (a, Trace)
withTrace :: forall a. IO a -> IO (a, Trace)
withTrace IO a
io = do
  IO ()
resetTraceRef
  a <- IO a
io
  tr <- readTraceRef
  return (a, tr)

-- | Truncate a trace to a given length.
truncateTrace :: Int -> Trace -> Trace
truncateTrace :: TraceNode -> Trace -> Trace
truncateTrace TraceNode
n (Trace [TraceNode]
entries) =
  [TraceNode] -> Trace
Trace (TraceNode -> [TraceNode] -> [TraceNode]
forall a. TraceNode -> [a] -> [a]
take TraceNode
n [TraceNode]
entries)