summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise/Vect.hs
blob: 1b0e57167c8298e987912d1e469a4326f3a00e0b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
-- |Simple vectorised constructors and projections.
--
module Vectorise.Vect
  ( Vect, VVar, VExpr, VBind

  , vectorised
  , lifted
  , mapVect

  , vVarType
  , vNonRec
  , vRec
  , vVar
  , vType
  , vTick
  , vLet
  , vLams
  , vVarApps
  , vCaseDEFAULT
  )
where

import GhcPrelude

import CoreSyn
import Type           ( Type )
import Var

-- |Contains the vectorised and lifted versions of some thing.
--
type Vect a = (a,a)
type VVar   = Vect Var
type VExpr  = Vect CoreExpr
type VBind  = Vect CoreBind

-- |Get the vectorised version of a thing.
--
vectorised :: Vect a -> a
vectorised = fst

-- |Get the lifted version of a thing.
--
lifted :: Vect a -> a
lifted = snd

-- |Apply some function to both the vectorised and lifted versions of a thing.
--
mapVect :: (a -> b) -> Vect a -> Vect b
mapVect f (x, y) = (f x, f y)

-- |Combine vectorised and lifted versions of two things componentwise.
--
zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
zipWithVect f (x1, y1) (x2, y2) = (f x1 x2, f y1 y2)

-- |Get the type of a vectorised variable.
--
vVarType :: VVar -> Type
vVarType = varType . vectorised

-- |Wrap a vectorised variable as a vectorised expression.
--
vVar :: VVar -> VExpr
vVar = mapVect Var

-- |Wrap a vectorised type as a vectorised expression.
--
vType :: Type -> VExpr
vType ty = (Type ty, Type ty)

-- |Make a vectorised note.
--
vTick :: Tickish Id -> VExpr -> VExpr
vTick = mapVect . Tick

-- |Make a vectorised non-recursive binding.
--
vNonRec :: VVar -> VExpr -> VBind
vNonRec = zipWithVect NonRec

-- |Make a vectorised recursive binding.
--
vRec :: [VVar] -> [VExpr] -> VBind
vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
  where
    (vvs, lvs) = unzip vs
    (ves, les) = unzip es

-- |Make a vectorised let expression.
--
vLet :: VBind -> VExpr -> VExpr
vLet = zipWithVect Let

-- |Make a vectorised lambda abstraction.
--
-- The lifted version also binds the lifting context 'lc'.
--
vLams :: Var      -- ^ Var bound to the lifting context.
      -> [VVar]   -- ^ Parameter vars for the abstraction.
      -> VExpr    -- ^ Body of the abstraction.
      -> VExpr
vLams lc vs (ve, le)
  = (mkLams vvs ve, mkLams (lc:lvs) le)
  where
    (vvs, lvs) = unzip vs

-- |Apply an expression to a set of argument variables.
--
-- The lifted version is also applied to the variable of the lifting context.
--
vVarApps :: Var -> VExpr -> [VVar] -> VExpr
vVarApps lc (ve, le) vvs
  = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
  where
    (vs, ls) = unzip vvs


vCaseDEFAULT :: VExpr  -- scrutinee
             -> VVar   -- bnder
             -> Type   -- type of vectorised version
             -> Type   -- type of lifted version
             -> VExpr  -- body of alternative.
             -> VExpr
vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
  = (Case vscrut vbndr vty (mkDEFAULT vbody),
     Case lscrut lbndr lty (mkDEFAULT lbody))
  where
    mkDEFAULT e = [(DEFAULT, [], e)]