diff options
author | Manuel M T Chakravarty <chak@cse.unsw.edu.au> | 2012-05-08 00:12:00 +1000 |
---|---|---|
committer | Manuel M T Chakravarty <chak@cse.unsw.edu.au> | 2012-05-08 00:14:21 +1000 |
commit | 209e375051e557b34e99a0bbef06c7ba6459f5d9 (patch) | |
tree | 4f9e5bb775984bf827cd7d1cb3609b11b16e992b | |
parent | b26a1b3f5aeeb5a0f8c862d529288c859198d1a6 (diff) | |
download | haskell-209e375051e557b34e99a0bbef06c7ba6459f5d9.tar.gz |
Fix #6080 & house keeping in Vectorise.Exp
-rw-r--r-- | compiler/vectorise/Vectorise.hs | 6 | ||||
-rw-r--r-- | compiler/vectorise/Vectorise/Exp.hs | 797 |
2 files changed, 338 insertions, 465 deletions
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index 8f6e32130f..3ac9c5105f 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -361,18 +361,18 @@ vectTopRhs recFs var expr rhs _globalScalar _isDFun (Just (_, expr')) -- Case (1) = return (inlineMe, False, expr') rhs True False Nothing -- Case (2) - = do { expr' <- vectScalarFun recFs expr + = do { expr' <- vectScalarFun expr ; return (inlineMe, True, vectorised expr') } rhs True True Nothing -- Case (3) - = do { expr' <- vectScalarDFun var recFs + = do { expr' <- vectScalarDFun var ; return (DontInline, True, expr') } rhs False False Nothing -- Case (4) — not a dfun = do { let exprFvs = freeVars expr ; (inline, isScalar, vexpr) <- inBind var $ - vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs + vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs Nothing ; return (inline, isScalar, vectorised vexpr) } rhs False True Nothing -- Case (4) — is a dfun diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 0764c3b255..e75cf0e009 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -51,271 +51,120 @@ import TcRnMonad (doptM) import DynFlags (DynFlag(Opt_AvoidVect)) --- For prototyping, the VITree is a separate data structure with the same shape as the corresponding expression --- tree. This will become part of the annotation - -data VectInfo = VIParr - | VISimple - | VIComplex - | VIEncaps - deriving (Eq, Show) - -data VITree = VITNode VectInfo [VITree] - deriving (Show) - -viTrace :: CoreExprWithFVs -> VectInfo -> [VITree] -> VM () -viTrace ce vi vTs = - -- return () - traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") (ppr $ deAnnotate ce) - -viOr :: [VITree] -> Bool -viOr = or . (map (\(VITNode vi _) -> vi == VIParr)) - --- TODO: free scalar vars don't actually need to be passed through, since encapsulations makes sure, that there are --- no free variables in encapsulated lambda expressions -vectInfo:: CoreExprWithFVs -> VM VITree -vectInfo ce@(_, AnnVar v) - = do { vi <- vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [] - ; traceVt "vectInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce)) - ; return $ VITNode vi [] - } - -vectInfo ce@(_, AnnLit _) - = do { vi <- vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [] - ; traceVt "vectInfo AnnLit" (ppr $ exprType $ deAnnotate ce) - ; return $ VITNode vi [] - } - -vectInfo ce@(_, AnnApp e1 e2) - = do { vt1 <- vectInfo e1 - ; vt2 <- vectInfo e2 - ; vi <- if viOr [vt1, vt2] - then return VIParr - else vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [vt1, vt2] - ; return $ VITNode vi [vt1, vt2] - } - -vectInfo ce@(_, AnnLam _var body) - = do { vt@(VITNode vi _) <- vectInfo body - ; viTrace ce vi [vt] - ; if (vi == VIParr) - then return $ VITNode vi [vt] - else return $ VITNode VIComplex [vt] - } - -vectInfo ce@(_, AnnLet (AnnNonRec _var expr) body) - = do { vtE <- vectInfo expr - ; vtB <- vectInfo body - ; vi <- if viOr [vtE, vtB] - then return VIParr - else vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [vtE, vtB] - ; return $ VITNode vi [vtE, vtB] - } - -vectInfo ce@(_, AnnLet (AnnRec bnds) body) - = do { let (_, exprs) = unzip bnds - ; vtBnds <- mapM (\e -> vectInfo e) exprs - ; if (viOr vtBnds) - then do { vtBnds' <- mapM (\e -> vectInfo e) exprs - ; vtB <- vectInfo body - ; return (VITNode VIParr (vtB: vtBnds')) - } - else do { vtB@(VITNode vib _) <- vectInfo body - ; ni <- if (vib == VIParr) - then return VIParr - else vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtB : vtBnds) - ; return $ VITNode ni (vtB : vtBnds) - } - } - -vectInfo ce@(_, AnnCase expr _var _ty alts) - = do { vtExpr <- vectInfo expr - ; vtAlts <- mapM (\(_, _, e) -> vectInfo e) alts - ; ni <- if viOr (vtExpr : vtAlts) - then return VIParr - else vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtExpr : vtAlts) - ; return $ VITNode ni (vtExpr: vtAlts) - } - - -vectInfo (_, AnnCast expr _) - = do { vt@(VITNode vi _) <- vectInfo expr - ; return $ VITNode vi [vt] - } - -vectInfo (_, AnnTick _ expr ) - = do { vt@(VITNode vi _) <- vectInfo expr - ; return $ VITNode vi [vt] - } - -vectInfo (_, AnnType {}) - = return $ VITNode VISimple [] - -vectInfo (_, AnnCoercion {}) - = return $ VITNode VISimple [] - - - -vectInfoType:: Type -> VM VectInfo -vectInfoType ty - | maybeParrTy ty = return VIParr - | otherwise - = do { sType <- isSimpleType ty - ; if sType - then return VISimple - else return VIComplex - } - - --- Checks whether the type might be a parallel array type. In particular, if the outermost --- constructor is a type family, we conservatively assume that it may be a parallel array type. -maybeParrTy :: Type -> Bool -maybeParrTy ty - | Just ty' <- coreView ty = maybeParrTy ty' - | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon - || or (map maybeParrTy ts) -maybeParrTy _ = False - +-- Main entry point to vectorise expressions ----------------------------------- -isSimpleType:: Type -> VM Bool -isSimpleType ty - | Just (c, _cs) <- splitTyConApp_maybe ty = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] -{- - = do { globals <- globalScalarTyCons - ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c) - ; return (elemNameSet (tyConName c) globals ) - } - -} - | Nothing <- splitTyConApp_maybe ty - = return False -isSimpleType ty - = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty) - -varsSimple :: VarSet -> VM Bool -varsSimple vs - = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs - ; return $ and varTypes - } - - --- | Vectorise a polymorphic expression. -vectPolyExpr:: Bool -> [Var] -> CoreExprWithFVs - -> VM (Inline, Bool, VExpr) -vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) - = do { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr - ; return (inline, isScalarFn, vTick tickish expr') - } - - - -vectPolyExpr loop_breaker recFns expr - = do { vectAvoidance <- liftDs $ doptM Opt_AvoidVect - ; vi <- vectInfo expr - ; ((tvs, mono), vi') <- - if vectAvoidance - then do { (extExpr, vi') <- encapsulateScalar vi expr - ; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr) - ; return $ (collectAnnTypeBinders extExpr , vi') - } - else return $ (collectAnnTypeBinders expr, vi) - ; arity <- polyArity tvs - ; polyAbstract tvs $ \args -> - do {(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi' - ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') - } - } - --- todo: clean this - -vectPolyExprVT:: Bool -> [Var] -> CoreExprWithFVs -> VITree - -> VM (Inline, Bool, VExpr) - --- vectPolyExprVT _loop_breaker _recFns e vi | not (checkTree vi (deAnnotate e)) --- = pprPanic "vectPolyExprVT" (ppr $ deAnnotate e) -vectPolyExprVT loop_breaker recFns (_, AnnTick tickish expr) (VITNode _ [vit]) - = do { (inline, isScalarFn, expr') <- vectPolyExprVT loop_breaker recFns expr vit - ; return (inline, isScalarFn, vTick tickish expr') - } - -vectPolyExprVT loop_breaker recFns expr vi - = do { -- checkTreeAnnM vi expr ; - let (tvs, mono) = collectAnnTypeBinders expr - ; arity <- polyArity tvs - ; polyAbstract tvs $ \args -> - do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi - ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') - } - } +-- |Vectorise a polymorphic expression. +-- +-- If not yet available, precompute vectorisation avoidance information before vectorising. If +-- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance +-- information to encapsulated subexpression that do not need to be vectorised. +-- +vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree + -> VM (Inline, Bool, VExpr) + -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions) +vectPolyExpr loop_breaker recFns expr Nothing + = do + { vectAvoidance <- liftDs $ doptM Opt_AvoidVect + ; vi <- vectAvoidInfo expr + ; (expr', vi') <- + if vectAvoidance + then do + { (expr', vi') <- encapsulateScalars vi expr + ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr') + ; return (expr', vi') + } + else return (expr, vi) + ; vectPolyExpr loop_breaker recFns expr' (Just vi') + } + + -- traverse through ticks +vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit])) + = do + { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit) + ; return (inline, isScalarFn, vTick tickish expr') + } + + -- collect and vectorise type abstractions; then, descent into the body +vectPolyExpr loop_breaker recFns expr (Just vit) + = do + { let (tvs, mono) = collectAnnTypeBinders expr + vit' = stripLevels (length tvs) vit + ; arity <- polyArity tvs + ; polyAbstract tvs $ \args -> + do + { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit' + ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') + } + } + where + stripLevels 0 vit = vit + stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit + stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit)) --- | encapsulate every purely sequential subexpression with a simple return type --- of a (potentially) parallel expression into a lambda abstraction over all its --- free variables followed by the corresponding application to those variables. --- Condition: --- all free variables and the result type must be of `simple' type --- the expression is 'complex enough', which is, for now, every expression --- which is not constant and contains at least one operation. +-- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a +-- into a lambda abstraction over all its free variables followed by the corresponding application +-- to those variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions. +-- +-- Preconditions: +-- +-- * All free variables and the result type must be /simple/ types. +-- * The expression is sufficientlt complex (top warrant special treatment). For now, that is +-- every expression that is not constant and contains at least one operation. -- -encapsulateScalar :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree) -encapsulateScalar vit ce@(_, AnnType _ty) +encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree) +encapsulateScalars vit ce@(_, AnnType _ty) = return (ce, vit) -encapsulateScalar vit ce@(_, AnnVar _v) +encapsulateScalars vit ce@(_, AnnVar _v) = return (ce, vit) -encapsulateScalar vit ce@(_, AnnLit _) +encapsulateScalars vit ce@(_, AnnLit _) = return (ce, vit) - -encapsulateScalar (VITNode vi [vit]) (fvs, AnnTick tck expr) - = do { (extExpr, vit') <- encapsulateScalar vit expr +encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr) + = do { (extExpr, vit') <- encapsulateScalars vit expr ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit']) } -encapsulateScalar _ (_fvs, AnnTick _tck _expr) +encapsulateScalars _ (_fvs, AnnTick _tck _expr) = panic "encapsulateScalar AnnTick doesn't match up" -encapsulateScalar (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr) +encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> do { let (e', vit') = encaps vit ce + (VISimple, True) -> do { let (e', vit') = liftSimple vit ce ; return (e', vit') } - _ -> do { (extExpr, vit') <- encapsulateScalar vit expr + _ -> do { (extExpr, vit') <- encapsulateScalars vit expr ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit']) } } -encapsulateScalar _ (_fvs, AnnLam _bndr _expr) - = panic "encapsulateScalar AnnLam doesn't match up" +encapsulateScalars _ (_fvs, AnnLam _bndr _expr) + = panic "encapsulateScalars AnnLam doesn't match up" - -encapsulateScalar vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) +encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> do { let (e', vt') = encaps vt ce + (VISimple, True) -> do { let (e', vt') = liftSimple vt ce -- ; checkTreeAnnM vt' e' -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e') ; return (e', vt') } - _ -> do { (etaCe1, vit1') <- encapsulateScalar vit1 ce1 - ; (etaCe2, vit2') <- encapsulateScalar vit2 ce2 + _ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1 + ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2 ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2']) } } -encapsulateScalar _ (_fvs, AnnApp _ce1 _ce2) - = panic "encapsulateScalar AnnApp doesn't match up" + +encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2) + = panic "encapsulateScalars AnnApp doesn't match up" -encapsulateScalar vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) +encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps vt ce - _ -> do { (extScrut, scrutVit') <- encapsulateScalar scrutVit scrut + (VISimple, True) -> return $ liftSimple vt ce + _ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut ; extAltsVits <- zipWithM expAlt altVits alts ; let (extAlts, altVits') = unzip extAltsVits ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits')) @@ -323,110 +172,100 @@ encapsulateScalar vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bn } where expAlt vt (con, bndrs, expr) - = do { (extExpr, vt') <- encapsulateScalar vt expr + = do { (extExpr, vt') <- encapsulateScalars vt expr ; return ((con, bndrs, extExpr), vt') } -encapsulateScalar _ (_fvs, AnnCase _scrut _bndr _ty _alts) - = panic "encapsulateScalar AnnCase doesn't match up" +encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts) + = panic "encapsulateScalars AnnCase doesn't match up" -encapsulateScalar vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) +encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps vt ce - _ -> do { (extExpr1, vt1') <- encapsulateScalar vt1 expr1 - ; (extExpr2, vt2') <- encapsulateScalar vt2 expr2 + (VISimple, True) -> return $ liftSimple vt ce + _ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1 + ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2 ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2']) } } -encapsulateScalar _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2) - = panic "encapsulateScalar AnnLet nonrec doesn't match up" +encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2) + = panic "encapsulateScalars AnnLet nonrec doesn't match up" -encapsulateScalar vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) +encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps vt ce + (VISimple, True) -> return $ liftSimple vt ce _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs ; let (extBnds, vtBnds') = unzip extBndsVts - ; (extExpr, vtB') <- encapsulateScalar vtB expr + ; (extExpr, vtB') <- encapsulateScalars vtB expr ; let vt' = VITNode vi (vtB':vtBnds') ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt') } } where expBndg vit (bndr, expr) - = do { (extExpr, vit') <- encapsulateScalar vit expr + = do { (extExpr, vit') <- encapsulateScalars vit expr ; return ((bndr, extExpr), vit') } -encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2) - = panic "encapsulateScalar AnnLet rec doesn't match up" +encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2) + = panic "encapsulateScalars AnnLet rec doesn't match up" - - -encapsulateScalar (VITNode vi [vit]) (fvs, AnnCast expr coercion) - = do { (extExpr, vit') <- encapsulateScalar vit expr +encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion) + = do { (extExpr, vit') <- encapsulateScalars vit expr ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit']) } -encapsulateScalar _ (_fvs, AnnCast _expr _coercion) - = panic "encapsulateScalar AnnCast rec doesn't match up" - - -encapsulateScalar _ _ - = panic "encapsulateScalar case not handled" +encapsulateScalars _ (_fvs, AnnCast _expr _coercion) + = panic "encapsulateScalars AnnCast rec doesn't match up" +encapsulateScalars _ _ + = panic "encapsulateScalars case not handled" - - --- CoreExprWithFVs, -- = AnnExpr Id VarSet --- AnnExpr bndr VarSet = (annot, AnnExpr' bndr VarSet) --- AnnLam :: bndr -> (AnnExpr bndr VarSet) -> AnnExpr' bndr VarSet --- AnnLam bndr (AnnExpr bndr annot) -encaps :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree) -encaps (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts) +-- Lambda-lift the given expression and apply it to the abstracted free variables. +-- +-- If the expression is a case expression scrutinising anything but a primitive type, then lift +-- each alternative individually. +-- +liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree) +liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts) | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr), - (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- TODO: globalScalarTyCons - = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits')) - - where - (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $ - zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, encaps altVi aex)) alts altVits - -encaps viTree ae@(fvs, _annEx) + (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- FIXME: shouldn't be hardcoded + = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits')) + where + (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $ + zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits + +liftSimple viTree ae@(fvs, _annEx) = (mkAnnApps (mkAnnLams ae vars) vars, viTree') where - mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits - mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs] + mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits + mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs] - mkViTreeApps vi [] = vi - mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []] + mkViTreeApps vi [] = vi + mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []] + + vars = varSetElems fvs + viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars + + mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet + mkAnnLam bndr ce = AnnLam bndr ce - vars = varSetElems fvs - viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars + mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs + mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check! + mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs - mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet - mkAnnLam bndr ce = AnnLam bndr ce - - mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs - mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check! - mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs - - mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet) - mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v)) - - mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs - mkAnnApps (fv, aex') [] = (fv, aex') - mkAnnApps ae (v:vs) = - let - (fv, aex') = mkAnnApps ae vs - in (extendVarSet fv v, mkAnnApp (fv, aex') v) - - - + mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet) + mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v)) + + mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs + mkAnnApps (fv, aex') [] = (fv, aex') + mkAnnApps ae (v:vs) = + let + (fv, aex') = mkAnnApps ae vs + in (extendVarSet fv v, mkAnnApp (fv, aex') v) - -- |Vectorise an expression. -- vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr @@ -441,6 +280,7 @@ vectExpr (_, AnnLit lit) _ vectExpr e@(_, AnnLam bndr _) vt | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt + | otherwise = cantVectorise "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e)) -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty'; -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint @@ -501,7 +341,7 @@ vectExpr (_, AnnCase scrut bndr ty alts) vt vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2]) = do - vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExprVT False [] rhs vt1 + vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1) (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2) return $ vLet (vNonRec vbndr vrhs) vbody @@ -518,7 +358,7 @@ vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds)) vect_rhs bndr rhs vt = localV . inBind bndr . liftM (\(_,_,z)->z) - $ vectPolyExprVT (isStrongLoopBreaker $ idOccInfo bndr) [] rhs vt + $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt) zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs) vectExpr (_, AnnTick tickish expr) (VITNode _ [vit]) @@ -527,7 +367,7 @@ vectExpr (_, AnnTick tickish expr) (VITNode _ [vit]) vectExpr (_, AnnType ty) _ = liftM vType (vectType ty) -vectExpr e _ = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e) +vectExpr e vit = cantVectorise "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit)) -- |Vectorise an expression that *may* have an outer lambda abstraction. -- @@ -542,11 +382,8 @@ vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` -> VITree -> VM (Inline, Bool, VExpr) - -- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e)) -- = pprPanic "vectFnExpr" (ppr $ deAnnotate e) - - vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt']) -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type | isId bndr @@ -557,7 +394,7 @@ vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode } -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation) | isId bndr - = mark DontInline True (vectScalarFunVT False recFns (deAnnotate expr) vt) + = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt) `orElseV` mark inlineMe False (vectLam inline loop_breaker expr vt) vectFnExpr _ _ _ e vt @@ -689,144 +526,28 @@ vectDictExpr (Coercion coe) -- instead they become dictionaries of vectorised methods). We treat them differently, though see -- "Note [Scalar dfuns]" in 'Vectorise'. -- -vectScalarFun :: [Var] -- ^ Functions names in same recursive binding group - -> CoreExpr -- ^ Expression to be vectorised - -> VM VExpr -vectScalarFun recFns expr - -- this is an external call to vectScalarFun, so we pass a dummy vt tree. The only - -- relevant bit is that the node info is *not* VIEncaps - = vectScalarFunVT True recFns expr (VITNode VISimple []) - - -vectScalarFunVT :: Bool -- ^ Was the function marked as scalar by the user? - -> [Var] -- ^ Functions names in same recursive binding group - -> CoreExpr -- ^ Expression to be vectorised - -> VITree - -> VM VExpr -vectScalarFunVT forceScalar recFns expr (VITNode vi _) - = do { gscalarVars <- globalScalarVars - ; scalarTyCons <- globalScalarTyCons - ; let scalarVars = gscalarVars `extendVarSetList` recFns - (arg_tys, res_ty) = splitFunTys (exprType expr) - ; MASSERT( not $ null arg_tys ) - ; traceVt ("vectScalarFun - not scalar? " ++ - "\n\tall tycons scalar? : " ++ (show $all (is_scalar_ty scalarTyCons) arg_tys) ++ - "\n\tresult scalar? : " ++ (show $is_scalar_ty scalarTyCons res_ty) ++ - "\n\tscalar body? : " ++ (show $is_scalar scalarVars (is_scalar_ty scalarTyCons) expr) ++ - "\n\tuses vars? : " ++ (show $uses scalarVars expr) ++ - "\n\t is encaps? (same as & of all prev cond): " ++ (show vi) - ) - (ppr expr) - ; onlyIfV (ptext (sLit "not a scalar function")) - (forceScalar -- user asserts the functions is scalar - || - (vi == VIEncaps)) -- should only be true if all the foll. cond are hold - -{- || - all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar - && is_scalar_ty scalarTyCons res_ty - && is_scalar scalarVars (is_scalar_ty scalarTyCons) expr - && uses scalarVars expr) - -} - $ do { traceVt "vectScalarFun - is scalar" (ppr expr) - ; mkScalarFun arg_tys res_ty expr - } - } - where - {- - -- !!!FIXME: We would like to allow scalar functions with arguments and results that can be - -- any 'scalarTyCons', but can't at the moment, as those argument and result types - -- need to be members of the 'Scalar' class (that in its current form would better - -- be called 'Primitive'). *ALSO* the hardcoded list of types is ugly! - -} - is_scalar_ty _scalarTyCons ty - | isPredTy ty -- dictionaries never get into the environment - = True - | Just (tycon, []) <- splitTyConApp_maybe ty -- TODO: FIX THIS! - = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] --- FIXME: = tyConName tycon `elemNameSet` scalarTyCons - | Just (tycon, _) <- splitTyConApp_maybe ty - = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] - --- FIXME: = tyConName tycon `elemNameSet` scalarTyCons - | otherwise - = False - - -- Checks whether an expression contain a non-scalar subexpression. - -- - -- Precodition: The variables in the first argument are scalar. - -- - -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding - -- them to the list of scalar variables) and then check them. If one of them turns out not to - -- be scalar, the entire group is regarded as not being scalar. - -- - -- The second argument is a predicate that checks whether a type is scalar. - -- - is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool - is_scalar scalars _isScalarTC (Var v) = - v `elemVarSet` scalars - is_scalar _scalars _isScalarTC (Lit _) = True - is_scalar scalars isScalarTC (App e1 e2) = is_scalar scalars isScalarTC e1 && - is_scalar scalars isScalarTC e2 - is_scalar scalars isScalarTC (Lam var body) - | maybe_parr_ty (varType var) = False - | otherwise = is_scalar (scalars `extendVarSet` var) - isScalarTC body - is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++ - " " ++ (show $ is_scalar scalars' isScalarTC body) ++ - (show $ showSDoc $ ppr bind)) $ - bindsAreScalar && - is_scalar scalars' isScalarTC body - where - (bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind - is_scalar scalars isScalarTC (Case e var ty alts) - | isScalarTC ty = is_scalar scalars' isScalarTC e && - all (is_scalar_alt scalars' isScalarTC) alts - | otherwise = False - where - scalars' = scalars `extendVarSet` var - is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e - is_scalar scalars isScalarTC (Tick _ e ) = is_scalar scalars isScalarTC e - is_scalar _scalars _isScalarTC (Type {}) = True - is_scalar _scalars _isScalarTC (Coercion {}) = True - - -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group) - is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e, - scalars `extendVarSet` var) - is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es, - scalars') - where - (vars, es) = unzip bnds - scalars' = scalars `extendVarSetList` vars - - is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) - isScalarTCs e - - -- Checks whether the type might be a parallel array type. In particular, if the outermost - -- constructor is a type family, we conservatively assume that it may be a parallel array type. - maybe_parr_ty :: Type -> Bool - maybe_parr_ty ty - | Just ty' <- coreView ty = maybe_parr_ty ty' - | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon - maybe_parr_ty _ = False - - -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions - -- is called by some other function that is otherwise scalar, it would be very bad - -- that just this call to the identity makes it not be scalar. - -- A scalar function has to actually compute something. Without the check, - -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to - -- (map (\x -> x)) which is very bad. Normal lifting transforms it to - -- (\n# x -> x) which is what we want. - uses funs (Var v) = v `elemVarSet` funs - uses funs (App e1 e2) = uses funs e1 || uses funs e2 - uses funs (Lam b body) = uses (funs `extendVarSet` b) body - uses funs (Let (NonRec _b letExpr) body) - = uses funs letExpr || uses funs body - uses funs (Case e _eId _ty alts) - = uses funs e || any (uses_alt funs) alts - uses _ _ = False - - uses_alt funs (_, _bs, e) = uses funs e +vectScalarFunMaybe :: CoreExpr -- ^ Expression to be vectorised + -> VITree -- ^ Vectorisation information + -> VM VExpr +vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr +vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function") + +-- |Vectorise an expression of functional type by lifting it by an application of a member of the +-- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) This is only a valid strategy if the +-- function does not contain parallel subcomputations and has only 'Scalar' types in its result and +-- arguments — this is a predcondition for calling this function. +-- +-- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised, +-- instead they become dictionaries of vectorised methods). We treat them differently, though see +-- "Note [Scalar dfuns]" in 'Vectorise'. +-- +vectScalarFun :: CoreExpr -> VM VExpr +vectScalarFun expr + = do + { traceVt "vectScalarFun" (ppr expr) + ; let (arg_tys, res_ty) = splitFunTys (exprType expr) + ; mkScalarFun arg_tys res_ty expr + } -- Generate code for a scalar function by generating a scalar closure. If the function is a -- dictionary function, vectorise it as dictionary code. @@ -883,9 +604,8 @@ mkScalarFun arg_tys res_ty expr -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire. -- vectScalarDFun :: Var -- ^ Original dfun - -> [Var] -- ^ Functions names in same recursive binding group -> VM CoreExpr -vectScalarDFun var recFns +vectScalarDFun var = do { -- bring the type variables into scope ; mapM_ defLocalTyVar tvs @@ -901,7 +621,7 @@ vectScalarDFun var recFns dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict]) selIds - ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun recFns e) scsOps + ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps -- vectorised applications of the class-dictionary data constructor ; Just vDataCon <- lookupDataCon dataCon @@ -943,7 +663,7 @@ unVectDict ty e Nothing -> panic "Vectorise.Exp.unVectDict: no class" selIds = classAllSelIds cls --- |Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures. +-- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures. -- -- All non-dictionary free variables go into the closure's environment, whereas the dictionary -- variables are passed explicit (as conventional arguments) into the body during closure @@ -1013,8 +733,9 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi | otherwise = return (ve, le) vectLam _ _ _ _ = panic "vectLam" --- | Vectorise an algebraic case expression. --- We convert +-- Vectorise an algebraic case expression. +-- +-- We convert -- -- case e :: t of v { ... } -- @@ -1167,9 +888,172 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) _ -> return [] -vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ []) +vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _) = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon) + +-- Support to compute information for vectorisation avoidance ------------------ + +-- Annotation for Core AST nodes that describes how they should be handled during vectorisation +-- and especially if vectorisation of the corresponding computation can be avoided. +-- +data VectAvoidInfo = VIParr -- tree contains parallel computations + | VISimple -- result type is scalar & no parallel subcomputation + | VIComplex -- any result type, no parallel subcomputation + | VIEncaps -- tree encapsulated by 'liftSimple' + deriving (Eq, Show) + +-- Instead of integrating the vectorisation avoidance information into Core expression, we keep +-- them in a separate tree (that structurally mirrors the Core expression that it annotates). +-- +data VITree = VITNode VectAvoidInfo [VITree] + deriving (Show) + +-- Is any of the tree nodes a 'VIPArr' node? +-- +anyVIPArr :: [VITree] -> Bool +anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr)) + +-- Compute Core annotations to determine for which subexpressions we can avoid vectorisation +-- +-- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure, +-- that there are no free variables in encapsulated lambda expressions +vectAvoidInfo :: CoreExprWithFVs -> VM VITree +vectAvoidInfo ce@(_, AnnVar v) + = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce + ; viTrace ce vi [] + ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce)) + ; return $ VITNode vi [] + } + +vectAvoidInfo ce@(_, AnnLit _) + = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce + ; viTrace ce vi [] + ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce) + ; return $ VITNode vi [] + } + +vectAvoidInfo ce@(_, AnnApp e1 e2) + = do { vt1 <- vectAvoidInfo e1 + ; vt2 <- vectAvoidInfo e2 + ; vi <- if anyVIPArr [vt1, vt2] + then return VIParr + else vectAvoidInfoType $ exprType $ deAnnotate ce + ; viTrace ce vi [vt1, vt2] + ; return $ VITNode vi [vt1, vt2] + } + +vectAvoidInfo ce@(_, AnnLam _var body) + = do { vt@(VITNode vi _) <- vectAvoidInfo body + ; viTrace ce vi [vt] + ; let resultVI | vi == VIParr = VIParr + | otherwise = VIComplex + ; return $ VITNode resultVI [vt] + } + +vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body) + = do { vtE <- vectAvoidInfo expr + ; vtB <- vectAvoidInfo body + ; vi <- if anyVIPArr [vtE, vtB] + then return VIParr + else vectAvoidInfoType $ exprType $ deAnnotate ce + ; viTrace ce vi [vtE, vtB] + ; return $ VITNode vi [vtE, vtB] + } + +vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body) + = do { let (_, exprs) = unzip bnds + ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs + ; if (anyVIPArr vtBnds) + then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs + ; vtB <- vectAvoidInfo body + ; return (VITNode VIParr (vtB: vtBnds')) + } + else do { vtB@(VITNode vib _) <- vectAvoidInfo body + ; ni <- if (vib == VIParr) + then return VIParr + else vectAvoidInfoType $ exprType $ deAnnotate ce + ; viTrace ce ni (vtB : vtBnds) + ; return $ VITNode ni (vtB : vtBnds) + } + } + +vectAvoidInfo ce@(_, AnnCase expr _var _ty alts) + = do { vtExpr <- vectAvoidInfo expr + ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts + ; ni <- if anyVIPArr (vtExpr : vtAlts) + then return VIParr + else vectAvoidInfoType $ exprType $ deAnnotate ce + ; viTrace ce ni (vtExpr : vtAlts) + ; return $ VITNode ni (vtExpr: vtAlts) + } + +vectAvoidInfo (_, AnnCast expr _) + = do { vt@(VITNode vi _) <- vectAvoidInfo expr + ; return $ VITNode vi [vt] + } + +vectAvoidInfo (_, AnnTick _ expr) + = do { vt@(VITNode vi _) <- vectAvoidInfo expr + ; return $ VITNode vi [vt] + } + +vectAvoidInfo (_, AnnType {}) + = return $ VITNode VISimple [] + +vectAvoidInfo (_, AnnCoercion {}) + = return $ VITNode VISimple [] + +-- Compute vectorisation avoidance information for a type. +-- +vectAvoidInfoType :: Type -> VM VectAvoidInfo +vectAvoidInfoType ty + | maybeParrTy ty = return VIParr + | otherwise + = do { sType <- isSimpleType ty + ; if sType + then return VISimple + else return VIComplex + } + +-- Checks whether the type might be a parallel array type. In particular, if the outermost +-- constructor is a type family, we conservatively assume that it may be a parallel array type. +-- +maybeParrTy :: Type -> Bool +maybeParrTy ty + | Just ty' <- coreView ty = maybeParrTy ty' + | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon + || or (map maybeParrTy ts) +maybeParrTy _ = False + +-- FIXME: This should not be hardcoded. +isSimpleType :: Type -> VM Bool +isSimpleType ty + | Just (c, _cs) <- splitTyConApp_maybe ty + = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] +{- + = do { globals <- globalScalarTyCons + ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c) + ; return (elemNameSet (tyConName c) globals ) + } + -} + | Nothing <- splitTyConApp_maybe ty + = return False +isSimpleType ty + = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty) + +varsSimple :: VarSet -> VM Bool +varsSimple vs + = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs + ; return $ and varTypes + } + +viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM () +viTrace ce vi vTs + = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") + (ppr $ deAnnotate ce) + + {- ---- Sanity check of the tree, for debugging only checkTree :: VITree -> CoreExpr -> Bool @@ -1178,44 +1062,33 @@ checkTree (VITNode _ []) (Type _ty) checkTree (VITNode _ []) (Var _v) = True - checkTree (VITNode _ []) (Lit _) = True - checkTree (VITNode _ [vit]) (Tick _ expr) = checkTree vit expr - - checkTree (VITNode _ [vit]) (Lam _ expr) = checkTree vit expr - checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2) = (checkTree vit1 ce1) && (checkTree vit2 ce2) - - - + checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts) = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts) where checkAlt vt (_, _, expr) = checkTree vt expr - checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2) = (checkTree vt1 expr1) && (checkTree vt2 expr2) - - checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr) = (and $ zipWith checkBndr vtBnds bndngs) && (checkTree vtB expr) where checkBndr vt (_, e) = checkTree vt e - - + checkTree (VITNode _ [vit]) (Cast expr _) = checkTree vit expr |