diff options
Diffstat (limited to 'compiler/vectorise/Vectorise/Exp.hs')
| -rw-r--r-- | compiler/vectorise/Vectorise/Exp.hs | 88 |
1 files changed, 46 insertions, 42 deletions
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 6d6a473b44..2b7accc646 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -26,6 +26,7 @@ import CoreFVs import DataCon import TyCon import Type +import NameSet import Var import VarEnv import VarSet @@ -42,11 +43,11 @@ import Data.List -- | Vectorise a polymorphic expression. -- -vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that - -- binding is a loop breaker. - -> [Var] - -> CoreExprWithFVs - -> VM (Inline, Bool, VExpr) +vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that + -- binding is a loop breaker. + -> [Var] + -> CoreExprWithFVs + -> VM (Inline, Bool, VExpr) vectPolyExpr loop_breaker recFns (_, AnnNote note expr) = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr return (inline, isScalarFn, vNote note expr') @@ -194,26 +195,24 @@ vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user? -> CoreExpr -- ^ Expression to be vectorised -> VM VExpr vectScalarFun forceScalar recFns expr - = do { gscalars <- globalScalars - ; let scalars = gscalars `extendVarSetList` recFns + = do { gscalarVars <- globalScalarVars + ; scalarTyCons <- globalScalarTyCons + ; let scalarVars = gscalarVars `extendVarSetList` recFns (arg_tys, res_ty) = splitFunTys (exprType expr) ; MASSERT( not $ null arg_tys ) - ; onlyIfV (forceScalar -- user asserts the functions is scalar + ; onlyIfV (forceScalar -- user asserts the functions is scalar || - all is_prim_ty arg_tys -- check whether the function is scalar - && is_prim_ty res_ty - && is_scalar scalars expr - && uses scalars expr) + all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar + && is_scalar_ty scalarTyCons res_ty + && is_scalar scalarVars (is_scalar_ty scalarTyCons) expr + && uses scalarVars expr) $ mkScalarFun arg_tys res_ty expr } where - -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!! - is_prim_ty ty - | Just (tycon, []) <- splitTyConApp_maybe ty - = tycon == intTyCon - || tycon == floatTyCon - || tycon == doubleTyCon - | otherwise = False + is_scalar_ty scalarTyCons ty + | Just (tycon, _) <- splitTyConApp_maybe ty + = tyConName tycon `elemNameSet` scalarTyCons + | otherwise = False -- Checks whether an expression contain a non-scalar subexpression. -- @@ -223,40 +222,45 @@ vectScalarFun forceScalar recFns expr -- them to the list of scalar variables) and then check them. If one of them turns out not to -- be scalar, the entire group is regarded as not being scalar. -- - -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous - -- data constructor as scalar. Should be changed once scalar types are passed - -- through VectInfo. + -- The second argument is a predicate that checks whether a type is scalar. -- - is_scalar :: VarSet -> CoreExpr -> Bool - is_scalar scalars (Var v) = v `elemVarSet` scalars - is_scalar _scalars (Lit _) = True - is_scalar scalars e@(App e1 e2) - | maybe_parr_ty (exprType e) = False - | otherwise = is_scalar scalars e1 && is_scalar scalars e2 - is_scalar scalars (Lam var body) - | maybe_parr_ty (varType var) = False - | otherwise = is_scalar (scalars `extendVarSet` var) body - is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body + is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool + is_scalar scalars _isScalarTC (Var v) = v `elemVarSet` scalars + is_scalar _scalars _isScalarTC (Lit _) = True + is_scalar scalars isScalarTC e@(App e1 e2) + | maybe_parr_ty (exprType e) = False + | otherwise = is_scalar scalars isScalarTC e1 && + is_scalar scalars isScalarTC e2 + is_scalar scalars isScalarTC (Lam var body) + | maybe_parr_ty (varType var) = False + | otherwise = is_scalar (scalars `extendVarSet` var) + isScalarTC body + is_scalar scalars isScalarTC (Let bind body) = bindsAreScalar && + is_scalar scalars' isScalarTC body where - (bindsAreScalar, scalars') = is_scalar_bind scalars bind - is_scalar scalars (Case e var ty alts) - | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts + (bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind + is_scalar scalars isScalarTC (Case e var ty alts) + | isScalarTC ty = is_scalar scalars' isScalarTC e && + all (is_scalar_alt scalars' isScalarTC) alts | otherwise = False where scalars' = scalars `extendVarSet` var - is_scalar scalars (Cast e _coe) = is_scalar scalars e - is_scalar scalars (Note _ e ) = is_scalar scalars e - is_scalar _scalars (Type {}) = True - is_scalar _scalars (Coercion {}) = True + is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e + is_scalar scalars isScalarTC (Note _ e ) = is_scalar scalars isScalarTC e + is_scalar _scalars _isScalarTC (Type {}) = True + is_scalar _scalars _isScalarTC (Coercion {}) = True -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group) - is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var) - is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars') + is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e, + scalars `extendVarSet` var) + is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es, + scalars') where (vars, es) = unzip bnds scalars' = scalars `extendVarSetList` vars - is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e + is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) + isScalarTCs e -- Checks whether the type might be a parallel array type. In particular, if the outermost -- constructor is a type family, we conservatively assume that it may be a parallel array type. |
