summaryrefslogtreecommitdiff
path: root/compiler/vectorise
diff options
context:
space:
mode:
authorGabriele Keller <keller@cse.unsw.edu.au>2012-04-24 12:15:31 +1000
committerGabriele Keller <keller@cse.unsw.edu.au>2012-04-24 23:36:43 +1000
commitedd95cc954e91a93133a72c22e1293ffdf6b4169 (patch)
tree43dcb50d9b46cad2da1ba106acec1615ca6a902b /compiler/vectorise
parent981e4f134a7a9a6a0f469bad3d0d13de45c6f2c7 (diff)
downloadhaskell-edd95cc954e91a93133a72c22e1293ffdf6b4169.tar.gz
Vectorisation Avoidance
Switched off by default. Use -favoid-vect to activate
Diffstat (limited to 'compiler/vectorise')
-rw-r--r--compiler/vectorise/Vectorise/Env.hs2
-rw-r--r--compiler/vectorise/Vectorise/Exp.hs509
2 files changed, 328 insertions, 183 deletions
diff --git a/compiler/vectorise/Vectorise/Env.hs b/compiler/vectorise/Vectorise/Env.hs
index dfa20698c8..a887e7736f 100644
--- a/compiler/vectorise/Vectorise/Env.hs
+++ b/compiler/vectorise/Vectorise/Env.hs
@@ -32,7 +32,7 @@ import NameEnv
import FastString
import TysPrim
import TysWiredIn
-import DataCon
+
import Data.Maybe
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs
index 66912de784..3012a3fae1 100644
--- a/compiler/vectorise/Vectorise/Exp.hs
+++ b/compiler/vectorise/Vectorise/Exp.hs
@@ -33,7 +33,6 @@ import TyCon
import TcType
import Type
import PrelNames
-import NameSet
import Var
import VarEnv
import VarSet
@@ -48,9 +47,8 @@ import Control.Monad
import Control.Applicative
import Data.Maybe
import Data.List
-
-
-import Debug.Trace
+import TcRnMonad (doptM)
+import DynFlags (DynFlag(Opt_AvoidVect))
-- For prototyping, the VITree is a separate data structure with the same shape as the corresponding expression
@@ -59,6 +57,7 @@ import Debug.Trace
data VectInfo = VIParr
| VISimple
| VIComplex
+ | VIEncaps
deriving (Eq, Show)
data VITree = VITNode VectInfo [VITree]
@@ -71,25 +70,27 @@ viTrace ce vi vTs =
viOr :: [VITree] -> Bool
viOr = or . (map (\(VITNode vi _) -> vi == VIParr))
-
+
+-- TODO: free scalar vars don't actually need to be passed through, since encapsulations makes sure, that there are
+-- no free variables in encapsulated lambda expressions
vectInfo:: CoreExprWithFVs -> VM VITree
vectInfo ce@(_, AnnVar v)
- = do { vi <- vectInfoType $ exprType $ deAnnotate ce
+ = do { vi <- vectInfoType $ exprType $ deAnnotate ce
; viTrace ce vi []
; traceVt "vectInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
; return $ VITNode vi []
}
-vectInfo ce@(_, AnnLit _)
+vectInfo ce@(_, AnnLit _)
= do { vi <- vectInfoType $ exprType $ deAnnotate ce
; viTrace ce vi []
; traceVt "vectInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
; return $ VITNode vi []
}
-vectInfo ce@(_, AnnApp e1 e2)
- = do { vt1 <- vectInfo e1
- ; vt2 <- vectInfo e2
+vectInfo ce@(_, AnnApp e1 e2)
+ = do { vt1 <- vectInfo e1
+ ; vt2 <- vectInfo e2
; vi <- if viOr [vt1, vt2]
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
@@ -97,15 +98,17 @@ vectInfo ce@(_, AnnApp e1 e2)
; return $ VITNode vi [vt1, vt2]
}
-vectInfo ce@(_, AnnLam _ body)
- = do { vt@(VITNode vi _) <- vectInfo body
- ; viTrace ce vi [vt]
- ; return $ VITNode vi [vt]
+vectInfo ce@(_, AnnLam _var body)
+ = do { vt@(VITNode vi _) <- vectInfo body
+ ; viTrace ce vi [vt]
+ ; if (vi == VIParr)
+ then return $ VITNode vi [vt]
+ else return $ VITNode VIComplex [vt]
}
-vectInfo ce@(_, AnnLet (AnnNonRec _ expr) body)
- = do { vtE <- vectInfo expr
- ; vtB <- vectInfo body
+vectInfo ce@(_, AnnLet (AnnNonRec _var expr) body)
+ = do { vtE <- vectInfo expr
+ ; vtB <- vectInfo body
; vi <- if viOr [vtE, vtB]
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
@@ -113,39 +116,45 @@ vectInfo ce@(_, AnnLet (AnnNonRec _ expr) body)
; return $ VITNode vi [vtE, vtB]
}
-vectInfo ce@(_, AnnLet (AnnRec bnds) body)
- = do { vtB <- vectInfo body
- ; let exprs = snd $ unzip bnds
- ; vtBnds <- mapM vectInfo exprs
- ; ni <- if viOr (vtB : vtBnds)
- then return VIParr
- else vectInfoType $ exprType $ deAnnotate ce
- ; viTrace ce ni (vtB : vtBnds)
- ; return $ VITNode ni (vtB : vtBnds)
+vectInfo ce@(_, AnnLet (AnnRec bnds) body)
+ = do { let (_, exprs) = unzip bnds
+ ; vtBnds <- mapM (\e -> vectInfo e) exprs
+ ; if (viOr vtBnds)
+ then do { vtBnds' <- mapM (\e -> vectInfo e) exprs
+ ; vtB <- vectInfo body
+ ; return (VITNode VIParr (vtB: vtBnds'))
+ }
+ else do { vtB@(VITNode vib _) <- vectInfo body
+ ; ni <- if (vib == VIParr)
+ then return VIParr
+ else vectInfoType $ exprType $ deAnnotate ce
+ ; viTrace ce ni (vtB : vtBnds)
+ ; return $ VITNode ni (vtB : vtBnds)
+ }
}
vectInfo ce@(_, AnnCase expr _var _ty alts)
- = do { vtExpr <- vectInfo expr
+ = do { vtExpr <- vectInfo expr
; vtAlts <- mapM (\(_, _, e) -> vectInfo e) alts
; ni <- if viOr (vtExpr : vtAlts)
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
; viTrace ce ni (vtExpr : vtAlts)
- ; return $ VITNode ni (vtExpr : vtAlts)
+ ; return $ VITNode ni (vtExpr: vtAlts)
}
-vectInfo (_, AnnCast expr _)
- = do { vt@(VITNode vi _) <- vectInfo expr
+vectInfo (_, AnnCast expr _)
+ = do { vt@(VITNode vi _) <- vectInfo expr
; return $ VITNode vi [vt]
}
vectInfo (_, AnnTick _ expr )
- = do { vt@(VITNode vi _) <- vectInfo expr
+ = do { vt@(VITNode vi _) <- vectInfo expr
; return $ VITNode vi [vt]
}
-vectInfo (_, AnnType {})
+vectInfo (_, AnnType {})
= return $ VITNode VISimple []
vectInfo (_, AnnCoercion {})
@@ -176,11 +185,13 @@ maybeParrTy _ = False
isSimpleType:: Type -> VM Bool
isSimpleType ty
- | Just (c, _cs) <- splitTyConApp_maybe ty
+ | Just (c, _cs) <- splitTyConApp_maybe ty = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
+{-
= do { globals <- globalScalarTyCons
- ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
- ; return (elemNameSet (tyConName c) globals )
- }
+ ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
+ ; return (elemNameSet (tyConName c) globals )
+ }
+ -}
| Nothing <- splitTyConApp_maybe ty
= return False
isSimpleType ty
@@ -201,27 +212,48 @@ vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr)
; return (inline, isScalarFn, vTick tickish expr')
}
+
+
vectPolyExpr loop_breaker recFns expr
- = do { let vectAvoidance = True
- ; (tvs, mono) <- if vectAvoidance
- then do { vi <- vectInfo expr
- ; extExpr <- encapsulateScalar vi expr
- ; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr)
- ; return $ collectAnnTypeBinders extExpr
- }
- else return $ collectAnnTypeBinders expr
+ = do { vectAvoidance <- liftDs $ doptM Opt_AvoidVect
+ ; vi <- vectInfo expr
+ ; ((tvs, mono), vi') <-
+ if vectAvoidance
+ then do { (extExpr, vi') <- encapsulateScalar vi expr
+ ; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr)
+ ; return $ (collectAnnTypeBinders extExpr , vi')
+ }
+ else return $ (collectAnnTypeBinders expr, vi)
; arity <- polyArity tvs
; polyAbstract tvs $ \args ->
- do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
+ do {(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi'
; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
}
}
-
-
-
+-- todo: clean this
+
+vectPolyExprVT:: Bool -> [Var] -> CoreExprWithFVs -> VITree
+ -> VM (Inline, Bool, VExpr)
--- | encapsulate every purely sequentail subexpression with a simple return type
+-- vectPolyExprVT _loop_breaker _recFns e vi | not (checkTree vi (deAnnotate e))
+-- = pprPanic "vectPolyExprVT" (ppr $ deAnnotate e)
+vectPolyExprVT loop_breaker recFns (_, AnnTick tickish expr) (VITNode _ [vit])
+ = do { (inline, isScalarFn, expr') <- vectPolyExprVT loop_breaker recFns expr vit
+ ; return (inline, isScalarFn, vTick tickish expr')
+ }
+
+vectPolyExprVT loop_breaker recFns expr vi
+ = do { -- checkTreeAnnM vi expr ;
+ let (tvs, mono) = collectAnnTypeBinders expr
+ ; arity <- polyArity tvs
+ ; polyAbstract tvs $ \args ->
+ do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi
+ ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
+ }
+ }
+
+-- | encapsulate every purely sequential subexpression with a simple return type
-- of a (potentially) parallel expression into a lambda abstraction over all its
-- free variables followed by the corresponding application to those variables.
-- Condition:
@@ -229,90 +261,103 @@ vectPolyExpr loop_breaker recFns expr
-- the expression is 'complex enough', which is, for now, every expression
-- which is not constant and contains at least one operation.
--
-encapsulateScalar :: VITree -> CoreExprWithFVs -> VM CoreExprWithFVs
-encapsulateScalar _ ce@(_, AnnType _ty)
- = return ce
+encapsulateScalar :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
+encapsulateScalar vit ce@(_, AnnType _ty)
+ = return (ce, vit)
-encapsulateScalar _ ce@(_, AnnVar _v)
- = return ce
+encapsulateScalar vit ce@(_, AnnVar _v)
+ = return (ce, vit)
-encapsulateScalar _ ce@(_, AnnLit _)
- = return ce
+encapsulateScalar vit ce@(_, AnnLit _)
+ = return (ce, vit)
-encapsulateScalar (VITNode _ [vit]) (fvs, AnnTick tck expr)
- = do { extExpr <- encapsulateScalar vit expr
- ; return (fvs, AnnTick tck extExpr)
+encapsulateScalar (VITNode vi [vit]) (fvs, AnnTick tck expr)
+ = do { (extExpr, vit') <- encapsulateScalar vit expr
+ ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
}
encapsulateScalar _ (_fvs, AnnTick _tck _expr)
= panic "encapsulateScalar AnnTick doesn't match up"
-encapsulateScalar (VITNode _ [vit]) (fvs, AnnLam bndr expr)
- = do { extExpr <- encapsulateScalar vit expr
- ; return (fvs, AnnLam bndr extExpr)
+encapsulateScalar (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr)
+ = do { varsS <- varsSimple fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> do { let (e', vit') = encaps vit ce
+ ; return (e', vit')
+ }
+ _ -> do { (extExpr, vit') <- encapsulateScalar vit expr
+ ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
+ }
}
encapsulateScalar _ (_fvs, AnnLam _bndr _expr)
= panic "encapsulateScalar AnnLam doesn't match up"
-encapsulateScalar (VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
+encapsulateScalar vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
- (VISimple, True) -> return $ encaps ce
- _ -> do { etaCe1 <- encapsulateScalar vit1 ce1
- ; etaCe2 <- encapsulateScalar vit2 ce2
- ; return (fvs, AnnApp etaCe1 etaCe2)
+ (VISimple, True) -> do { let (e', vt') = encaps vt ce
+ -- ; checkTreeAnnM vt' e'
+ -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e')
+ ; return (e', vt')
+ }
+ _ -> do { (etaCe1, vit1') <- encapsulateScalar vit1 ce1
+ ; (etaCe2, vit2') <- encapsulateScalar vit2 ce2
+ ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
}
}
encapsulateScalar _ (_fvs, AnnApp _ce1 _ce2)
= panic "encapsulateScalar AnnApp doesn't match up"
-encapsulateScalar (VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
+encapsulateScalar vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
- (VISimple, True) -> return $ encaps ce
- _ -> do { extScrut <- encapsulateScalar scrutVit scrut
- ; extAlts <- zipWithM expAlt altVits alts
- ; return (fvs, AnnCase extScrut bndr ty extAlts)
+ (VISimple, True) -> return $ encaps vt ce
+ _ -> do { (extScrut, scrutVit') <- encapsulateScalar scrutVit scrut
+ ; extAltsVits <- zipWithM expAlt altVits alts
+ ; let (extAlts, altVits') = unzip extAltsVits
+ ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
}
}
where
expAlt vt (con, bndrs, expr)
- = do { extExpr <- encapsulateScalar vt expr
- ; return (con, bndrs, extExpr)
+ = do { (extExpr, vt') <- encapsulateScalar vt expr
+ ; return ((con, bndrs, extExpr), vt')
}
encapsulateScalar _ (_fvs, AnnCase _scrut _bndr _ty _alts)
= panic "encapsulateScalar AnnCase doesn't match up"
-encapsulateScalar (VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
+encapsulateScalar vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
- (VISimple, True) -> return $ encaps ce
- _ -> do { extExpr1 <- encapsulateScalar vt1 expr1
- ; extExpr2 <- encapsulateScalar vt2 expr2
- ; return (fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2)
+ (VISimple, True) -> return $ encaps vt ce
+ _ -> do { (extExpr1, vt1') <- encapsulateScalar vt1 expr1
+ ; (extExpr2, vt2') <- encapsulateScalar vt2 expr2
+ ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
}
}
encapsulateScalar _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
= panic "encapsulateScalar AnnLet nonrec doesn't match up"
-encapsulateScalar (VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
+encapsulateScalar vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
- (VISimple, True) -> return $ encaps ce
- _ -> do { extBnds <- zipWithM expBndg vtBnds bndngs
- ; extExpr <- encapsulateScalar vtB expr
- ; return (fvs, AnnLet (AnnRec extBnds) extExpr)
+ (VISimple, True) -> return $ encaps vt ce
+ _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
+ ; let (extBnds, vtBnds') = unzip extBndsVts
+ ; (extExpr, vtB') <- encapsulateScalar vtB expr
+ ; let vt' = VITNode vi (vtB':vtBnds')
+ ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
}
}
where
expBndg vit (bndr, expr)
- = do { extExpr <- encapsulateScalar vit expr
- ; return (bndr, extExpr)
+ = do { (extExpr, vit') <- encapsulateScalar vit expr
+ ; return ((bndr, extExpr), vit')
}
encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2)
@@ -320,9 +365,9 @@ encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2)
-encapsulateScalar (VITNode _ [vit]) (fvs, AnnCast expr coercion)
- = do { extExpr <- encapsulateScalar vit expr
- ; return (fvs, AnnCast extExpr coercion)
+encapsulateScalar (VITNode vi [vit]) (fvs, AnnCast expr coercion)
+ = do { (extExpr, vit') <- encapsulateScalar vit expr
+ ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
}
encapsulateScalar _ (_fvs, AnnCast _expr _coercion)
@@ -333,56 +378,75 @@ encapsulateScalar _ _
= panic "encapsulateScalar case not handled"
-mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
-mkAnnLam bndr ce = AnnLam bndr ce
-mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
-mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
-mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
-
-mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
-mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
-
-mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
-mkAnnApps (fv, aex') [] = (fv, aex')
-mkAnnApps ae (v:vs) =
- let
- (fv, aex') = mkAnnApps ae vs
- in (extendVarSet fv v, mkAnnApp (fv, aex') v)
-- CoreExprWithFVs, -- = AnnExpr Id VarSet
-- AnnExpr bndr VarSet = (annot, AnnExpr' bndr VarSet)
-- AnnLam :: bndr -> (AnnExpr bndr VarSet) -> AnnExpr' bndr VarSet
-- AnnLam bndr (AnnExpr bndr annot)
-encaps :: CoreExprWithFVs -> CoreExprWithFVs
-encaps (fvs, AnnCase expr bndr t alts)
+encaps :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
+encaps (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts)
| Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
(not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- TODO: globalScalarTyCons
- = (fvs, AnnCase expr bndr t (map (\(ac, bndrs, aex) -> (ac, bndrs, encaps aex)) alts))
+ = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))
-encaps ae@(fvs, _annEx)
- = let
- vars = varSetElems fvs
- in mkAnnApps (mkAnnLams ae vars) vars
+ where
+ (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $
+ zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, encaps altVi aex)) alts altVits
+
+encaps viTree ae@(fvs, _annEx)
+ = (mkAnnApps (mkAnnLams ae vars) vars, viTree')
+ where
+ mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
+ mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
+ mkViTreeApps vi [] = vi
+ mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
+
+ vars = varSetElems fvs
+ viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
+
+ mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
+ mkAnnLam bndr ce = AnnLam bndr ce
+
+ mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
+ mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
+ mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
+
+ mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
+ mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
+
+ mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
+ mkAnnApps (fv, aex') [] = (fv, aex')
+ mkAnnApps ae (v:vs) =
+ let
+ (fv, aex') = mkAnnApps ae vs
+ in (extendVarSet fv v, mkAnnApp (fv, aex') v)
+
+
+
+
+
-- |Vectorise an expression.
--
-vectExpr :: CoreExprWithFVs -> VM VExpr
-
-vectExpr (_, AnnVar v)
+vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
+-- vectExpr e vi | not (checkTree vi (deAnnotate e))
+-- = pprPanic "vectExpr" (ppr $ deAnnotate e)
+
+vectExpr (_, AnnVar v) _
= vectVar v
-vectExpr (_, AnnLit lit)
+vectExpr (_, AnnLit lit) _
= vectConst $ Lit lit
-vectExpr e@(_, AnnLam bndr _)
- | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e
+vectExpr e@(_, AnnLam bndr _) vt
+ | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
-- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
-- happy.
-- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
-vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
+vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _
| v == pAT_ERROR_ID
= do { (vty, lty) <- vectAndLiftType ty
; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
@@ -392,13 +456,13 @@ vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
-- type application (handle multiple consecutive type applications simultaneously to ensure the
-- PA dictionaries are put at the right places)
-vectExpr e@(_, AnnApp _ arg)
+vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
| isAnnTypeArg arg
= vectPolyApp e
-- 'Int', 'Float', or 'Double' literal
-- FIXME: this needs to be generalised
-vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
+vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
| Just con <- isDataConId_maybe v
, is_special_con con
= do
@@ -409,17 +473,17 @@ vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
-- value application (dictionary or user value)
-vectExpr e@(_, AnnApp fn arg)
+vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2])
| isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
= vectPolyApp e
| otherwise -- user value
= do { -- vectorise the types
- ; varg_ty <- vectType arg_ty
+ ; varg_ty <- vectType arg_ty
; vres_ty <- vectType res_ty
-- vectorise the function and argument expression
- ; vfn <- vectExpr fn
- ; varg <- vectExpr arg
+ ; vfn <- vectExpr fn vit1
+ ; varg <- vectExpr arg vit2
-- the vectorised function is a closure; apply it to the vectorised argument
; mkClosureApp varg_ty vres_ty vfn varg
@@ -427,42 +491,43 @@ vectExpr e@(_, AnnApp fn arg)
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
-vectExpr (_, AnnCase scrut bndr ty alts)
+vectExpr (_, AnnCase scrut bndr ty alts) vt
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
- = vectAlgCase tycon ty_args scrut bndr ty alts
+ = vectAlgCase tycon ty_args scrut bndr ty alts vt
| otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
where
scrut_ty = exprType (deAnnotate scrut)
-vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
+vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2])
= do
- vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExprVT False [] rhs vt1
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
return $ vLet (vNonRec vbndr vrhs) vbody
-vectExpr (_, AnnLet (AnnRec bs) body)
+vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
= do
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
- (zipWithM vect_rhs bndrs rhss)
- (vectExpr body)
+ (zipWith3M vect_rhs bndrs rhss vtBnds)
+ (vectExpr body vtB)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip bs
- vect_rhs bndr rhs = localV
- . inBind bndr
- . liftM (\(_,_,z)->z)
- $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs
+ vect_rhs bndr rhs vt = localV
+ . inBind bndr
+ . liftM (\(_,_,z)->z)
+ $ vectPolyExprVT (isStrongLoopBreaker $ idOccInfo bndr) [] rhs vt
+ zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
-vectExpr (_, AnnTick tickish expr)
- = liftM (vTick tickish) (vectExpr expr)
+vectExpr (_, AnnTick tickish expr) (VITNode _ [vit])
+ = liftM (vTick tickish) (vectExpr expr vit)
-vectExpr (_, AnnType ty)
+vectExpr (_, AnnType ty) _
= liftM vType (vectType ty)
-vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
+vectExpr e _ = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
-- |Vectorise an expression that *may* have an outer lambda abstraction.
--
@@ -475,23 +540,29 @@ vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether
-> Bool -- ^ Whether the binding is a loop breaker
-> [Var] -- ^ Names of function in same recursive binding group
-> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
+ -> VITree
-> VM (Inline, Bool, VExpr)
-vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body)
+
+-- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e))
+-- = pprPanic "vectFnExpr" (ppr $ deAnnotate e)
+
+
+vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
-- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
| isId bndr
&& isPredTy (idType bndr)
= do { vBndr <- vectBndr bndr
- ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body
+ ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
}
-- non-predicate abstraction: vectorise (try to vectorise as a scalar computation)
| isId bndr
- = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
+ = mark DontInline True (vectScalarFunVT False recFns (deAnnotate expr) vt)
`orElseV`
- mark inlineMe False (vectLam inline loop_breaker expr)
-vectFnExpr _ _ _ e
+ mark inlineMe False (vectLam inline loop_breaker expr vt)
+vectFnExpr _ _ _ e vt
-- not an abstraction: vectorise as a vanilla expression
- = mark DontInline False $ vectExpr e
+ = mark DontInline False $ vectExpr e vt
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
@@ -611,7 +682,7 @@ vectDictExpr (Coercion coe)
-- |Vectorise an expression of functional type, where all arguments and the result are of primitive
-- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
-- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
--- requires the full blown vectorisation transformation; instead, they can be lifted by application
+-- require the full blown vectorisation transformation; instead, they can be lifted by application
-- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
--
-- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
@@ -622,7 +693,18 @@ vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user?
-> [Var] -- ^ Functions names in same recursive binding group
-> CoreExpr -- ^ Expression to be vectorised
-> VM VExpr
-vectScalarFun forceScalar recFns expr
+vectScalarFun forceScalar recFns expr
+ = vectScalarFunVT forceScalar recFns expr (VITNode VISimple [])
+
+
+
+
+vectScalarFunVT :: Bool -- ^ Was the function marked as scalar by the user?
+ -> [Var] -- ^ Functions names in same recursive binding group
+ -> CoreExpr -- ^ Expression to be vectorised
+ -> VITree
+ -> VM VExpr
+vectScalarFunVT forceScalar recFns expr (VITNode vi _)
= do { gscalarVars <- globalScalarVars
; scalarTyCons <- globalScalarTyCons
; let scalarVars = gscalarVars `extendVarSetList` recFns
@@ -632,12 +714,15 @@ vectScalarFun forceScalar recFns expr
"\n\tall tycons scalar? : " ++ (show $all (is_scalar_ty scalarTyCons) arg_tys) ++
"\n\tresult scalar? : " ++ (show $is_scalar_ty scalarTyCons res_ty) ++
"\n\tscalar body? : " ++ (show $is_scalar scalarVars (is_scalar_ty scalarTyCons) expr) ++
- "\n\tuses vars? : " ++ (show $uses scalarVars expr)
+ "\n\tuses vars? : " ++ (show $uses scalarVars expr) ++
+ "\n\t is encaps? : " ++ (show vi)
)
(ppr expr)
; onlyIfV (ptext (sLit "not a scalar function"))
(forceScalar -- user asserts the functions is scalar
||
+ (vi == VIEncaps) -- should only be true if all the foll. cond are hold
+ ||
all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
&& is_scalar_ty scalarTyCons res_ty
&& is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
@@ -664,25 +749,17 @@ vectScalarFun forceScalar recFns expr
-- any 'scalarTyCons', but can't at the moment, as those argument and result types
-- need to be members of the 'Scalar' class (that in its current form would better
-- be called 'Primitive'). *ALSO* the hardcoded list of types is ugly!
- is_primitive_ty ty
- | isPredTy ty -- dictionaries never get into the environment
- = True
- | Just (tycon, _) <- splitTyConApp_maybe ty
- = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
- | otherwise
- = False
- -}
-
+ -}
is_scalar_ty _scalarTyCons ty
| isPredTy ty -- dictionaries never get into the environment
= True
| Just (tycon, []) <- splitTyConApp_maybe ty -- TODO: FIX THIS!
- = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
--- = tyConName tycon `elemNameSet` scalarTyCons
+ = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
+-- FIXME: = tyConName tycon `elemNameSet` scalarTyCons
| Just (tycon, _) <- splitTyConApp_maybe ty
- = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
+ = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
--- = tyConName tycon `elemNameSet` scalarTyCons
+-- FIXME: = tyConName tycon `elemNameSet` scalarTyCons
| otherwise
= False
@@ -706,7 +783,9 @@ vectScalarFun forceScalar recFns expr
| maybe_parr_ty (varType var) = False
| otherwise = is_scalar (scalars `extendVarSet` var)
isScalarTC body
- is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++ " " ++ (show $ is_scalar scalars' isScalarTC body) ++ (show $ showSDoc $ ppr bind)) $
+ is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++
+ " " ++ (show $ is_scalar scalars' isScalarTC body) ++
+ (show $ showSDoc $ ppr bind)) $
bindsAreScalar &&
is_scalar scalars' isScalarTC body
where
@@ -884,8 +963,9 @@ unVectDict ty e
vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
-> Bool -- ^ Whether the binding is a loop breaker.
-> CoreExprWithFVs -- ^ Body of abstraction.
+ -> VITree
-> VM VExpr
-vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
+vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
= do { let (bndrs, body) = collectAnnValBinders expr
-- grab the in-scope type variables
@@ -913,13 +993,18 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
. hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
$ do { -- generate the vectorised body of the lambda abstraction
; lc <- builtin liftingContext
- ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body)
+ ; let viBody = stripLams expr vi
+ -- ; checkTreeAnnM vi expr
+ ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
; vbody' <- break_loop lc res_ty vbody
; return $ vLams lc vbndrs vbody'
}
}
where
+ stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt
+ stripLams _ vi = vi
+
maybe_inline n | inline = Inline n
| otherwise = DontInline
@@ -937,7 +1022,7 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
(LitAlt (mkMachInt 0), [], empty)])
}
| otherwise = return (ve, le)
-vectLam _ _ _ = panic "vectLam"
+vectLam _ _ _ _ = panic "vectLam"
-- | Vectorise an algebraic case expression.
-- We convert
@@ -955,31 +1040,31 @@ vectLam _ _ _ = panic "vectLam"
--
-- FIXME: this is too lazy
-vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
- -> [(AltCon, [Var], CoreExprWithFVs)]
+vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type
+ -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree
-> VM VExpr
-vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
+vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
= do
- vscrut <- vectExpr scrut
+ vscrut <- vectExpr scrut scrutVit
(vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
+vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
= do
- vscrut <- vectExpr scrut
+ vscrut <- vectExpr scrut scrutVit
(vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
+vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
= do
(vty, lty) <- vectAndLiftType ty
- vexpr <- vectExpr scrut
+ vexpr <- vectExpr scrut scrutVit
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
- $ vectExpr body
+ $ vectExpr body altVit
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
vect_dc <- maybeV dataConErr (lookupDataCon dc)
@@ -997,7 +1082,7 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
-vectAlgCase tycon _ty_args scrut bndr ty alts
+vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
= do
vect_tc <- maybeV tyConErr (lookupTyCon tycon)
(vty, lty) <- vectAndLiftType ty
@@ -1008,10 +1093,10 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
- $ mapM (proc_alt arity sel vty lty) alts'
+ $ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
- vexpr <- vectExpr scrut
+ vexpr <- vectExpr scrut scrutVit
(vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
let (vect_bodies, lift_bodies) = unzip vbodies
@@ -1043,7 +1128,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
- proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
+ proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi)
= do
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
@@ -1061,7 +1146,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
- (ve, le) <- vectExpr body
+ (ve, le) <- vectExpr body vi
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
-- empty <- emptyPD vty
@@ -1092,4 +1177,64 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
return [(NonRec lv' expr)]
_ -> return []
+
+vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ [])
+ = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)
+
+---- Sanity check of the
+{-
+checkTree :: VITree -> CoreExpr -> Bool
+checkTree (VITNode _ []) (Type _ty)
+ = True
+
+checkTree (VITNode _ []) (Var _v)
+ = True
+
+
+checkTree (VITNode _ []) (Lit _)
+ = True
+
+
+checkTree (VITNode _ [vit]) (Tick _ expr)
+ = checkTree vit expr
+
+
+
+checkTree (VITNode _ [vit]) (Lam _ expr)
+ = checkTree vit expr
+
+
+checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2)
+ = (checkTree vit1 ce1) && (checkTree vit2 ce2)
+
+
+
+checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts)
+ = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts)
+ where
+ checkAlt vt (_, _, expr) = checkTree vt expr
+
+
+checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2)
+ = (checkTree vt1 expr1) && (checkTree vt2 expr2)
+
+
+
+checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr)
+ = (and $ zipWith checkBndr vtBnds bndngs) &&
+ (checkTree vtB expr)
+ where
+ checkBndr vt (_, e) = checkTree vt e
+
+
+checkTree (VITNode _ [vit]) (Cast expr _)
+ = checkTree vit expr
+
+checkTree _ _ = False
+checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM ()
+checkTreeAnnM vi e =
+ if not (checkTree vi $ deAnnotate e)
+ then error ("checkTreeAnnM : \n " ++ show vi)
+ else return ()
+-}