summaryrefslogtreecommitdiff
path: root/compiler/simplCore/AnfiseCore.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/simplCore/AnfiseCore.hs')
-rw-r--r--compiler/simplCore/AnfiseCore.hs90
1 files changed, 90 insertions, 0 deletions
diff --git a/compiler/simplCore/AnfiseCore.hs b/compiler/simplCore/AnfiseCore.hs
new file mode 100644
index 0000000000..ef4cfb5404
--- /dev/null
+++ b/compiler/simplCore/AnfiseCore.hs
@@ -0,0 +1,90 @@
+-- | Convert Core into A-normal form (ANF).
+module AnfiseCore ( anfiseProgram ) where
+
+import BasicTypes
+import Type
+import Id
+import VarEnv
+import CoreUtils
+import CoreSyn
+import FastString
+import Unique
+
+import Data.Bifunctor
+import Data.Either
+import Control.Monad.Trans.State
+
+anfiseProgram :: CoreProgram -> CoreProgram
+anfiseProgram top_binds = map goTopLvl top_binds
+ where
+ goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
+ goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
+
+ in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds top_binds
+
+ go :: InScopeSet -> CoreExpr -> CoreExpr
+ go _ e@(Var{}) = e
+ go _ e@(Lit {}) = e
+ go _ e@(Type {}) = e
+ go _ e@(Coercion {}) = e
+
+ go in_scope e@(App e1 e2)
+ | Var f_id <- f
+ , isJoinId f_id
+ = dont_bind
+ | otherwise
+ = let bound_args :: [Either CoreExpr (Id, CoreExpr)]
+ bound_args = evalState (mapM bind_arg args) in_scope
+ where
+ bind_arg :: CoreExpr -> State InScopeSet (Either CoreExpr (Id, CoreExpr))
+ bind_arg arg
+ | not should_bind = return $ Left arg
+ | otherwise = do
+ bndr <- mkAnfId ty
+ nowInScope bndr
+ return $ Right (bndr, arg)
+ where
+ ty = exprType arg
+ should_bind = isValArg arg && not (isUnliftedType ty) && not (exprIsTrivial arg)
+ binds = map (uncurry NonRec) (rights bound_args)
+ to_arg = either id (Var . fst)
+ in mkLets binds $ mkApps f (map to_arg bound_args)
+ where
+ (f, args) = collectArgs e
+ dont_bind = App (go in_scope e1) (go in_scope e2)
+ go in_scope (Lam v e') = Lam v (go in_scope' e')
+ where in_scope' = in_scope `extendInScopeSet` v
+ go in_scope (Case scrut bndr ty alts)
+ = Case (go in_scope scrut) bndr ty (map (goAlt in_scope') alts)
+ where in_scope' = in_scope `extendInScopeSet` bndr
+ go in_scope (Cast e' c) = Cast (go in_scope e') c
+ go in_scope (Tick t e') = Tick t (go in_scope e')
+ go in_scope (Let bind body) = goBind in_scope bind (go in_scope' body)
+ where in_scope' = in_scope `extendInScopeSetList` bindersOf bind
+
+ goAlt :: InScopeSet -> CoreAlt -> CoreAlt
+ goAlt in_scope (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
+ where in_scope' = in_scope `extendInScopeSetList` pats
+
+ goBind :: InScopeSet -> CoreBind -> (CoreExpr -> CoreExpr)
+ goBind in_scope (NonRec v rhs) = Let (NonRec v (go in_scope rhs))
+ goBind in_scope (Rec pairs) = Let (Rec pairs')
+ where pairs' = map (second (go in_scope')) pairs
+ in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
+
+nowInScope :: Id -> State InScopeSet ()
+nowInScope id = modify (`extendInScopeSet` id)
+
+mkAnfId :: Type -> State InScopeSet Id
+mkAnfId ty = do
+ in_scope <- get
+ return $ uniqAway in_scope id_tmpl
+ where
+ id_tmpl = mkSysLocal (fsLit "anf") initExitJoinUnique ty
+ `setIdOccInfo` occ_info
+ occ_info =
+ OneOcc { occ_in_lam = insideLam
+ , occ_one_br = oneBranch
+ , occ_int_cxt = False
+ , occ_tail = NoTailCallInfo
+ }