summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Graf <sebastian.graf@kit.edu>2022-06-08 21:35:28 +0200
committerSebastian Graf <sebastian.graf@kit.edu>2022-06-08 21:35:28 +0200
commit28e7e187e5ad9b39f50a399fe063fd3dd20ea8f5 (patch)
tree31bcfa07a646556ce8f5c3461340059a84ffed09
parent66ec7c65736c7826f8488e104a9fbf860e7d3810 (diff)
downloadhaskell-wip/T14620.tar.gz
Stashing away stuffwip/T14620
-rw-r--r--compiler/GHC/Core/SimpleOpt.hs155
1 files changed, 102 insertions, 53 deletions
diff --git a/compiler/GHC/Core/SimpleOpt.hs b/compiler/GHC/Core/SimpleOpt.hs
index e999541b5e..13993f5340 100644
--- a/compiler/GHC/Core/SimpleOpt.hs
+++ b/compiler/GHC/Core/SimpleOpt.hs
@@ -394,7 +394,7 @@ tryJoinPointWWs in_scope body_ty binds
where
go jph = ([(join_bndr jph, join_rhs jph)], join_wrapper jph)
join_wrapper jph@JoinPointAfterMono{} -- Rare: A join point after we inline a wrapper
- = [(join_wrapper_bndr jph, join_wrapper_body jph)]
+ = [(join_wrapper_bndr jph, join_rule_rhs jph)]
join_wrapper DefinitelyJoinPoint{} -- Common: Regular join point. No wrapper
= []
@@ -866,16 +866,16 @@ and again its arity increases (#15517)
-- | Indicates that a binding can be transformed into a join point.
data JoinPointHood
= DefinitelyJoinPoint -- ^ A join point by nature
- { join_bndr :: !InBndr
- , join_rhs :: !InExpr }
+ { djp_bndr :: !InBndr
+ , djp_rhs :: !InExpr }
| JoinPointAfterMono
-- ^ A join point after we have instantiated the forall binders occuring in
-- the result type. See Note [Join point worker/wrapper].
- { join_bndr :: !InBndr
- , join_rhs :: !InExpr
- , join_wrapper_bndr :: !InBndr
+ { jpam_fun_bndr :: !InBndr
+ , jpam_fun_rhs :: !InExpr
+ , jpam_spec_bndr :: !InBndr
-- ^ the bndr of a wrapper that needs to be inlined unconditionally
- , join_wrapper_body :: !InExpr }
+ , jpam_spec_rhs :: !InExpr }
-- | An element of the result list of 'matchJoinResTy'.
-- Corresponds to a join binder of what is going to be the new join point.
@@ -893,9 +893,9 @@ instance Outputable JoinWorkerBinder where
ppr (InstBinder ty) = text "Inst" <+> ppr ty
ppr (SubstBinder bndr) = text "Subst" <+> ppr bndr
-isSubstBinder :: JoinWorkerBinder -> Bool
-isSubstBinder SubstBinder{} = True
-isSubstBinder _ = False
+isInstBinder :: JoinWorkerBinder -> Bool
+isInstBinder SubstBinder{} = True
+isInstBinder _ = False
-- | Returns Just jph if the binding is a join point:
-- If it's a JoinId, just return @DefinitelyJoinPoint bndr rhs@.
@@ -926,34 +926,40 @@ joinPointBindings_maybe in_scope body_type binds
| AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
, not (exprIsTrivial rhs)
- , (lam_bndrs, rhs') <- etaExpandToJoinPoint join_arity rhs
- , let eta_rhs' = mkLams lam_bndrs rhs'
- , let inst_tys = matchJoinResTy join_arity (idType bndr) body_type
- , let new_join_arity = count isSubstBinder inst_tys
- , let no_mono = new_join_arity == join_arity
- , let worker_body = mk_worker_body lam_bndrs inst_tys eta_rhs'
- -- we need an in-scope set as if the worker was defined inside the RHS of the wrapper (as is the case with SAT)
- , let in_scope' = extendInScopeSetList in_scope lam_bndrs
- , let new_bndr = uniqAway in_scope' bndr -- only used in else branch
- `setIdType` exprType worker_body
- , let wrapper_body = mk_wrapper_body new_bndr lam_bndrs inst_tys
- , let wrapper_bndr = bndr
- -- , no_mono || pprTrace "always tail called:" (vcat [ppr in_scope', ppr bndr, ppr (idType bndr), ppr body_type, ppr rhs, ppr new_bndr, ppr (exprType worker_body), ppr join_arity, ppr inst_tys, ppr new_bndr, ppr wrapper_body, ppr worker_body]) True
- = Just $! if no_mono
- then ( in_scope
- , DefinitelyJoinPoint
- { join_bndr = adjust_id_info bndr lam_bndrs join_arity
- , join_rhs = eta_rhs' } )
- else ( extendInScopeSet in_scope new_bndr
- , JoinPointAfterMono
- { join_bndr = adjust_id_info new_bndr lam_bndrs new_join_arity
- , join_rhs = worker_body
- , join_wrapper_bndr = wrapper_bndr
- , join_wrapper_body = wrapper_body } )
+ = Just $! determine_join_point_hood in_scope body_type join_arity bndr rhs
| otherwise
= Nothing
+determine_join_point_hood in_scope body_type join_arity bndr rhs
+ | res_ty_ok
+ = (in_scope, DefinitelyJoinPoint
+ { djp_bndr = adjust_id_info bndr lam_bndrs join_arity
+ , djp_rhs = mkLams lam_bndrs body } )
+ | otherwise
+ -- , pprTrace "needs rewrite:" (vcat [ppr in_scope', ppr bndr, ppr (idType bndr), ppr body_type, ppr rhs, ppr new_bndr, ppr (exprType spec_fun), ppr join_arity, ppr inst_tys, ppr new_bndr, ppr rule_rhs, ppr spec_fun]) True
+ --
+ -- Let's take the following running example:
+ --
+ -- let fn :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)]
+ -- fn = \@a @b (xs :: [a]) @c b (mc :: Maybe c) -> <rhs>
+ -- in (<body> :: [(Bool, Char)])
+ --
+ = (extendInScopeSet in_scope spec_bndr
+ , JoinPointAfterMono
+ { dpam_fun_bndr = bndr
+ , dpam_fun_rhs = rhs
+ , dpam_spec_bndr = spec_bndr
+ , dpam_spec_rhs = spec_fun } )
+ where
+ (lam_bndrs, body) = etaExpandToJoinPoint join_arity rhs
+ res_ty_ok = not $ any isInstBinder inst_tys
+ inst_tys = matchJoinResTy join_arity (idType bndr) body_type
+ -- Only stuff used in the otherwise branch from here-on
+ spec_fun = mk_spec_fun lam_bndrs inst_tys eta_rhs'
+ spec_bndr = uniqAway in_scope bndr -- only used in else branch
+ `setIdType` exprType spec_fun
+
adjust_id_info :: InBndr -> [InBndr] -> JoinArity -> InBndr
adjust_id_info bndr lam_bndrs join_arity = zapStableUnfolding $ -- TODO: Discuss! Type errors otherwise.
let str_sig = idStrictness bndr
@@ -961,30 +967,45 @@ joinPointBindings_maybe in_scope body_type binds
in bndr `asJoinId` join_arity
`setIdStrictness` etaConvertStrictSig str_arity str_sig
- mk_wrapper_body :: InBndr -> [InBndr] -> [JoinWorkerBinder] -> InExpr
+ mk_spec_fun :: [JoinWorkerBinder] -> [InBndr] -> InExpr -> InExpr
-- See Note [Join point worker/wrapper].
- mk_wrapper_body new_bndr lam_bndrs inst_tys
- = ASSERT( lam_bndrs `equalLength` inst_tys )
- -- pprTraceWith "mk_wrapper_body" (\e -> ppr lam_bndrs $$ ppr inst_tys $$ ppr e) $
- go (Var new_bndr) $ zipEqual "mk_wrapper_body" lam_bndrs inst_tys
+ mk_spec_fun inst_tys lam_bndrs rhs
+ = -- pprTraceWith "mk_spec_fun" (\e -> ppr e) $
+ go rhs $ zipEqual "mk_spec_fun" lam_bndrs inst_tys
where
go e [] = e
- go e ((lb,SubstBinder{}):prs) -- non-instantiated parameter
- | isId lb -- value paramater xs
- = Lam lb (go (App e (Var lb)) prs)
- | otherwise -- type paramater @b
- = Lam lb (go (App e (Type (mkTyVarTy lb))) prs)
- go e ((lb,InstBinder{}):prs) -- instantiated parameter, @a or @c
- = ASSERT( isTyVar lb )
- Lam lb (go e prs)
+ go e ((lb,SubstBinder bndr):prs) -- non-instantiated parameter
+ | Anon _ (Scaled _ ty) <- bndr -- value paramater xs
+ , let lb' = lb `setIdType` ty
+ = Lam lb' (go (App e (Var lb')) prs)
+ | Named (binderVar -> tcv) <- bndr -- type paramater @b
+ = Lam tcv (go (App e (Type (mkTyVarTy tcv))) prs)
+ go e ((_ ,InstBinder ty):prs) -- instantiated paramater, @a or @c
+ = go (App e (Type ty)) prs
- mk_worker_body :: [InBndr] -> [JoinWorkerBinder] -> InExpr -> InExpr
- -- See Note [Join point worker/wrapper].
- mk_worker_body lam_bndrs inst_tys rhs
- = -- pprTraceWith "mk_worker_body" (\e -> ppr e) $
- go rhs $ zipEqual "mk_worker_body" lam_bndrs inst_tys
+ -- let fn :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)]
+ -- fn = \@a @b (xs :: [a]) @c b (mc :: Maybe c) -> <rhs>
+ -- in (<body> :: [(Bool, Char)])
+ --
+ -- mk_spec_rhs [I ]
+ -- mk_spec_rhs
+ -- [I Bool, S (b::*), S (_::[Bool]), I Char, S (_::b), S (_::Maybe Char)]
+ -- [(a::*),(b::*),(xs::[a]),(c::*),(b::b),(mc::Maybe c)]
+ -- rhs
+ --
+ -- returns (lams', rhs'), where
+ --
+ -- lams' = [(b::*),(xs::[Bool]),(b'::b),(mc::Maybe Char)]
+ -- rhs' = rhs @Bool @b xs @Char b' mc
+ mk_spec_rhs :: [JoinWorkerBinder] -> [InBndr] -> InExpr -> ([InBndr], InExpr)
+ mk_spec_rhs inst_tys lam_bndrs rhs
+ = -- pprTraceWith "mk_spec_fun" (\e -> ppr e) $
+ foldr go ([], rhs) $ zipEqual "mk_spec_fun" lam_bndrs inst_tys
where
- go e [] = e
+ go (lb,jwb) (lams', rhs') = case jwb of
+ SubstBinder bndr
+
+
go e ((lb,SubstBinder bndr):prs) -- non-instantiated parameter
| Anon _ (Scaled _ ty) <- bndr -- value paramater xs
, let lb' = lb `setIdType` ty
@@ -994,6 +1015,34 @@ joinPointBindings_maybe in_scope body_type binds
go e ((_ ,InstBinder ty):prs) -- instantiated paramater, @a or @c
= go (App e (Type ty)) prs
+ mk_spec_rule :: InBndr -> InBndr -> [InBndr] -> [JoinWorkerBinder] -> InExpr
+ -- See Note [Join point worker/wrapper].
+ mk_spec_rule bndr spec_bndr lam_bndrs inst_tys
+ = ASSERT( lam_bndrs `equalLength` inst_tys )
+ -- pprTraceWith "mk_spec_rule" (\e -> ppr lam_bndrs $$ ppr inst_tys $$ ppr e) $
+ go (Var new_bndr) $ zipEqual "mk_spec_rule" lam_bndrs inst_tys
+ where
+ go e [] = e
+ go e ((lb,SubstBinder{}):prs) -- non-instantiated parameter
+ | isId lb -- value paramater xs
+ = Lam lb (go (App e (Var lb)) prs)
+ | otherwise -- type paramater @b
+ = Lam lb (go (App e (Type (mkTyVarTy lb))) prs)
+ go e ((lb,InstBinder{}):prs) -- instantiated parameter, @a or @c
+ = ASSERT( isTyVar lb )
+ Lam lb (go e prs)
+
+ fn = bndr
+
+ fn_name = idName fn
+ fn_loc = nameSrcSpan fn_name
+ fn_occ = nameOccName fn_name
+ rule_name = mkFastString ("JOIN:" ++ occNameString fn_occ)
+ spec_name = mkInternalName spec_uniq spec_occ fn_loc
+ rule = mkRule this_mod True {- Auto -} True {- Local -}
+ rule_name inline_act fn_name qvars pats rule_rhs
+ addIdSpecialisations bndr [rule]
+
-- | Figures out how to monomorphise the result type of a join point.
--
-- @matchJoinResTy ja join_ty body_ty@ computes the result type of @join_ty@ by