{-# LANGUAGE TemplateHaskellQuotes #-}

-- | Derive 'Mutable' instances for a given data type.
module Test.Mutagen.TH.Mutable
  ( deriveMutable
  )
where

import Control.Monad (forM, guard)
import Data.List (sortOn)
import Language.Haskell.TH
  ( Lit (..)
  , Name
  , Q
  , newName
  )
import Language.Haskell.TH.Desugar
  ( DClause (..)
  , DCon (..)
  , DDec (..)
  , DExp (..)
  , DLetDec (..)
  , DPat (..)
  , DType (..)
  , dLamE
  , dPatToDExp
  , mkTupleDExp
  )
import Test.Mutagen.Fragment.Store
  ( sampleFragments
  )
import Test.Mutagen.Mutant
  ( Mutant (..)
  )
import Test.Mutagen.Mutation
  ( Mutable
  , def
  , inside
  , invalidPosition
  , mutate
  , node
  , positions
  , wrap
  )
import Test.Mutagen.TH.Util
  ( applyTyVars
  , createDPat
  , dConFields
  , dConFieldsNum
  , dConFieldsTypes
  , dConName
  , dTyVarBndrName
  , mkConDExp
  , mkListDExp
  , mutagenError
  , reifyTypeDef
  )

{-------------------------------------------------------------------------------
-- * Deriving Mutable instances
-------------------------------------------------------------------------------}

-- | Derive a 'Mutable' instance for the given data type.
deriveMutable :: Name -> [Name] -> Maybe Name -> Q [DDec]
deriveMutable :: Name -> [Name] -> Maybe Name -> Q [DDec]
deriveMutable Name
typeName [Name]
ignoredCons Maybe Name
mbDef = do
  -- Reify the type definition
  (dtvbs, dcons) <- Name -> Q ([DTyVarBndrVis], [DCon])
reifyTypeDef Name
typeName
  -- Apply the context type variables to the type name to get 'Type'-kinded
  -- target type to derive the instance for
  let targetType = Name -> [DTyVarBndrVis] -> DType
applyTyVars Name
typeName [DTyVarBndrVis]
dtvbs
  -- Keep only the constructors that are not ignored
  let wantedCons = (DCon -> Bool) -> [DCon] -> [DCon]
forall a. (a -> Bool) -> [a] -> [a]
filter (\DCon
con -> DCon -> Name
dConName DCon
con Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Name]
ignoredCons) [DCon]
dcons
  -- Derive each function in the Mutable class separately
  insDef <- deriveDef targetType mbDef wantedCons
  insPositions <- derivePositions wantedCons
  insInside <- deriveInside wantedCons
  insMutate <- deriveMutate wantedCons
  -- Build the Mutable instance
  let insCxt = [Name -> DType
DConT ''Mutable DType -> DType -> DType
`DAppT` Name -> DType
DVarT (DTyVarBndrVis -> Name
dTyVarBndrName DTyVarBndrVis
tvb) | DTyVarBndrVis
tvb <- [DTyVarBndrVis]
dtvbs]
  let insTy = Name -> DType
DConT ''Mutable DType -> DType -> DType
`DAppT` DType
targetType
  let insBody = [[DDec]] -> [DDec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[DDec]
insDef, [DDec]
insPositions, [DDec]
insInside, [DDec]
insMutate]
  return [DInstanceD Nothing Nothing insCxt insTy insBody]

-- | Derive the 'def' method for the 'Mutable' type class.
deriveDef :: DType -> Maybe Name -> [DCon] -> Q [DDec]
deriveDef :: DType -> Maybe Name -> [DCon] -> Q [DDec]
deriveDef DType
dty Maybe Name
mbDef [DCon]
cons = do
  defValue <-
    case Maybe Name
mbDef of
      Just Name
var -> do
        DExp -> Q DExp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (DExp -> Q DExp) -> DExp -> Q DExp
forall a b. (a -> b) -> a -> b
$ Name -> DExp
DVarE Name
var
      Maybe Name
Nothing -> do
        let terms :: [DCon]
terms = (DCon -> Bool) -> [DCon] -> [DCon]
forall a. (a -> Bool) -> [a] -> [a]
filter (DType -> [DType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
notElem DType
dty ([DType] -> Bool) -> (DCon -> [DType]) -> DCon -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DConFields -> [DType]
dConFieldsTypes (DConFields -> [DType]) -> (DCon -> DConFields) -> DCon -> [DType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DCon -> DConFields
dConFields) [DCon]
cons
        let sorted :: [DCon]
sorted = (DCon -> Int) -> [DCon] -> [DCon]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (DConFields -> Int
dConFieldsNum (DConFields -> Int) -> (DCon -> DConFields) -> DCon -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DCon -> DConFields
dConFields) [DCon]
terms
        smallest <-
          case [DCon]
sorted of
            [] ->
              String -> [[DCon]] -> Q DCon
forall a b. Show a => String -> [a] -> Q b
mutagenError
                ( String
"could not find a proper constructor to derive 'def' with, "
                    String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"please define a default value manually via 'optDefault'"
                )
                [[DCon]
sorted]
            DCon
con : [DCon]
_ ->
              DCon -> Q DCon
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return DCon
con
        return
          $ mkConDExp
            (dConName smallest)
            (replicate (dConFieldsNum (dConFields smallest)) (DVarE 'def))
  return [DLetDec (DFunD 'def [DClause [] defValue])]

-- | Derive the 'positions' method for the 'Mutable' type class.
derivePositions :: [DCon] -> Q [DDec]
derivePositions :: [DCon] -> Q [DDec]
derivePositions [DCon]
cons = do
  clauses <-
    [DCon] -> (DCon -> Q DClause) -> Q [DClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [DCon]
cons ((DCon -> Q DClause) -> Q [DClause])
-> (DCon -> Q DClause) -> Q [DClause]
forall a b. (a -> b) -> a -> b
$ \DCon
con -> do
      (vars, pat) <- DCon -> Q ([Name], DPat)
createDPat DCon
con
      let clauseBody =
            Name -> DExp
DVarE 'node
              DExp -> DExp -> DExp
`DAppE` [DExp] -> DExp
mkListDExp
                [ [DExp] -> DExp
mkTupleDExp
                    [ Lit -> DExp
DLitE (Integer -> Lit
IntegerL Integer
n)
                    , Name -> DExp
DVarE 'positions DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
var
                    ]
                | (Integer
n, Name
var) <- [Integer] -> [Name] -> [(Integer, Name)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [Name]
vars
                ]
      return (DClause [pat] clauseBody)
  return [DLetDec (DFunD 'positions clauses)]

-- | Derive the 'inside' method for the 'Mutable' type class.
deriveInside :: [DCon] -> Q [DDec]
deriveInside :: [DCon] -> Q [DDec]
deriveInside [DCon]
cons = do
  pos_ <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"pos"
  mut_ <- newName "mut"
  x_ <- newName "x"
  -- First clause
  let firstClause =
        [DPat] -> DExp -> DClause
DClause
          [Name -> [DType] -> [DPat] -> DPat
DConP '[] [] [], Name -> DPat
DVarP Name
mut_, Name -> DPat
DVarP Name
x_]
          (Name -> DExp
DVarE Name
mut_ DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
x_)
  -- Recursive constructor clauses
  conClauses <-
    forM cons $ \DCon
con -> do
      (vars, pat) <- DCon -> Q ([Name], DPat)
createDPat DCon
con
      forM [0 .. length vars - 1] $ \Int
idx -> do
        let posPat :: DPat
posPat =
              Name -> [DType] -> [DPat] -> DPat
DConP '(:) [] [Lit -> DPat
DLitP (Integer -> Lit
IntegerL (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
idx)), Name -> DPat
DVarP Name
pos_]
        let mutPat :: DPat
mutPat =
              Name -> DPat
DVarP Name
mut_
        let insideExpr :: DExp
insideExpr =
              Name -> DExp
DVarE 'inside
                DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
pos_
                DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
mut_
                DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE ([Name]
vars [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
idx)
        let lamExpr :: DExp
lamExpr =
              [DPat] -> DExp -> DExp
dLamE
                [Name -> DPat
DVarP Name
x_]
                ( Name -> [DExp] -> DExp
mkConDExp
                    (DCon -> Name
dConName DCon
con)
                    [Name -> DExp
DVarE Name
v | Name
v <- Int -> Name -> [Name] -> [Name]
forall a. Int -> a -> [a] -> [a]
replaceAt Int
idx Name
x_ [Name]
vars]
                )
        let clauseBody :: DExp
clauseBody =
              Name -> DExp
DVarE 'wrap
                DExp -> DExp -> DExp
`DAppE` DExp
insideExpr
                DExp -> DExp -> DExp
`DAppE` DExp
lamExpr
        DClause -> Q DClause
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
          (DClause -> Q DClause) -> DClause -> Q DClause
forall a b. (a -> b) -> a -> b
$ [DPat] -> DExp -> DClause
DClause
            [DPat
posPat, DPat
mutPat, DPat
pat]
            DExp
clauseBody
  -- Last clause (error message)
  let lastClause =
        [DPat] -> DExp -> DClause
DClause
          [Name -> DPat
DVarP Name
pos_, DPat
DWildP, DPat
DWildP]
          (Name -> DExp
DVarE 'invalidPosition DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
pos_)
  -- Combine all clauses
  let clauses = [DClause
firstClause] [DClause] -> [DClause] -> [DClause]
forall a. Semigroup a => a -> a -> a
<> [[DClause]] -> [DClause]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[DClause]]
conClauses [DClause] -> [DClause] -> [DClause]
forall a. Semigroup a => a -> a -> a
<> [DClause
lastClause]
  return [DLetDec (DFunD 'inside clauses)]

-- | Derive the 'mutate' method for the 'Mutable' type class.
deriveMutate :: [DCon] -> Q [DDec]
deriveMutate :: [DCon] -> Q [DDec]
deriveMutate [DCon]
cons = do
  clauses <-
    [DCon] -> (DCon -> Q DClause) -> Q [DClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [DCon]
cons ((DCon -> Q DClause) -> Q [DClause])
-> (DCon -> Q DClause) -> Q [DClause]
forall a b. (a -> b) -> a -> b
$ \DCon
con -> do
      (vars, pat) <- DCon -> Q ([Name], DPat)
createDPat DCon
con
      -- Fragment mutation
      let fragMutants =
            Name -> DExp
DConE 'Frag DExp -> DExp -> DExp
`DAppE` (Name -> DExp
DVarE 'sampleFragments DExp -> DExp -> DExp
`DAppE` DPat -> DExp
dPatToDExp DPat
pat)
      -- Pure mutations
      let pureMutants =
            [ Name -> DExp
DConE 'Pure DExp -> DExp -> DExp
`DAppE` DExp
mutatedCon
            | let fieldTypes :: [(Name, DType)]
fieldTypes = [Name] -> [DType] -> [(Name, DType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
vars (DConFields -> [DType]
dConFieldsTypes (DCon -> DConFields
dConFields DCon
con))
            , let name :: Name
name = DCon -> Name
dConName DCon
con
            , DExp
mutatedCon <- Name -> [(Name, DType)] -> [DCon] -> [DExp]
mutateCon Name
name [(Name, DType)]
fieldTypes [DCon]
cons
            ]
      let clauseBody = [DExp] -> DExp
mkListDExp (DExp
fragMutants DExp -> [DExp] -> [DExp]
forall a. a -> [a] -> [a]
: [DExp]
pureMutants)
      return (DClause [pat] clauseBody)
  return [DLetDec (DFunD 'mutate clauses)]

-- | Generate mutated constructor expressions
--
-- NOTE: this looks kinda funky because it's is defined using the list monad :)
mutateCon :: Name -> [(Name, DType)] -> [DCon] -> [DExp]
mutateCon :: Name -> [(Name, DType)] -> [DCon] -> [DExp]
mutateCon Name
name [(Name, DType)]
fieldTypes [DCon]
cons = do
  con <- [DCon]
cons
  mutation <- validMutations con
  guard (mutation /= nullMutation)
  return mutation
  where
    -- Combine valid field substitutions into saturated constructor expressions
    validMutations :: DCon -> [DExp]
validMutations DCon
con =
      DExp -> [[Name]] -> [DExp]
forall {m :: * -> *}. Monad m => DExp -> [m Name] -> m DExp
combineMutatedFields
        (Name -> DExp
DConE (DCon -> Name
dConName DCon
con))
        (DCon -> [[Name]]
validFieldSubstitutions DCon
con)
    -- Recursively build all valid combinations of mutated fields
    combineMutatedFields :: DExp -> [m Name] -> m DExp
combineMutatedFields DExp
acc [] = do
      DExp -> m DExp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return DExp
acc
    combineMutatedFields DExp
acc (m Name
fields : [m Name]
rest) = do
      field <- m Name
fields
      combineMutatedFields (acc `DAppE` DVarE field) rest
    -- For each field of a constructor, find valid substitutions of the same
    -- type, using this instance's 'def' value if none are found
    validFieldSubstitutions :: DCon -> [[Name]]
validFieldSubstitutions DCon
con =
      [ if [Name] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
subst then ['def] else [Name]
subst
      | DType
fty <- DConFields -> [DType]
dConFieldsTypes (DCon -> DConFields
dConFields DCon
con)
      , let subst :: [Name]
subst = DType -> [Name]
validSubstitutions DType
fty
      ]
    -- Find all variables that can substitute a given field type
    validSubstitutions :: DType -> [Name]
validSubstitutions DType
ty' = do
      (field, ty) <- [(Name, DType)]
fieldTypes
      guard (ty' == ty)
      return field
    -- The null mutation is the one that leaves all fields unchanged
    nullMutation :: DExp
nullMutation =
      Name -> [DExp] -> DExp
mkConDExp Name
name [Name -> DExp
DVarE Name
field | (Name
field, DType
_) <- [(Name, DType)]
fieldTypes]

-- | Replace the element at the given index in a list.
replaceAt :: Int -> a -> [a] -> [a]
replaceAt :: forall a. Int -> a -> [a] -> [a]
replaceAt Int
n a
y [a]
xs = [if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n then a
y else a
x | (a
x, Int
i) <- [a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [Int
0 ..]]