diff options
author | Gabriele Keller <keller@cse.unsw.edu.au> | 2012-04-24 12:15:31 +1000 |
---|---|---|
committer | Gabriele Keller <keller@cse.unsw.edu.au> | 2012-04-24 23:36:43 +1000 |
commit | edd95cc954e91a93133a72c22e1293ffdf6b4169 (patch) | |
tree | 43dcb50d9b46cad2da1ba106acec1615ca6a902b /compiler/vectorise | |
parent | 981e4f134a7a9a6a0f469bad3d0d13de45c6f2c7 (diff) | |
download | haskell-edd95cc954e91a93133a72c22e1293ffdf6b4169.tar.gz |
Vectorisation Avoidance
Switched off by default. Use -favoid-vect to activate
Diffstat (limited to 'compiler/vectorise')
-rw-r--r-- | compiler/vectorise/Vectorise/Env.hs | 2 | ||||
-rw-r--r-- | compiler/vectorise/Vectorise/Exp.hs | 509 |
2 files changed, 328 insertions, 183 deletions
diff --git a/compiler/vectorise/Vectorise/Env.hs b/compiler/vectorise/Vectorise/Env.hs index dfa20698c8..a887e7736f 100644 --- a/compiler/vectorise/Vectorise/Env.hs +++ b/compiler/vectorise/Vectorise/Env.hs @@ -32,7 +32,7 @@ import NameEnv import FastString import TysPrim import TysWiredIn -import DataCon + import Data.Maybe diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 66912de784..3012a3fae1 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -33,7 +33,6 @@ import TyCon import TcType import Type import PrelNames -import NameSet import Var import VarEnv import VarSet @@ -48,9 +47,8 @@ import Control.Monad import Control.Applicative import Data.Maybe import Data.List - - -import Debug.Trace +import TcRnMonad (doptM) +import DynFlags (DynFlag(Opt_AvoidVect)) -- For prototyping, the VITree is a separate data structure with the same shape as the corresponding expression @@ -59,6 +57,7 @@ import Debug.Trace data VectInfo = VIParr | VISimple | VIComplex + | VIEncaps deriving (Eq, Show) data VITree = VITNode VectInfo [VITree] @@ -71,25 +70,27 @@ viTrace ce vi vTs = viOr :: [VITree] -> Bool viOr = or . (map (\(VITNode vi _) -> vi == VIParr)) - + +-- TODO: free scalar vars don't actually need to be passed through, since encapsulations makes sure, that there are +-- no free variables in encapsulated lambda expressions vectInfo:: CoreExprWithFVs -> VM VITree vectInfo ce@(_, AnnVar v) - = do { vi <- vectInfoType $ exprType $ deAnnotate ce + = do { vi <- vectInfoType $ exprType $ deAnnotate ce ; viTrace ce vi [] ; traceVt "vectInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce)) ; return $ VITNode vi [] } -vectInfo ce@(_, AnnLit _) +vectInfo ce@(_, AnnLit _) = do { vi <- vectInfoType $ exprType $ deAnnotate ce ; viTrace ce vi [] ; traceVt "vectInfo AnnLit" (ppr $ exprType $ deAnnotate ce) ; return $ VITNode vi [] } -vectInfo ce@(_, AnnApp e1 e2) - = do { vt1 <- vectInfo e1 - ; vt2 <- vectInfo e2 +vectInfo ce@(_, AnnApp e1 e2) + = do { vt1 <- vectInfo e1 + ; vt2 <- vectInfo e2 ; vi <- if viOr [vt1, vt2] then return VIParr else vectInfoType $ exprType $ deAnnotate ce @@ -97,15 +98,17 @@ vectInfo ce@(_, AnnApp e1 e2) ; return $ VITNode vi [vt1, vt2] } -vectInfo ce@(_, AnnLam _ body) - = do { vt@(VITNode vi _) <- vectInfo body - ; viTrace ce vi [vt] - ; return $ VITNode vi [vt] +vectInfo ce@(_, AnnLam _var body) + = do { vt@(VITNode vi _) <- vectInfo body + ; viTrace ce vi [vt] + ; if (vi == VIParr) + then return $ VITNode vi [vt] + else return $ VITNode VIComplex [vt] } -vectInfo ce@(_, AnnLet (AnnNonRec _ expr) body) - = do { vtE <- vectInfo expr - ; vtB <- vectInfo body +vectInfo ce@(_, AnnLet (AnnNonRec _var expr) body) + = do { vtE <- vectInfo expr + ; vtB <- vectInfo body ; vi <- if viOr [vtE, vtB] then return VIParr else vectInfoType $ exprType $ deAnnotate ce @@ -113,39 +116,45 @@ vectInfo ce@(_, AnnLet (AnnNonRec _ expr) body) ; return $ VITNode vi [vtE, vtB] } -vectInfo ce@(_, AnnLet (AnnRec bnds) body) - = do { vtB <- vectInfo body - ; let exprs = snd $ unzip bnds - ; vtBnds <- mapM vectInfo exprs - ; ni <- if viOr (vtB : vtBnds) - then return VIParr - else vectInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtB : vtBnds) - ; return $ VITNode ni (vtB : vtBnds) +vectInfo ce@(_, AnnLet (AnnRec bnds) body) + = do { let (_, exprs) = unzip bnds + ; vtBnds <- mapM (\e -> vectInfo e) exprs + ; if (viOr vtBnds) + then do { vtBnds' <- mapM (\e -> vectInfo e) exprs + ; vtB <- vectInfo body + ; return (VITNode VIParr (vtB: vtBnds')) + } + else do { vtB@(VITNode vib _) <- vectInfo body + ; ni <- if (vib == VIParr) + then return VIParr + else vectInfoType $ exprType $ deAnnotate ce + ; viTrace ce ni (vtB : vtBnds) + ; return $ VITNode ni (vtB : vtBnds) + } } vectInfo ce@(_, AnnCase expr _var _ty alts) - = do { vtExpr <- vectInfo expr + = do { vtExpr <- vectInfo expr ; vtAlts <- mapM (\(_, _, e) -> vectInfo e) alts ; ni <- if viOr (vtExpr : vtAlts) then return VIParr else vectInfoType $ exprType $ deAnnotate ce ; viTrace ce ni (vtExpr : vtAlts) - ; return $ VITNode ni (vtExpr : vtAlts) + ; return $ VITNode ni (vtExpr: vtAlts) } -vectInfo (_, AnnCast expr _) - = do { vt@(VITNode vi _) <- vectInfo expr +vectInfo (_, AnnCast expr _) + = do { vt@(VITNode vi _) <- vectInfo expr ; return $ VITNode vi [vt] } vectInfo (_, AnnTick _ expr ) - = do { vt@(VITNode vi _) <- vectInfo expr + = do { vt@(VITNode vi _) <- vectInfo expr ; return $ VITNode vi [vt] } -vectInfo (_, AnnType {}) +vectInfo (_, AnnType {}) = return $ VITNode VISimple [] vectInfo (_, AnnCoercion {}) @@ -176,11 +185,13 @@ maybeParrTy _ = False isSimpleType:: Type -> VM Bool isSimpleType ty - | Just (c, _cs) <- splitTyConApp_maybe ty + | Just (c, _cs) <- splitTyConApp_maybe ty = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] +{- = do { globals <- globalScalarTyCons - ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c) - ; return (elemNameSet (tyConName c) globals ) - } + ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c) + ; return (elemNameSet (tyConName c) globals ) + } + -} | Nothing <- splitTyConApp_maybe ty = return False isSimpleType ty @@ -201,27 +212,48 @@ vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) ; return (inline, isScalarFn, vTick tickish expr') } + + vectPolyExpr loop_breaker recFns expr - = do { let vectAvoidance = True - ; (tvs, mono) <- if vectAvoidance - then do { vi <- vectInfo expr - ; extExpr <- encapsulateScalar vi expr - ; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr) - ; return $ collectAnnTypeBinders extExpr - } - else return $ collectAnnTypeBinders expr + = do { vectAvoidance <- liftDs $ doptM Opt_AvoidVect + ; vi <- vectInfo expr + ; ((tvs, mono), vi') <- + if vectAvoidance + then do { (extExpr, vi') <- encapsulateScalar vi expr + ; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr) + ; return $ (collectAnnTypeBinders extExpr , vi') + } + else return $ (collectAnnTypeBinders expr, vi) ; arity <- polyArity tvs ; polyAbstract tvs $ \args -> - do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono + do {(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi' ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') } } - - - +-- todo: clean this + +vectPolyExprVT:: Bool -> [Var] -> CoreExprWithFVs -> VITree + -> VM (Inline, Bool, VExpr) --- | encapsulate every purely sequentail subexpression with a simple return type +-- vectPolyExprVT _loop_breaker _recFns e vi | not (checkTree vi (deAnnotate e)) +-- = pprPanic "vectPolyExprVT" (ppr $ deAnnotate e) +vectPolyExprVT loop_breaker recFns (_, AnnTick tickish expr) (VITNode _ [vit]) + = do { (inline, isScalarFn, expr') <- vectPolyExprVT loop_breaker recFns expr vit + ; return (inline, isScalarFn, vTick tickish expr') + } + +vectPolyExprVT loop_breaker recFns expr vi + = do { -- checkTreeAnnM vi expr ; + let (tvs, mono) = collectAnnTypeBinders expr + ; arity <- polyArity tvs + ; polyAbstract tvs $ \args -> + do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi + ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') + } + } + +-- | encapsulate every purely sequential subexpression with a simple return type -- of a (potentially) parallel expression into a lambda abstraction over all its -- free variables followed by the corresponding application to those variables. -- Condition: @@ -229,90 +261,103 @@ vectPolyExpr loop_breaker recFns expr -- the expression is 'complex enough', which is, for now, every expression -- which is not constant and contains at least one operation. -- -encapsulateScalar :: VITree -> CoreExprWithFVs -> VM CoreExprWithFVs -encapsulateScalar _ ce@(_, AnnType _ty) - = return ce +encapsulateScalar :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree) +encapsulateScalar vit ce@(_, AnnType _ty) + = return (ce, vit) -encapsulateScalar _ ce@(_, AnnVar _v) - = return ce +encapsulateScalar vit ce@(_, AnnVar _v) + = return (ce, vit) -encapsulateScalar _ ce@(_, AnnLit _) - = return ce +encapsulateScalar vit ce@(_, AnnLit _) + = return (ce, vit) -encapsulateScalar (VITNode _ [vit]) (fvs, AnnTick tck expr) - = do { extExpr <- encapsulateScalar vit expr - ; return (fvs, AnnTick tck extExpr) +encapsulateScalar (VITNode vi [vit]) (fvs, AnnTick tck expr) + = do { (extExpr, vit') <- encapsulateScalar vit expr + ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit']) } encapsulateScalar _ (_fvs, AnnTick _tck _expr) = panic "encapsulateScalar AnnTick doesn't match up" -encapsulateScalar (VITNode _ [vit]) (fvs, AnnLam bndr expr) - = do { extExpr <- encapsulateScalar vit expr - ; return (fvs, AnnLam bndr extExpr) +encapsulateScalar (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr) + = do { varsS <- varsSimple fvs + ; case (vi, varsS) of + (VISimple, True) -> do { let (e', vit') = encaps vit ce + ; return (e', vit') + } + _ -> do { (extExpr, vit') <- encapsulateScalar vit expr + ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit']) + } } encapsulateScalar _ (_fvs, AnnLam _bndr _expr) = panic "encapsulateScalar AnnLam doesn't match up" -encapsulateScalar (VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) +encapsulateScalar vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps ce - _ -> do { etaCe1 <- encapsulateScalar vit1 ce1 - ; etaCe2 <- encapsulateScalar vit2 ce2 - ; return (fvs, AnnApp etaCe1 etaCe2) + (VISimple, True) -> do { let (e', vt') = encaps vt ce + -- ; checkTreeAnnM vt' e' + -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e') + ; return (e', vt') + } + _ -> do { (etaCe1, vit1') <- encapsulateScalar vit1 ce1 + ; (etaCe2, vit2') <- encapsulateScalar vit2 ce2 + ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2']) } } encapsulateScalar _ (_fvs, AnnApp _ce1 _ce2) = panic "encapsulateScalar AnnApp doesn't match up" -encapsulateScalar (VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) +encapsulateScalar vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps ce - _ -> do { extScrut <- encapsulateScalar scrutVit scrut - ; extAlts <- zipWithM expAlt altVits alts - ; return (fvs, AnnCase extScrut bndr ty extAlts) + (VISimple, True) -> return $ encaps vt ce + _ -> do { (extScrut, scrutVit') <- encapsulateScalar scrutVit scrut + ; extAltsVits <- zipWithM expAlt altVits alts + ; let (extAlts, altVits') = unzip extAltsVits + ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits')) } } where expAlt vt (con, bndrs, expr) - = do { extExpr <- encapsulateScalar vt expr - ; return (con, bndrs, extExpr) + = do { (extExpr, vt') <- encapsulateScalar vt expr + ; return ((con, bndrs, extExpr), vt') } encapsulateScalar _ (_fvs, AnnCase _scrut _bndr _ty _alts) = panic "encapsulateScalar AnnCase doesn't match up" -encapsulateScalar (VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) +encapsulateScalar vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps ce - _ -> do { extExpr1 <- encapsulateScalar vt1 expr1 - ; extExpr2 <- encapsulateScalar vt2 expr2 - ; return (fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2) + (VISimple, True) -> return $ encaps vt ce + _ -> do { (extExpr1, vt1') <- encapsulateScalar vt1 expr1 + ; (extExpr2, vt2') <- encapsulateScalar vt2 expr2 + ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2']) } } encapsulateScalar _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2) = panic "encapsulateScalar AnnLet nonrec doesn't match up" -encapsulateScalar (VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) +encapsulateScalar vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) = do { varsS <- varsSimple fvs ; case (vi, varsS) of - (VISimple, True) -> return $ encaps ce - _ -> do { extBnds <- zipWithM expBndg vtBnds bndngs - ; extExpr <- encapsulateScalar vtB expr - ; return (fvs, AnnLet (AnnRec extBnds) extExpr) + (VISimple, True) -> return $ encaps vt ce + _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs + ; let (extBnds, vtBnds') = unzip extBndsVts + ; (extExpr, vtB') <- encapsulateScalar vtB expr + ; let vt' = VITNode vi (vtB':vtBnds') + ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt') } } where expBndg vit (bndr, expr) - = do { extExpr <- encapsulateScalar vit expr - ; return (bndr, extExpr) + = do { (extExpr, vit') <- encapsulateScalar vit expr + ; return ((bndr, extExpr), vit') } encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2) @@ -320,9 +365,9 @@ encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2) -encapsulateScalar (VITNode _ [vit]) (fvs, AnnCast expr coercion) - = do { extExpr <- encapsulateScalar vit expr - ; return (fvs, AnnCast extExpr coercion) +encapsulateScalar (VITNode vi [vit]) (fvs, AnnCast expr coercion) + = do { (extExpr, vit') <- encapsulateScalar vit expr + ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit']) } encapsulateScalar _ (_fvs, AnnCast _expr _coercion) @@ -333,56 +378,75 @@ encapsulateScalar _ _ = panic "encapsulateScalar case not handled" -mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet -mkAnnLam bndr ce = AnnLam bndr ce -mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs -mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check! -mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs - -mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet) -mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v)) - -mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs -mkAnnApps (fv, aex') [] = (fv, aex') -mkAnnApps ae (v:vs) = - let - (fv, aex') = mkAnnApps ae vs - in (extendVarSet fv v, mkAnnApp (fv, aex') v) -- CoreExprWithFVs, -- = AnnExpr Id VarSet -- AnnExpr bndr VarSet = (annot, AnnExpr' bndr VarSet) -- AnnLam :: bndr -> (AnnExpr bndr VarSet) -> AnnExpr' bndr VarSet -- AnnLam bndr (AnnExpr bndr annot) -encaps :: CoreExprWithFVs -> CoreExprWithFVs -encaps (fvs, AnnCase expr bndr t alts) +encaps :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree) +encaps (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts) | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr), (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- TODO: globalScalarTyCons - = (fvs, AnnCase expr bndr t (map (\(ac, bndrs, aex) -> (ac, bndrs, encaps aex)) alts)) + = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits')) -encaps ae@(fvs, _annEx) - = let - vars = varSetElems fvs - in mkAnnApps (mkAnnLams ae vars) vars + where + (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $ + zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, encaps altVi aex)) alts altVits + +encaps viTree ae@(fvs, _annEx) + = (mkAnnApps (mkAnnLams ae vars) vars, viTree') + where + mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits + mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs] + mkViTreeApps vi [] = vi + mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []] + + vars = varSetElems fvs + viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars + + mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet + mkAnnLam bndr ce = AnnLam bndr ce + + mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs + mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check! + mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs + + mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet) + mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v)) + + mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs + mkAnnApps (fv, aex') [] = (fv, aex') + mkAnnApps ae (v:vs) = + let + (fv, aex') = mkAnnApps ae vs + in (extendVarSet fv v, mkAnnApp (fv, aex') v) + + + + + -- |Vectorise an expression. -- -vectExpr :: CoreExprWithFVs -> VM VExpr - -vectExpr (_, AnnVar v) +vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr +-- vectExpr e vi | not (checkTree vi (deAnnotate e)) +-- = pprPanic "vectExpr" (ppr $ deAnnotate e) + +vectExpr (_, AnnVar v) _ = vectVar v -vectExpr (_, AnnLit lit) +vectExpr (_, AnnLit lit) _ = vectConst $ Lit lit -vectExpr e@(_, AnnLam bndr _) - | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e +vectExpr e@(_, AnnLam bndr _) vt + | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty'; -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint -- happy. -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now? -vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) +vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _ | v == pAT_ERROR_ID = do { (vty, lty) <- vectAndLiftType ty ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err']) @@ -392,13 +456,13 @@ vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) -- type application (handle multiple consecutive type applications simultaneously to ensure the -- PA dictionaries are put at the right places) -vectExpr e@(_, AnnApp _ arg) +vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _]) | isAnnTypeArg arg = vectPolyApp e -- 'Int', 'Float', or 'Double' literal -- FIXME: this needs to be generalised -vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) +vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _ | Just con <- isDataConId_maybe v , is_special_con con = do @@ -409,17 +473,17 @@ vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon] -- value application (dictionary or user value) -vectExpr e@(_, AnnApp fn arg) +vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2]) | isPredTy arg_ty -- dictionary application (whose result is not a dictionary) = vectPolyApp e | otherwise -- user value = do { -- vectorise the types - ; varg_ty <- vectType arg_ty + ; varg_ty <- vectType arg_ty ; vres_ty <- vectType res_ty -- vectorise the function and argument expression - ; vfn <- vectExpr fn - ; varg <- vectExpr arg + ; vfn <- vectExpr fn vit1 + ; varg <- vectExpr arg vit2 -- the vectorised function is a closure; apply it to the vectorised argument ; mkClosureApp varg_ty vres_ty vfn varg @@ -427,42 +491,43 @@ vectExpr e@(_, AnnApp fn arg) where (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn -vectExpr (_, AnnCase scrut bndr ty alts) +vectExpr (_, AnnCase scrut bndr ty alts) vt | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty , isAlgTyCon tycon - = vectAlgCase tycon ty_args scrut bndr ty alts + = vectAlgCase tycon ty_args scrut bndr ty alts vt | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) where scrut_ty = exprType (deAnnotate scrut) -vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) +vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2]) = do - vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExprVT False [] rhs vt1 + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2) return $ vLet (vNonRec vbndr vrhs) vbody -vectExpr (_, AnnLet (AnnRec bs) body) +vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds)) = do (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ liftM2 (,) - (zipWithM vect_rhs bndrs rhss) - (vectExpr body) + (zipWith3M vect_rhs bndrs rhss vtBnds) + (vectExpr body vtB) return $ vLet (vRec vbndrs vrhss) vbody where (bndrs, rhss) = unzip bs - vect_rhs bndr rhs = localV - . inBind bndr - . liftM (\(_,_,z)->z) - $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs + vect_rhs bndr rhs vt = localV + . inBind bndr + . liftM (\(_,_,z)->z) + $ vectPolyExprVT (isStrongLoopBreaker $ idOccInfo bndr) [] rhs vt + zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs) -vectExpr (_, AnnTick tickish expr) - = liftM (vTick tickish) (vectExpr expr) +vectExpr (_, AnnTick tickish expr) (VITNode _ [vit]) + = liftM (vTick tickish) (vectExpr expr vit) -vectExpr (_, AnnType ty) +vectExpr (_, AnnType ty) _ = liftM vType (vectType ty) -vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e) +vectExpr e _ = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e) -- |Vectorise an expression that *may* have an outer lambda abstraction. -- @@ -475,23 +540,29 @@ vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether -> Bool -- ^ Whether the binding is a loop breaker -> [Var] -- ^ Names of function in same recursive binding group -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` + -> VITree -> VM (Inline, Bool, VExpr) -vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) + +-- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e)) +-- = pprPanic "vectFnExpr" (ppr $ deAnnotate e) + + +vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt']) -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type | isId bndr && isPredTy (idType bndr) = do { vBndr <- vectBndr bndr - ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body + ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt' ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody) } -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation) | isId bndr - = mark DontInline True (vectScalarFun False recFns (deAnnotate expr)) + = mark DontInline True (vectScalarFunVT False recFns (deAnnotate expr) vt) `orElseV` - mark inlineMe False (vectLam inline loop_breaker expr) -vectFnExpr _ _ _ e + mark inlineMe False (vectLam inline loop_breaker expr vt) +vectFnExpr _ _ _ e vt -- not an abstraction: vectorise as a vanilla expression - = mark DontInline False $ vectExpr e + = mark DontInline False $ vectExpr e vt mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } @@ -611,7 +682,7 @@ vectDictExpr (Coercion coe) -- |Vectorise an expression of functional type, where all arguments and the result are of primitive -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not --- requires the full blown vectorisation transformation; instead, they can be lifted by application +-- require the full blown vectorisation transformation; instead, they can be lifted by application -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) -- -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised, @@ -622,7 +693,18 @@ vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user? -> [Var] -- ^ Functions names in same recursive binding group -> CoreExpr -- ^ Expression to be vectorised -> VM VExpr -vectScalarFun forceScalar recFns expr +vectScalarFun forceScalar recFns expr + = vectScalarFunVT forceScalar recFns expr (VITNode VISimple []) + + + + +vectScalarFunVT :: Bool -- ^ Was the function marked as scalar by the user? + -> [Var] -- ^ Functions names in same recursive binding group + -> CoreExpr -- ^ Expression to be vectorised + -> VITree + -> VM VExpr +vectScalarFunVT forceScalar recFns expr (VITNode vi _) = do { gscalarVars <- globalScalarVars ; scalarTyCons <- globalScalarTyCons ; let scalarVars = gscalarVars `extendVarSetList` recFns @@ -632,12 +714,15 @@ vectScalarFun forceScalar recFns expr "\n\tall tycons scalar? : " ++ (show $all (is_scalar_ty scalarTyCons) arg_tys) ++ "\n\tresult scalar? : " ++ (show $is_scalar_ty scalarTyCons res_ty) ++ "\n\tscalar body? : " ++ (show $is_scalar scalarVars (is_scalar_ty scalarTyCons) expr) ++ - "\n\tuses vars? : " ++ (show $uses scalarVars expr) + "\n\tuses vars? : " ++ (show $uses scalarVars expr) ++ + "\n\t is encaps? : " ++ (show vi) ) (ppr expr) ; onlyIfV (ptext (sLit "not a scalar function")) (forceScalar -- user asserts the functions is scalar || + (vi == VIEncaps) -- should only be true if all the foll. cond are hold + || all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar && is_scalar_ty scalarTyCons res_ty && is_scalar scalarVars (is_scalar_ty scalarTyCons) expr @@ -664,25 +749,17 @@ vectScalarFun forceScalar recFns expr -- any 'scalarTyCons', but can't at the moment, as those argument and result types -- need to be members of the 'Scalar' class (that in its current form would better -- be called 'Primitive'). *ALSO* the hardcoded list of types is ugly! - is_primitive_ty ty - | isPredTy ty -- dictionaries never get into the environment - = True - | Just (tycon, _) <- splitTyConApp_maybe ty - = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName] - | otherwise - = False - -} - + -} is_scalar_ty _scalarTyCons ty | isPredTy ty -- dictionaries never get into the environment = True | Just (tycon, []) <- splitTyConApp_maybe ty -- TODO: FIX THIS! - = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName] --- = tyConName tycon `elemNameSet` scalarTyCons + = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] +-- FIXME: = tyConName tycon `elemNameSet` scalarTyCons | Just (tycon, _) <- splitTyConApp_maybe ty - = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName] + = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] --- = tyConName tycon `elemNameSet` scalarTyCons +-- FIXME: = tyConName tycon `elemNameSet` scalarTyCons | otherwise = False @@ -706,7 +783,9 @@ vectScalarFun forceScalar recFns expr | maybe_parr_ty (varType var) = False | otherwise = is_scalar (scalars `extendVarSet` var) isScalarTC body - is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++ " " ++ (show $ is_scalar scalars' isScalarTC body) ++ (show $ showSDoc $ ppr bind)) $ + is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++ + " " ++ (show $ is_scalar scalars' isScalarTC body) ++ + (show $ showSDoc $ ppr bind)) $ bindsAreScalar && is_scalar scalars' isScalarTC body where @@ -884,8 +963,9 @@ unVectDict ty e vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. -> Bool -- ^ Whether the binding is a loop breaker. -> CoreExprWithFVs -- ^ Body of abstraction. + -> VITree -> VM VExpr -vectLam inline loop_breaker expr@(fvs, AnnLam _ _) +vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi = do { let (bndrs, body) = collectAnnValBinders expr -- grab the in-scope type variables @@ -913,13 +993,18 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity) $ do { -- generate the vectorised body of the lambda abstraction ; lc <- builtin liftingContext - ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body) + ; let viBody = stripLams expr vi + -- ; checkTreeAnnM vi expr + ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody) ; vbody' <- break_loop lc res_ty vbody ; return $ vLams lc vbndrs vbody' } } where + stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt + stripLams _ vi = vi + maybe_inline n | inline = Inline n | otherwise = DontInline @@ -937,7 +1022,7 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) (LitAlt (mkMachInt 0), [], empty)]) } | otherwise = return (ve, le) -vectLam _ _ _ = panic "vectLam" +vectLam _ _ _ _ = panic "vectLam" -- | Vectorise an algebraic case expression. -- We convert @@ -955,31 +1040,31 @@ vectLam _ _ _ = panic "vectLam" -- -- FIXME: this is too lazy -vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type - -> [(AltCon, [Var], CoreExprWithFVs)] +vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type + -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree -> VM VExpr -vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] +vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit])) = do - vscrut <- vectExpr scrut + vscrut <- vectExpr scrut scrutVit (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit) return $ vCaseDEFAULT vscrut vbndr vty lty vbody -vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] +vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit])) = do - vscrut <- vectExpr scrut + vscrut <- vectExpr scrut scrutVit (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit) return $ vCaseDEFAULT vscrut vbndr vty lty vbody -vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] +vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit])) = do (vty, lty) <- vectAndLiftType ty - vexpr <- vectExpr scrut + vexpr <- vectExpr scrut scrutVit (vbndr, (vbndrs, (vect_body, lift_body))) <- vect_scrut_bndr . vectBndrsIn bndrs - $ vectExpr body + $ vectExpr body altVit let (vect_bndrs, lift_bndrs) = unzip vbndrs (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) vect_dc <- maybeV dataConErr (lookupDataCon dc) @@ -997,7 +1082,7 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc) -vectAlgCase tycon _ty_args scrut bndr ty alts +vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) = do vect_tc <- maybeV tyConErr (lookupTyCon tycon) (vty, lty) <- vectAndLiftType ty @@ -1008,10 +1093,10 @@ vectAlgCase tycon _ty_args scrut bndr ty alts let sel = Var sel_bndr (vbndr, valts) <- vect_scrut_bndr - $ mapM (proc_alt arity sel vty lty) alts' + $ mapM (proc_alt arity sel vty lty) (zip alts' altVits) let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts - vexpr <- vectExpr scrut + vexpr <- vectExpr scrut scrutVit (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) let (vect_bodies, lift_bodies) = unzip vbodies @@ -1043,7 +1128,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts cmp _ DEFAULT = GT cmp _ _ = panic "vectAlgCase/cmp" - proc_alt arity sel _ lty (DataAlt dc, bndrs, body) + proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi) = do vect_dc <- maybeV dataConErr (lookupDataCon dc) let ntag = dataConTagZ vect_dc @@ -1061,7 +1146,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts binds <- mapM (pack_var (Var lc) sel_tags tag) . filter isLocalId $ varSetElems fvs - (ve, le) <- vectExpr body + (ve, le) <- vectExpr body vi return (ve, Case (elems `App` sel) lc lty [(DEFAULT, [], (mkLets (concat binds) le))]) -- empty <- emptyPD vty @@ -1092,4 +1177,64 @@ vectAlgCase tycon _ty_args scrut bndr ty alts return [(NonRec lv' expr)] _ -> return [] + +vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ []) + = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon) + +---- Sanity check of the +{- +checkTree :: VITree -> CoreExpr -> Bool +checkTree (VITNode _ []) (Type _ty) + = True + +checkTree (VITNode _ []) (Var _v) + = True + + +checkTree (VITNode _ []) (Lit _) + = True + + +checkTree (VITNode _ [vit]) (Tick _ expr) + = checkTree vit expr + + + +checkTree (VITNode _ [vit]) (Lam _ expr) + = checkTree vit expr + + +checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2) + = (checkTree vit1 ce1) && (checkTree vit2 ce2) + + + +checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts) + = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts) + where + checkAlt vt (_, _, expr) = checkTree vt expr + + +checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2) + = (checkTree vt1 expr1) && (checkTree vt2 expr2) + + + +checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr) + = (and $ zipWith checkBndr vtBnds bndngs) && + (checkTree vtB expr) + where + checkBndr vt (_, e) = checkTree vt e + + +checkTree (VITNode _ [vit]) (Cast expr _) + = checkTree vit expr + +checkTree _ _ = False +checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM () +checkTreeAnnM vi e = + if not (checkTree vi $ deAnnotate e) + then error ("checkTreeAnnM : \n " ++ show vi) + else return () +-} |