summaryrefslogtreecommitdiff
path: root/compiler/vectorise/VectCore.hs
blob: 248bcb62e10fced651139bcce77cdf8ce67df56c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
{-# OPTIONS -w #-}
-- The above warning supression flag is a temporary kludge.
-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
-- for details

module VectCore (
  Vect, VVar, VExpr, VBind,

  vectorised, lifted,
  mapVect,

  vNonRec, vRec,

  vVar, vType, vNote, vLet,
  vLams, vLamsWithoutLC, vVarApps,
  vCaseDEFAULT, vCaseProd
) where

#include "HsVersions.h"

import CoreSyn
import CoreUtils      ( exprType )
import DataCon        ( DataCon )
import Type           ( Type )
import Id             ( mkWildId )
import Var

type Vect a = (a,a)
type VVar   = Vect Var
type VExpr  = Vect CoreExpr
type VBind  = Vect CoreBind

vectorised :: Vect a -> a
vectorised = fst

lifted :: Vect a -> a
lifted = snd

mapVect :: (a -> b) -> Vect a -> Vect b
mapVect f (x,y) = (f x, f y)

zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)

vVar :: VVar -> VExpr
vVar = mapVect Var

vType :: Type -> VExpr
vType ty = (Type ty, Type ty)

vNote :: Note -> VExpr -> VExpr
vNote = mapVect . Note

vNonRec :: VVar -> VExpr -> VBind
vNonRec = zipWithVect NonRec

vRec :: [VVar] -> [VExpr] -> VBind
vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
  where
    (vvs, lvs) = unzip vs
    (ves, les) = unzip es

vLet :: VBind -> VExpr -> VExpr
vLet = zipWithVect Let

vLams :: Var -> [VVar] -> VExpr -> VExpr
vLams lc vs (ve, le) = (mkLams vvs ve, mkLams (lc:lvs) le)
  where
    (vvs,lvs) = unzip vs

vLamsWithoutLC :: [VVar] -> VExpr -> VExpr
vLamsWithoutLC vvs (ve,le) = (mkLams vs ve, mkLams ls le)
  where
    (vs,ls) = unzip vvs

vVarApps :: Var -> VExpr -> [VVar] -> VExpr
vVarApps lc (ve, le) vvs = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
  where
    (vs,ls) = unzip vvs 

vCaseDEFAULT :: VExpr -> VVar -> Type -> Type -> VExpr -> VExpr
vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
  = (Case vscrut vbndr vty (mkDEFAULT vbody),
     Case lscrut lbndr lty (mkDEFAULT lbody))
  where
    mkDEFAULT e = [(DEFAULT, [], e)]

vCaseProd :: VExpr -> Type -> Type
          -> DataCon -> DataCon -> [Var] -> [VVar] -> VExpr -> VExpr
vCaseProd (vscrut, lscrut) vty lty vdc ldc sh_bndrs bndrs
          (vbody,lbody)
  = (Case vscrut (mkWildId $ exprType vscrut) vty
          [(DataAlt vdc, vbndrs, vbody)],
     Case lscrut (mkWildId $ exprType lscrut) lty
          [(DataAlt ldc, sh_bndrs ++ lbndrs, lbody)])
  where
    (vbndrs, lbndrs) = unzip bndrs