summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise/Vect.hs
blob: 6dcffa2509168b69c7ee1d2f98630eb33f91d97e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

-- | Simple vectorised constructors and projections.
module Vectorise.Vect (
	Vect, VVar, VExpr, VBind,

	vectorised,
	lifted,
	mapVect,

	vVarType,
	vNonRec,
	vRec,
	vVar,
	vType,
	vNote,
	vLet,
	vLams,
	vLamsWithoutLC,
	vVarApps,
	vCaseDEFAULT
) where
import CoreSyn
import Type           ( Type )
import Var

-- | Contains the vectorised and lifted versions of some thing.
type Vect a = (a,a)
type VVar   = Vect Var
type VExpr  = Vect CoreExpr
type VBind  = Vect CoreBind


-- | Get the vectorised version of a thing.
vectorised :: Vect a -> a
vectorised = fst


-- | Get the lifted version of a thing.
lifted :: Vect a -> a
lifted = snd


-- | Apply some function to both the vectorised and lifted versions of a thing.
mapVect :: (a -> b) -> Vect a -> Vect b
mapVect f (x,y) = (f x, f y)


-- | Combine vectorised and lifted versions of two things componentwise.
zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)


-- | Get the type of a vectorised variable.
vVarType :: VVar -> Type
vVarType = varType . vectorised


-- | Wrap a vectorised variable as a vectorised expression.
vVar :: VVar -> VExpr
vVar = mapVect Var


-- | Wrap a vectorised type as a vectorised expression.
vType :: Type -> VExpr
vType ty = (Type ty, Type ty)


-- | Make a vectorised note.
vNote :: Note -> VExpr -> VExpr
vNote = mapVect . Note


-- | Make a vectorised non-recursive binding.
vNonRec :: VVar -> VExpr -> VBind
vNonRec = zipWithVect NonRec


-- | Make a vectorised recursive binding.
vRec :: [VVar] -> [VExpr] -> VBind
vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
  where
    (vvs, lvs) = unzip vs
    (ves, les) = unzip es


-- | Make a vectorised let expresion.
vLet :: VBind -> VExpr -> VExpr
vLet = zipWithVect Let


-- | Make a vectorised lambda abstraction.
--   The lifted version also binds the lifting context.
vLams	:: Var		-- ^ Var bound to the lifting context.
	-> [VVar]	-- ^ Parameter vars for the abstraction.
	-> VExpr	-- ^ Body of the abstraction.
	-> VExpr

vLams lc vs (ve, le) 
  = (mkLams vvs ve, mkLams (lc:lvs) le)
  where
    (vvs,lvs) = unzip vs


-- | Like `vLams` but the lifted version doesn't bind the lifting context.
vLamsWithoutLC :: [VVar] -> VExpr -> VExpr
vLamsWithoutLC vvs (ve,le) 
  = (mkLams vs ve, mkLams ls le)
  where
    (vs,ls) = unzip vvs


-- | Apply some argument variables to an expression.
--   The lifted version is also applied to the variable of the lifting context.
vVarApps :: Var -> VExpr -> [VVar] -> VExpr
vVarApps lc (ve, le) vvs 
  = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
  where
    (vs,ls) = unzip vvs 


vCaseDEFAULT
	:: VExpr	-- scrutiniy
	-> VVar		-- bnder
	-> Type		-- type of vectorised version
	-> Type		-- type of lifted version
	-> VExpr	-- body of alternative.
	-> VExpr

vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
  = (Case vscrut vbndr vty (mkDEFAULT vbody),
     Case lscrut lbndr lty (mkDEFAULT lbody))
  where
    mkDEFAULT e = [(DEFAULT, [], e)]