summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise/Exp.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/vectorise/Vectorise/Exp.hs')
-rw-r--r--compiler/vectorise/Vectorise/Exp.hs88
1 files changed, 46 insertions, 42 deletions
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs
index 6d6a473b44..2b7accc646 100644
--- a/compiler/vectorise/Vectorise/Exp.hs
+++ b/compiler/vectorise/Vectorise/Exp.hs
@@ -26,6 +26,7 @@ import CoreFVs
import DataCon
import TyCon
import Type
+import NameSet
import Var
import VarEnv
import VarSet
@@ -42,11 +43,11 @@ import Data.List
-- | Vectorise a polymorphic expression.
--
-vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
- -- binding is a loop breaker.
- -> [Var]
- -> CoreExprWithFVs
- -> VM (Inline, Bool, VExpr)
+vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
+ -- binding is a loop breaker.
+ -> [Var]
+ -> CoreExprWithFVs
+ -> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
= do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
return (inline, isScalarFn, vNote note expr')
@@ -194,26 +195,24 @@ vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user?
-> CoreExpr -- ^ Expression to be vectorised
-> VM VExpr
vectScalarFun forceScalar recFns expr
- = do { gscalars <- globalScalars
- ; let scalars = gscalars `extendVarSetList` recFns
+ = do { gscalarVars <- globalScalarVars
+ ; scalarTyCons <- globalScalarTyCons
+ ; let scalarVars = gscalarVars `extendVarSetList` recFns
(arg_tys, res_ty) = splitFunTys (exprType expr)
; MASSERT( not $ null arg_tys )
- ; onlyIfV (forceScalar -- user asserts the functions is scalar
+ ; onlyIfV (forceScalar -- user asserts the functions is scalar
||
- all is_prim_ty arg_tys -- check whether the function is scalar
- && is_prim_ty res_ty
- && is_scalar scalars expr
- && uses scalars expr)
+ all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
+ && is_scalar_ty scalarTyCons res_ty
+ && is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
+ && uses scalarVars expr)
$ mkScalarFun arg_tys res_ty expr
}
where
- -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!!
- is_prim_ty ty
- | Just (tycon, []) <- splitTyConApp_maybe ty
- = tycon == intTyCon
- || tycon == floatTyCon
- || tycon == doubleTyCon
- | otherwise = False
+ is_scalar_ty scalarTyCons ty
+ | Just (tycon, _) <- splitTyConApp_maybe ty
+ = tyConName tycon `elemNameSet` scalarTyCons
+ | otherwise = False
-- Checks whether an expression contain a non-scalar subexpression.
--
@@ -223,40 +222,45 @@ vectScalarFun forceScalar recFns expr
-- them to the list of scalar variables) and then check them. If one of them turns out not to
-- be scalar, the entire group is regarded as not being scalar.
--
- -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
- -- data constructor as scalar. Should be changed once scalar types are passed
- -- through VectInfo.
+ -- The second argument is a predicate that checks whether a type is scalar.
--
- is_scalar :: VarSet -> CoreExpr -> Bool
- is_scalar scalars (Var v) = v `elemVarSet` scalars
- is_scalar _scalars (Lit _) = True
- is_scalar scalars e@(App e1 e2)
- | maybe_parr_ty (exprType e) = False
- | otherwise = is_scalar scalars e1 && is_scalar scalars e2
- is_scalar scalars (Lam var body)
- | maybe_parr_ty (varType var) = False
- | otherwise = is_scalar (scalars `extendVarSet` var) body
- is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body
+ is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool
+ is_scalar scalars _isScalarTC (Var v) = v `elemVarSet` scalars
+ is_scalar _scalars _isScalarTC (Lit _) = True
+ is_scalar scalars isScalarTC e@(App e1 e2)
+ | maybe_parr_ty (exprType e) = False
+ | otherwise = is_scalar scalars isScalarTC e1 &&
+ is_scalar scalars isScalarTC e2
+ is_scalar scalars isScalarTC (Lam var body)
+ | maybe_parr_ty (varType var) = False
+ | otherwise = is_scalar (scalars `extendVarSet` var)
+ isScalarTC body
+ is_scalar scalars isScalarTC (Let bind body) = bindsAreScalar &&
+ is_scalar scalars' isScalarTC body
where
- (bindsAreScalar, scalars') = is_scalar_bind scalars bind
- is_scalar scalars (Case e var ty alts)
- | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts
+ (bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind
+ is_scalar scalars isScalarTC (Case e var ty alts)
+ | isScalarTC ty = is_scalar scalars' isScalarTC e &&
+ all (is_scalar_alt scalars' isScalarTC) alts
| otherwise = False
where
scalars' = scalars `extendVarSet` var
- is_scalar scalars (Cast e _coe) = is_scalar scalars e
- is_scalar scalars (Note _ e ) = is_scalar scalars e
- is_scalar _scalars (Type {}) = True
- is_scalar _scalars (Coercion {}) = True
+ is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e
+ is_scalar scalars isScalarTC (Note _ e ) = is_scalar scalars isScalarTC e
+ is_scalar _scalars _isScalarTC (Type {}) = True
+ is_scalar _scalars _isScalarTC (Coercion {}) = True
-- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
- is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
- is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars')
+ is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e,
+ scalars `extendVarSet` var)
+ is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es,
+ scalars')
where
(vars, es) = unzip bnds
scalars' = scalars `extendVarSetList` vars
- is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
+ is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars)
+ isScalarTCs e
-- Checks whether the type might be a parallel array type. In particular, if the outermost
-- constructor is a type family, we conservatively assume that it may be a parallel array type.