diff options
Diffstat (limited to 'compiler/simplCore/SAT.hs')
-rw-r--r-- | compiler/simplCore/SAT.hs | 431 |
1 files changed, 431 insertions, 0 deletions
diff --git a/compiler/simplCore/SAT.hs b/compiler/simplCore/SAT.hs new file mode 100644 index 0000000000..dc76df0e08 --- /dev/null +++ b/compiler/simplCore/SAT.hs @@ -0,0 +1,431 @@ +{- +(c) The GRASP/AQUA Project, Glasgow University, 1992-1998 + + +************************************************************************ + + Static Argument Transformation pass + +************************************************************************ + +May be seen as removing invariants from loops: +Arguments of recursive functions that do not change in recursive +calls are removed from the recursion, which is done locally +and only passes the arguments which effectively change. + +Example: +map = /\ ab -> \f -> \xs -> case xs of + [] -> [] + (a:b) -> f a : map f b + +as map is recursively called with the same argument f (unmodified) +we transform it to + +map = /\ ab -> \f -> \xs -> let map' ys = case ys of + [] -> [] + (a:b) -> f a : map' b + in map' xs + +Notice that for a compiler that uses lambda lifting this is +useless as map' will be transformed back to what map was. + +We could possibly do the same for big lambdas, but we don't as +they will eventually be removed in later stages of the compiler, +therefore there is no penalty in keeping them. + +We only apply the SAT when the number of static args is > 2. This +produces few bad cases. See + should_transform +in saTransform. + +Here are the headline nofib results: + Size Allocs Runtime +Min +0.0% -13.7% -21.4% +Max +0.1% +0.0% +5.4% +Geometric Mean +0.0% -0.2% -6.9% + +The previous patch, to fix polymorphic floatout demand signatures, is +essential to make this work well! +-} + +{-# LANGUAGE CPP #-} +module SAT ( doStaticArgs ) where + +import Var +import CoreSyn +import CoreUtils +import Type +import Coercion +import Id +import Name +import VarEnv +import UniqSupply +import Util +import UniqFM +import VarSet +import Unique +import UniqSet +import Outputable + +import Data.List +import FastString + +#include "HsVersions.h" + +doStaticArgs :: UniqSupply -> CoreProgram -> CoreProgram +doStaticArgs us binds = snd $ mapAccumL sat_bind_threaded_us us binds + where + sat_bind_threaded_us us bind = + let (us1, us2) = splitUniqSupply us + in (us1, fst $ runSAT us2 (satBind bind emptyUniqSet)) + +-- We don't bother to SAT recursive groups since it can lead +-- to massive code expansion: see Andre Santos' thesis for details. +-- This means we only apply the actual SAT to Rec groups of one element, +-- but we want to recurse into the others anyway to discover other binds +satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo) +satBind (NonRec binder expr) interesting_ids = do + (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids + return (NonRec binder expr', finalizeApp expr_app sat_info_expr) +satBind (Rec [(binder, rhs)]) interesting_ids = do + let interesting_ids' = interesting_ids `addOneToUniqSet` binder + (rhs_binders, rhs_body) = collectBinders rhs + (rhs_body', sat_info_rhs_body) <- satTopLevelExpr rhs_body interesting_ids' + let sat_info_rhs_from_args = unitVarEnv binder (bindersToSATInfo rhs_binders) + sat_info_rhs' = mergeIdSATInfo sat_info_rhs_from_args sat_info_rhs_body + + shadowing = binder `elementOfUniqSet` interesting_ids + sat_info_rhs'' = if shadowing + then sat_info_rhs' `delFromUFM` binder -- For safety + else sat_info_rhs' + + bind' <- saTransformMaybe binder (lookupUFM sat_info_rhs' binder) + rhs_binders rhs_body' + return (bind', sat_info_rhs'') +satBind (Rec pairs) interesting_ids = do + let (binders, rhss) = unzip pairs + rhss_SATed <- mapM (\e -> satTopLevelExpr e interesting_ids) rhss + let (rhss', sat_info_rhss') = unzip rhss_SATed + return (Rec (zipEqual "satBind" binders rhss'), mergeIdSATInfos sat_info_rhss') + +data App = VarApp Id | TypeApp Type | CoApp Coercion +data Staticness a = Static a | NotStatic + +type IdAppInfo = (Id, SATInfo) + +type SATInfo = [Staticness App] +type IdSATInfo = IdEnv SATInfo +emptyIdSATInfo :: IdSATInfo +emptyIdSATInfo = emptyUFM + +{- +pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (Map.toList id_sat_info)) + where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info) +-} + +pprSATInfo :: SATInfo -> SDoc +pprSATInfo staticness = hcat $ map pprStaticness staticness + +pprStaticness :: Staticness App -> SDoc +pprStaticness (Static (VarApp _)) = ptext (sLit "SV") +pprStaticness (Static (TypeApp _)) = ptext (sLit "ST") +pprStaticness (Static (CoApp _)) = ptext (sLit "SC") +pprStaticness NotStatic = ptext (sLit "NS") + + +mergeSATInfo :: SATInfo -> SATInfo -> SATInfo +mergeSATInfo l r = zipWith mergeSA l r + where + mergeSA NotStatic _ = NotStatic + mergeSA _ NotStatic = NotStatic + mergeSA (Static (VarApp v)) (Static (VarApp v')) + | v == v' = Static (VarApp v) + | otherwise = NotStatic + mergeSA (Static (TypeApp t)) (Static (TypeApp t')) + | t `eqType` t' = Static (TypeApp t) + | otherwise = NotStatic + mergeSA (Static (CoApp c)) (Static (CoApp c')) + | c `coreEqCoercion` c' = Static (CoApp c) + | otherwise = NotStatic + mergeSA _ _ = pprPanic "mergeSATInfo" $ + ptext (sLit "Left:") + <> pprSATInfo l <> ptext (sLit ", ") + <> ptext (sLit "Right:") + <> pprSATInfo r + +mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo +mergeIdSATInfo = plusUFM_C mergeSATInfo + +mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo +mergeIdSATInfos = foldl' mergeIdSATInfo emptyIdSATInfo + +bindersToSATInfo :: [Id] -> SATInfo +bindersToSATInfo vs = map (Static . binderToApp) vs + where binderToApp v | isId v = VarApp v + | isTyVar v = TypeApp $ mkTyVarTy v + | otherwise = CoApp $ mkCoVarCo v + +finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo +finalizeApp Nothing id_sat_info = id_sat_info +finalizeApp (Just (v, sat_info')) id_sat_info = + let sat_info'' = case lookupUFM id_sat_info v of + Nothing -> sat_info' + Just sat_info -> mergeSATInfo sat_info sat_info' + in extendVarEnv id_sat_info v sat_info'' + +satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo) +satTopLevelExpr expr interesting_ids = do + (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids + return (expr', finalizeApp expr_app sat_info_expr) + +satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo) +satExpr var@(Var v) interesting_ids = do + let app_info = if v `elementOfUniqSet` interesting_ids + then Just (v, []) + else Nothing + return (var, emptyIdSATInfo, app_info) + +satExpr lit@(Lit _) _ = do + return (lit, emptyIdSATInfo, Nothing) + +satExpr (Lam binders body) interesting_ids = do + (body', sat_info, this_app) <- satExpr body interesting_ids + return (Lam binders body', finalizeApp this_app sat_info, Nothing) + +satExpr (App fn arg) interesting_ids = do + (fn', sat_info_fn, fn_app) <- satExpr fn interesting_ids + let satRemainder = boring fn' sat_info_fn + case fn_app of + Nothing -> satRemainder Nothing + Just (fn_id, fn_app_info) -> + -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface) + let satRemainderWithStaticness arg_staticness = satRemainder $ Just (fn_id, fn_app_info ++ [arg_staticness]) + in case arg of + Type t -> satRemainderWithStaticness $ Static (TypeApp t) + Coercion c -> satRemainderWithStaticness $ Static (CoApp c) + Var v -> satRemainderWithStaticness $ Static (VarApp v) + _ -> satRemainderWithStaticness $ NotStatic + where + boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo) + boring fn' sat_info_fn app_info = + do (arg', sat_info_arg, arg_app) <- satExpr arg interesting_ids + let sat_info_arg' = finalizeApp arg_app sat_info_arg + sat_info = mergeIdSATInfo sat_info_fn sat_info_arg' + return (App fn' arg', sat_info, app_info) + +satExpr (Case expr bndr ty alts) interesting_ids = do + (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids + let sat_info_expr' = finalizeApp expr_app sat_info_expr + + zipped_alts' <- mapM satAlt alts + let (alts', sat_infos_alts) = unzip zipped_alts' + return (Case expr' bndr ty alts', mergeIdSATInfo sat_info_expr' (mergeIdSATInfos sat_infos_alts), Nothing) + where + satAlt (con, bndrs, expr) = do + (expr', sat_info_expr) <- satTopLevelExpr expr interesting_ids + return ((con, bndrs, expr'), sat_info_expr) + +satExpr (Let bind body) interesting_ids = do + (body', sat_info_body, body_app) <- satExpr body interesting_ids + (bind', sat_info_bind) <- satBind bind interesting_ids + return (Let bind' body', mergeIdSATInfo sat_info_body sat_info_bind, body_app) + +satExpr (Tick tickish expr) interesting_ids = do + (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids + return (Tick tickish expr', sat_info_expr, expr_app) + +satExpr ty@(Type _) _ = do + return (ty, emptyIdSATInfo, Nothing) + +satExpr co@(Coercion _) _ = do + return (co, emptyIdSATInfo, Nothing) + +satExpr (Cast expr coercion) interesting_ids = do + (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids + return (Cast expr' coercion, sat_info_expr, expr_app) + +{- +************************************************************************ + + Static Argument Transformation Monad + +************************************************************************ +-} + +type SatM result = UniqSM result + +runSAT :: UniqSupply -> SatM a -> a +runSAT = initUs_ + +newUnique :: SatM Unique +newUnique = getUniqueM + +{- +************************************************************************ + + Static Argument Transformation Monad + +************************************************************************ + +To do the transformation, the game plan is to: + +1. Create a small nonrecursive RHS that takes the + original arguments to the function but discards + the ones that are static and makes a call to the + SATed version with the remainder. We intend that + this will be inlined later, removing the overhead + +2. Bind this nonrecursive RHS over the original body + WITH THE SAME UNIQUE as the original body so that + any recursive calls to the original now go via + the small wrapper + +3. Rebind the original function to a new one which contains + our SATed function and just makes a call to it: + we call the thing making this call the local body + +Example: transform this + + map :: forall a b. (a->b) -> [a] -> [b] + map = /\ab. \(f:a->b) (as:[a]) -> body[map] +to + map :: forall a b. (a->b) -> [a] -> [b] + map = /\ab. \(f:a->b) (as:[a]) -> + letrec map' :: [a] -> [b] + -- The "worker function + map' = \(as:[a]) -> + let map :: forall a' b'. (a -> b) -> [a] -> [b] + -- The "shadow function + map = /\a'b'. \(f':(a->b) (as:[a]). + map' as + in body[map] + in map' as + +Note [Shadow binding] +~~~~~~~~~~~~~~~~~~~~~ +The calls to the inner map inside body[map] should get inlined +by the local re-binding of 'map'. We call this the "shadow binding". + +But we can't use the original binder 'map' unchanged, because +it might be exported, in which case the shadow binding won't be +discarded as dead code after it is inlined. + +So we use a hack: we make a new SysLocal binder with the *same* unique +as binder. (Another alternative would be to reset the export flag.) + +Note [Binder type capture] +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Notice that in the inner map (the "shadow function"), the static arguments +are discarded -- it's as if they were underscores. Instead, mentions +of these arguments (notably in the types of dynamic arguments) are bound +by the *outer* lambdas of the main function. So we must make up fresh +names for the static arguments so that they do not capture variables +mentioned in the types of dynamic args. + +In the map example, the shadow function must clone the static type +argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a' +is bound by the outer forall. We clone f' too for consistency, but +that doesn't matter either way because static Id arguments aren't +mentioned in the shadow binding at all. + +If we don't we get something like this: + +[Exported] +[Arity 3] +GHC.Base.until = + \ (@ a_aiK) + (p_a6T :: a_aiK -> GHC.Types.Bool) + (f_a6V :: a_aiK -> a_aiK) + (x_a6X :: a_aiK) -> + letrec { + sat_worker_s1aU :: a_aiK -> a_aiK + [] + sat_worker_s1aU = + \ (x_a6X :: a_aiK) -> + let { + sat_shadow_r17 :: forall a_a3O. + (a_a3O -> GHC.Types.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O + [] + sat_shadow_r17 = + \ (@ a_aiK) + (p_a6T :: a_aiK -> GHC.Types.Bool) + (f_a6V :: a_aiK -> a_aiK) + (x_a6X :: a_aiK) -> + sat_worker_s1aU x_a6X } in + case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] { + GHC.Types.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X); + GHC.Types.True -> x_a6X + }; } in + sat_worker_s1aU x_a6X + +Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK +type argument. This is bad because it means the application sat_worker_s1aU x_a6X +is not well typed. +-} + +saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind +saTransformMaybe binder maybe_arg_staticness rhs_binders rhs_body + | Just arg_staticness <- maybe_arg_staticness + , should_transform arg_staticness + = saTransform binder arg_staticness rhs_binders rhs_body + | otherwise + = return (Rec [(binder, mkLams rhs_binders rhs_body)]) + where + should_transform staticness = n_static_args > 1 -- THIS IS THE DECISION POINT + where + n_static_args = length (filter isStaticValue staticness) + +saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind +saTransform binder arg_staticness rhs_binders rhs_body + = do { shadow_lam_bndrs <- mapM clone binders_w_staticness + ; uniq <- newUnique + ; return (NonRec binder (mk_new_rhs uniq shadow_lam_bndrs)) } + where + -- Running example: foldr + -- foldr \alpha \beta c n xs = e, for some e + -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic] + -- rhs_binders = [\alpha, \beta, c, n, xs] + -- rhs_body = e + + binders_w_staticness = rhs_binders `zip` (arg_staticness ++ repeat NotStatic) + -- Any extra args are assumed NotStatic + + non_static_args :: [Var] + -- non_static_args = [xs] + -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs] + non_static_args = [v | (v, NotStatic) <- binders_w_staticness] + + clone (bndr, NotStatic) = return bndr + clone (bndr, _ ) = do { uniq <- newUnique + ; return (setVarUnique bndr uniq) } + + -- new_rhs = \alpha beta c n xs -> + -- let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs -> + -- sat_worker xs + -- in e + -- in sat_worker xs + mk_new_rhs uniq shadow_lam_bndrs + = mkLams rhs_binders $ + Let (Rec [(rec_body_bndr, rec_body)]) + local_body + where + local_body = mkVarApps (Var rec_body_bndr) non_static_args + + rec_body = mkLams non_static_args $ + Let (NonRec shadow_bndr shadow_rhs) rhs_body + + -- See Note [Binder type capture] + shadow_rhs = mkLams shadow_lam_bndrs local_body + -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs + + rec_body_bndr = mkSysLocal (fsLit "sat_worker") uniq (exprType rec_body) + -- rec_body_bndr = sat_worker + + -- See Note [Shadow binding]; make a SysLocal + shadow_bndr = mkSysLocal (occNameFS (getOccName binder)) + (idUnique binder) + (exprType shadow_rhs) + +isStaticValue :: Staticness App -> Bool +isStaticValue (Static (VarApp _)) = True +isStaticValue _ = False |