summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/vectorise/Vectorise.hs')
-rw-r--r--compiler/vectorise/Vectorise.hs34
1 files changed, 33 insertions, 1 deletions
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs
index cd1f429454..bee160c467 100644
--- a/compiler/vectorise/Vectorise.hs
+++ b/compiler/vectorise/Vectorise.hs
@@ -264,12 +264,44 @@ vectExpr (_, AnnLet (AnnRec bs) body)
$ vectExpr rhs
vectExpr e@(fvs, AnnLam bndr _)
- | isId bndr = vectLam fvs bs body
+ | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
+ `orElseV` vectLam fvs bs body
where
(bs,body) = collectAnnValBinders e
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
+vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
+vectScalarLam args body
+ = do
+ scalars <- globalScalars
+ onlyIfV (all is_scalar_ty arg_tys
+ && is_scalar_ty res_ty
+ && is_scalar (extendVarSetList scalars args) body)
+ $ do
+ fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
+ zipf <- zipScalars arg_tys res_ty
+ clo <- scalarClosure arg_tys res_ty (Var fn_var)
+ (zipf `App` Var fn_var)
+ clo_var <- hoistExpr (fsLit "clo") clo
+ lclo <- liftPA (Var clo_var)
+ return (Var clo_var, lclo)
+ where
+ arg_tys = map idType args
+ res_ty = exprType body
+
+ is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
+ = tycon == intTyCon
+ || tycon == floatTyCon
+ || tycon == doubleTyCon
+
+ | otherwise = False
+
+ is_scalar vs (Var v) = v `elemVarSet` vs
+ is_scalar _ e@(Lit l) = is_scalar_ty $ exprType e
+ is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
+ is_scalar _ _ = False
+
vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam fvs bs body
= do