From 28e7e187e5ad9b39f50a399fe063fd3dd20ea8f5 Mon Sep 17 00:00:00 2001 From: Sebastian Graf Date: Wed, 8 Jun 2022 21:35:28 +0200 Subject: Stashing away stuff --- compiler/GHC/Core/SimpleOpt.hs | 155 +++++++++++++++++++++++++++-------------- 1 file changed, 102 insertions(+), 53 deletions(-) diff --git a/compiler/GHC/Core/SimpleOpt.hs b/compiler/GHC/Core/SimpleOpt.hs index e999541b5e..13993f5340 100644 --- a/compiler/GHC/Core/SimpleOpt.hs +++ b/compiler/GHC/Core/SimpleOpt.hs @@ -394,7 +394,7 @@ tryJoinPointWWs in_scope body_ty binds where go jph = ([(join_bndr jph, join_rhs jph)], join_wrapper jph) join_wrapper jph@JoinPointAfterMono{} -- Rare: A join point after we inline a wrapper - = [(join_wrapper_bndr jph, join_wrapper_body jph)] + = [(join_wrapper_bndr jph, join_rule_rhs jph)] join_wrapper DefinitelyJoinPoint{} -- Common: Regular join point. No wrapper = [] @@ -866,16 +866,16 @@ and again its arity increases (#15517) -- | Indicates that a binding can be transformed into a join point. data JoinPointHood = DefinitelyJoinPoint -- ^ A join point by nature - { join_bndr :: !InBndr - , join_rhs :: !InExpr } + { djp_bndr :: !InBndr + , djp_rhs :: !InExpr } | JoinPointAfterMono -- ^ A join point after we have instantiated the forall binders occuring in -- the result type. See Note [Join point worker/wrapper]. - { join_bndr :: !InBndr - , join_rhs :: !InExpr - , join_wrapper_bndr :: !InBndr + { jpam_fun_bndr :: !InBndr + , jpam_fun_rhs :: !InExpr + , jpam_spec_bndr :: !InBndr -- ^ the bndr of a wrapper that needs to be inlined unconditionally - , join_wrapper_body :: !InExpr } + , jpam_spec_rhs :: !InExpr } -- | An element of the result list of 'matchJoinResTy'. -- Corresponds to a join binder of what is going to be the new join point. @@ -893,9 +893,9 @@ instance Outputable JoinWorkerBinder where ppr (InstBinder ty) = text "Inst" <+> ppr ty ppr (SubstBinder bndr) = text "Subst" <+> ppr bndr -isSubstBinder :: JoinWorkerBinder -> Bool -isSubstBinder SubstBinder{} = True -isSubstBinder _ = False +isInstBinder :: JoinWorkerBinder -> Bool +isInstBinder SubstBinder{} = True +isInstBinder _ = False -- | Returns Just jph if the binding is a join point: -- If it's a JoinId, just return @DefinitelyJoinPoint bndr rhs@. @@ -926,34 +926,40 @@ joinPointBindings_maybe in_scope body_type binds | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr) , not (exprIsTrivial rhs) - , (lam_bndrs, rhs') <- etaExpandToJoinPoint join_arity rhs - , let eta_rhs' = mkLams lam_bndrs rhs' - , let inst_tys = matchJoinResTy join_arity (idType bndr) body_type - , let new_join_arity = count isSubstBinder inst_tys - , let no_mono = new_join_arity == join_arity - , let worker_body = mk_worker_body lam_bndrs inst_tys eta_rhs' - -- we need an in-scope set as if the worker was defined inside the RHS of the wrapper (as is the case with SAT) - , let in_scope' = extendInScopeSetList in_scope lam_bndrs - , let new_bndr = uniqAway in_scope' bndr -- only used in else branch - `setIdType` exprType worker_body - , let wrapper_body = mk_wrapper_body new_bndr lam_bndrs inst_tys - , let wrapper_bndr = bndr - -- , no_mono || pprTrace "always tail called:" (vcat [ppr in_scope', ppr bndr, ppr (idType bndr), ppr body_type, ppr rhs, ppr new_bndr, ppr (exprType worker_body), ppr join_arity, ppr inst_tys, ppr new_bndr, ppr wrapper_body, ppr worker_body]) True - = Just $! if no_mono - then ( in_scope - , DefinitelyJoinPoint - { join_bndr = adjust_id_info bndr lam_bndrs join_arity - , join_rhs = eta_rhs' } ) - else ( extendInScopeSet in_scope new_bndr - , JoinPointAfterMono - { join_bndr = adjust_id_info new_bndr lam_bndrs new_join_arity - , join_rhs = worker_body - , join_wrapper_bndr = wrapper_bndr - , join_wrapper_body = wrapper_body } ) + = Just $! determine_join_point_hood in_scope body_type join_arity bndr rhs | otherwise = Nothing +determine_join_point_hood in_scope body_type join_arity bndr rhs + | res_ty_ok + = (in_scope, DefinitelyJoinPoint + { djp_bndr = adjust_id_info bndr lam_bndrs join_arity + , djp_rhs = mkLams lam_bndrs body } ) + | otherwise + -- , pprTrace "needs rewrite:" (vcat [ppr in_scope', ppr bndr, ppr (idType bndr), ppr body_type, ppr rhs, ppr new_bndr, ppr (exprType spec_fun), ppr join_arity, ppr inst_tys, ppr new_bndr, ppr rule_rhs, ppr spec_fun]) True + -- + -- Let's take the following running example: + -- + -- let fn :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)] + -- fn = \@a @b (xs :: [a]) @c b (mc :: Maybe c) -> + -- in ( :: [(Bool, Char)]) + -- + = (extendInScopeSet in_scope spec_bndr + , JoinPointAfterMono + { dpam_fun_bndr = bndr + , dpam_fun_rhs = rhs + , dpam_spec_bndr = spec_bndr + , dpam_spec_rhs = spec_fun } ) + where + (lam_bndrs, body) = etaExpandToJoinPoint join_arity rhs + res_ty_ok = not $ any isInstBinder inst_tys + inst_tys = matchJoinResTy join_arity (idType bndr) body_type + -- Only stuff used in the otherwise branch from here-on + spec_fun = mk_spec_fun lam_bndrs inst_tys eta_rhs' + spec_bndr = uniqAway in_scope bndr -- only used in else branch + `setIdType` exprType spec_fun + adjust_id_info :: InBndr -> [InBndr] -> JoinArity -> InBndr adjust_id_info bndr lam_bndrs join_arity = zapStableUnfolding $ -- TODO: Discuss! Type errors otherwise. let str_sig = idStrictness bndr @@ -961,30 +967,45 @@ joinPointBindings_maybe in_scope body_type binds in bndr `asJoinId` join_arity `setIdStrictness` etaConvertStrictSig str_arity str_sig - mk_wrapper_body :: InBndr -> [InBndr] -> [JoinWorkerBinder] -> InExpr + mk_spec_fun :: [JoinWorkerBinder] -> [InBndr] -> InExpr -> InExpr -- See Note [Join point worker/wrapper]. - mk_wrapper_body new_bndr lam_bndrs inst_tys - = ASSERT( lam_bndrs `equalLength` inst_tys ) - -- pprTraceWith "mk_wrapper_body" (\e -> ppr lam_bndrs $$ ppr inst_tys $$ ppr e) $ - go (Var new_bndr) $ zipEqual "mk_wrapper_body" lam_bndrs inst_tys + mk_spec_fun inst_tys lam_bndrs rhs + = -- pprTraceWith "mk_spec_fun" (\e -> ppr e) $ + go rhs $ zipEqual "mk_spec_fun" lam_bndrs inst_tys where go e [] = e - go e ((lb,SubstBinder{}):prs) -- non-instantiated parameter - | isId lb -- value paramater xs - = Lam lb (go (App e (Var lb)) prs) - | otherwise -- type paramater @b - = Lam lb (go (App e (Type (mkTyVarTy lb))) prs) - go e ((lb,InstBinder{}):prs) -- instantiated parameter, @a or @c - = ASSERT( isTyVar lb ) - Lam lb (go e prs) + go e ((lb,SubstBinder bndr):prs) -- non-instantiated parameter + | Anon _ (Scaled _ ty) <- bndr -- value paramater xs + , let lb' = lb `setIdType` ty + = Lam lb' (go (App e (Var lb')) prs) + | Named (binderVar -> tcv) <- bndr -- type paramater @b + = Lam tcv (go (App e (Type (mkTyVarTy tcv))) prs) + go e ((_ ,InstBinder ty):prs) -- instantiated paramater, @a or @c + = go (App e (Type ty)) prs - mk_worker_body :: [InBndr] -> [JoinWorkerBinder] -> InExpr -> InExpr - -- See Note [Join point worker/wrapper]. - mk_worker_body lam_bndrs inst_tys rhs - = -- pprTraceWith "mk_worker_body" (\e -> ppr e) $ - go rhs $ zipEqual "mk_worker_body" lam_bndrs inst_tys + -- let fn :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)] + -- fn = \@a @b (xs :: [a]) @c b (mc :: Maybe c) -> + -- in ( :: [(Bool, Char)]) + -- + -- mk_spec_rhs [I ] + -- mk_spec_rhs + -- [I Bool, S (b::*), S (_::[Bool]), I Char, S (_::b), S (_::Maybe Char)] + -- [(a::*),(b::*),(xs::[a]),(c::*),(b::b),(mc::Maybe c)] + -- rhs + -- + -- returns (lams', rhs'), where + -- + -- lams' = [(b::*),(xs::[Bool]),(b'::b),(mc::Maybe Char)] + -- rhs' = rhs @Bool @b xs @Char b' mc + mk_spec_rhs :: [JoinWorkerBinder] -> [InBndr] -> InExpr -> ([InBndr], InExpr) + mk_spec_rhs inst_tys lam_bndrs rhs + = -- pprTraceWith "mk_spec_fun" (\e -> ppr e) $ + foldr go ([], rhs) $ zipEqual "mk_spec_fun" lam_bndrs inst_tys where - go e [] = e + go (lb,jwb) (lams', rhs') = case jwb of + SubstBinder bndr + + go e ((lb,SubstBinder bndr):prs) -- non-instantiated parameter | Anon _ (Scaled _ ty) <- bndr -- value paramater xs , let lb' = lb `setIdType` ty @@ -994,6 +1015,34 @@ joinPointBindings_maybe in_scope body_type binds go e ((_ ,InstBinder ty):prs) -- instantiated paramater, @a or @c = go (App e (Type ty)) prs + mk_spec_rule :: InBndr -> InBndr -> [InBndr] -> [JoinWorkerBinder] -> InExpr + -- See Note [Join point worker/wrapper]. + mk_spec_rule bndr spec_bndr lam_bndrs inst_tys + = ASSERT( lam_bndrs `equalLength` inst_tys ) + -- pprTraceWith "mk_spec_rule" (\e -> ppr lam_bndrs $$ ppr inst_tys $$ ppr e) $ + go (Var new_bndr) $ zipEqual "mk_spec_rule" lam_bndrs inst_tys + where + go e [] = e + go e ((lb,SubstBinder{}):prs) -- non-instantiated parameter + | isId lb -- value paramater xs + = Lam lb (go (App e (Var lb)) prs) + | otherwise -- type paramater @b + = Lam lb (go (App e (Type (mkTyVarTy lb))) prs) + go e ((lb,InstBinder{}):prs) -- instantiated parameter, @a or @c + = ASSERT( isTyVar lb ) + Lam lb (go e prs) + + fn = bndr + + fn_name = idName fn + fn_loc = nameSrcSpan fn_name + fn_occ = nameOccName fn_name + rule_name = mkFastString ("JOIN:" ++ occNameString fn_occ) + spec_name = mkInternalName spec_uniq spec_occ fn_loc + rule = mkRule this_mod True {- Auto -} True {- Local -} + rule_name inline_act fn_name qvars pats rule_rhs + addIdSpecialisations bndr [rule] + -- | Figures out how to monomorphise the result type of a join point. -- -- @matchJoinResTy ja join_ty body_ty@ computes the result type of @join_ty@ by -- cgit v1.2.1