summaryrefslogtreecommitdiff
path: root/backend/src/llvm/llvm_passes.cpp
blob: f5d9052183aeb6f12c0a52b98a493ab9598f174a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
/* 
 * Copyright © 2012 Intel Corporation
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library. If not, see <http://www.gnu.org/licenses/>.
 *
 * Author: Benjamin Segovia <benjamin.segovia@intel.com>
 *         Heldge RHodin <alice.rhodin@alice-dsl.net>
 */

/**
 * \file llvm_passes.cpp
 * \author Benjamin Segovia <benjamin.segovia@intel.com>
 * \author Heldge RHodin <alice.rhodin@alice-dsl.net>
 */

/* THIS CODE IS DERIVED FROM GPL LLVM PTX BACKEND. CODE IS HERE:
 * http://sourceforge.net/scm/?type=git&group_id=319085
 * Note that however, the original author, Heldge Rhodin, granted me (Benjamin
 * Segovia) the right to use another license for it (MIT here)
 */

#include "llvm_includes.hpp"

#include "llvm/llvm_gen_backend.hpp"
#include "ir/unit.hpp"
#include "sys/map.hpp"

using namespace llvm;

namespace gbe
{
  bool isKernelFunction(const llvm::Function &F) {
    bool bKernel = false;
#if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9
    bKernel = F.getMetadata("kernel_arg_name") != NULL;
#else
    const Module *module = F.getParent();
    const Module::NamedMDListType& globalMD = module->getNamedMDList();
    for(auto i = globalMD.begin(); i != globalMD.end(); i++) {
      const NamedMDNode &md = *i;
      if(strcmp(md.getName().data(), "opencl.kernels") != 0) continue;
      uint32_t ops = md.getNumOperands();
      for(uint32_t x = 0; x < ops; x++) {
        MDNode* node = md.getOperand(x);
#if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR <= 5
        Value * op = node->getOperand(0);
#else
        Value * op = cast<ValueAsMetadata>(node->getOperand(0))->getValue();
#endif
        if(op == &F) bKernel = true;
      }
    }
#endif
    return bKernel;
  }

  uint32_t getModuleOclVersion(const llvm::Module *M) {
    uint32_t oclVersion = 120;
    NamedMDNode *version = M->getNamedMetadata("opencl.ocl.version");
    if (version == NULL)
      return oclVersion;
    uint32_t ops = version->getNumOperands();
    if(ops > 0) {
      uint32_t major = 0, minor = 0;
      MDNode* node = version->getOperand(0);
#if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 6
      major = mdconst::extract<ConstantInt>(node->getOperand(0))->getZExtValue();
      minor = mdconst::extract<ConstantInt>(node->getOperand(1))->getZExtValue();
#else
      major = cast<ConstantInt>(MD->getOperand(0))->getZExtValue();
      minor = cast<ConstantInt>(MD->getOperand(1))->getZExtValue();
#endif
      oclVersion = major * 100 + minor * 10;
    }
    return oclVersion;
  }

  int32_t getPadding(int32_t offset, int32_t align) {
    return (align - (offset % align)) % align; 
  }

  uint32_t getAlignmentByte(const ir::Unit &unit, Type* Ty)
  {
    switch (Ty->getTypeID()) {
      case Type::VoidTyID: NOT_SUPPORTED;
      case Type::VectorTyID:
      {
        const VectorType* VecTy = cast<VectorType>(Ty);
        uint32_t elemNum = VecTy->getNumElements();
        if (elemNum == 3) elemNum = 4; // OCL spec
        return elemNum * getTypeByteSize(unit, VecTy->getElementType());
      }
      case Type::PointerTyID:
      case Type::IntegerTyID:
      case Type::FloatTyID:
      case Type::DoubleTyID:
      case Type::HalfTyID:
        return getTypeBitSize(unit, Ty)/8;
      case Type::ArrayTyID:
        return getAlignmentByte(unit, cast<ArrayType>(Ty)->getElementType());
      case Type::StructTyID:
      {
        const StructType* StrTy = cast<StructType>(Ty);
        uint32_t maxa = 0;
        for(uint32_t subtype = 0; subtype < StrTy->getNumElements(); subtype++)
        {
          maxa = std::max(getAlignmentByte(unit, StrTy->getElementType(subtype)), maxa);
        }
        return maxa;
      }
      default: NOT_SUPPORTED;
    }
    return 0u;
  }

  uint32_t getTypeBitSize(const ir::Unit &unit, Type* Ty)
  {
    switch (Ty->getTypeID()) {
      case Type::VoidTyID:    NOT_SUPPORTED;
      case Type::PointerTyID: return unit.getPointerSize();
      case Type::IntegerTyID:
      {
        // use S16 to represent SLM bool variables.
        int bitWidth = cast<IntegerType>(Ty)->getBitWidth();
        return (bitWidth == 1) ? 16 : bitWidth;
      }
      case Type::HalfTyID:    return 16;
      case Type::FloatTyID:   return 32;
      case Type::DoubleTyID:  return 64;
      case Type::VectorTyID:
      {
        const VectorType* VecTy = cast<VectorType>(Ty);
        uint32_t numElem = VecTy->getNumElements();
        if(numElem == 3) numElem = 4; // OCL spec
        return numElem * getTypeBitSize(unit, VecTy->getElementType());
      }
      case Type::ArrayTyID:
      {
        const ArrayType* ArrTy = cast<ArrayType>(Ty);
        Type* elementType = ArrTy->getElementType();
        uint32_t size_element = getTypeBitSize(unit, elementType);
        uint32_t size = ArrTy->getNumElements() * size_element;
        uint32_t align = 8 * getAlignmentByte(unit, elementType);
        size += (ArrTy->getNumElements()-1) * getPadding(size_element, align);
        return size;
      }
      case Type::StructTyID:
      {
        const StructType* StrTy = cast<StructType>(Ty);
        uint32_t size = 0;
        for(uint32_t subtype=0; subtype < StrTy->getNumElements(); subtype++)
        {
          Type* elementType = StrTy->getElementType(subtype);
          uint32_t align = 8 * getAlignmentByte(unit, elementType);
          size += getPadding(size, align);
          size += getTypeBitSize(unit, elementType);
        }
        return size;
      }
      default: NOT_SUPPORTED;
    }
    return 0u;
  }

  uint32_t getTypeByteSize(const ir::Unit &unit, Type* Ty)
  {
    uint32_t size_bit = getTypeBitSize(unit, Ty);
    assert((size_bit%8==0) && "no multiple of 8");
    return size_bit/8;
  }

  int32_t getGEPConstOffset(const ir::Unit &unit, CompositeType *CompTy, int32_t TypeIndex) {
    int32_t offset = 0;
    SequentialType * seqType = dyn_cast<SequentialType>(CompTy);
    if (seqType != NULL) {
      if (TypeIndex != 0) {
        Type *elementType = seqType->getElementType();
        uint32_t elementSize = getTypeByteSize(unit, elementType);
        uint32_t align = getAlignmentByte(unit, elementType);
        elementSize += getPadding(elementSize, align);
        offset = elementSize * TypeIndex;
      }
    } else {
      int32_t step = TypeIndex > 0 ? 1 : -1;
      GBE_ASSERT(CompTy->isStructTy());
      for(int32_t ty_i=0; ty_i != TypeIndex; ty_i += step)
      {
        Type* elementType = CompTy->getTypeAtIndex(ty_i);
        uint32_t align = getAlignmentByte(unit, elementType);
        offset += getPadding(offset, align * step);
        offset += getTypeByteSize(unit, elementType) * step;
      }

      //add getPaddingding for accessed type
      const uint32_t align = getAlignmentByte(unit, CompTy->getTypeAtIndex(TypeIndex));
      offset += getPadding(offset, align * step);
    }
    return offset;
  }

  class GenRemoveGEPPasss : public BasicBlockPass
  {

   public:
    static char ID;
    GenRemoveGEPPasss(const ir::Unit &unit) :
      BasicBlockPass(ID),
      unit(unit) {}
    const ir::Unit &unit;
    void getAnalysisUsage(AnalysisUsage &AU) const {
      AU.setPreservesCFG();
    }

    virtual const char *getPassName() const {
      return "SPIR backend: insert special spir instructions";
    }

    bool simplifyGEPInstructions(GetElementPtrInst* GEPInst);

    virtual bool runOnBasicBlock(BasicBlock &BB)
    {
      bool changedBlock = false;
      iplist<Instruction>::iterator I = BB.getInstList().begin();
      for (auto nextI = I, E = --BB.getInstList().end(); I != E; I = nextI) {
        iplist<Instruction>::iterator I = nextI++;
        if(GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(&*I))
          changedBlock = (simplifyGEPInstructions(gep) || changedBlock);
      }
      return changedBlock;
    }
  };

  char GenRemoveGEPPasss::ID = 0;

  bool GenRemoveGEPPasss::simplifyGEPInstructions(GetElementPtrInst* GEPInst)
  {
    const uint32_t ptrSize = unit.getPointerSize();
    Value* parentPointer = GEPInst->getOperand(0);
    CompositeType* CompTy = parentPointer ? cast<CompositeType>(parentPointer->getType()) : NULL;
    if(!CompTy)
      return false;

    Value* currentAddrInst = 
      new PtrToIntInst(parentPointer, IntegerType::get(GEPInst->getContext(), ptrSize), "", GEPInst);

    int32_t constantOffset = 0;

    for(uint32_t op=1; op<GEPInst->getNumOperands(); ++op)
    {
      int32_t TypeIndex;
      ConstantInt* ConstOP = dyn_cast<ConstantInt>(GEPInst->getOperand(op));
      if (ConstOP != NULL) {
        TypeIndex = ConstOP->getZExtValue();
        constantOffset += getGEPConstOffset(unit, CompTy, TypeIndex);
      }
      else {
        // we only have array/vectors here, 
        // therefore all elements have the same size
        TypeIndex = 0;

        Type* elementType = CompTy->getTypeAtIndex(TypeIndex);
        uint32_t size = getTypeByteSize(unit, elementType);

        //add padding
        uint32_t align = getAlignmentByte(unit, elementType);
        size += getPadding(size, align);

        Constant* newConstSize = 
          ConstantInt::get(IntegerType::get(GEPInst->getContext(), ptrSize), size);

        Value *operand = GEPInst->getOperand(op); 

        if(!operand)
          continue;
#if 0
        //HACK TODO: Inserted by type replacement.. this code could break something????
        if(getTypeByteSize(unit, operand->getType())>4)
        {
          GBE_ASSERTM(false, "CHECK IT");
          operand->dump();

          //previous instruction is sext or zext instr. ignore it
          CastInst *cast = dyn_cast<CastInst>(operand);
          if(cast && (isa<ZExtInst>(operand) || isa<SExtInst>(operand)))
          {
            //hope that CastInst is a s/zext
            operand = cast->getOperand(0);
          }
          else
          {
            //trunctate
            operand = 
              new TruncInst(operand, 
                  IntegerType::get(GEPInst->getContext(), 
                    ptrSize), 
                  "", GEPInst);
          }
        }
#endif
        Value* tmpMul = operand;
        if (size != 1) {
          tmpMul = BinaryOperator::Create(Instruction::Mul, newConstSize, operand,
                                         "", GEPInst);
        }
        currentAddrInst = 
          BinaryOperator::Create(Instruction::Add, currentAddrInst, tmpMul,
              "", GEPInst);
      }

      //step down in type hirachy
      CompTy = dyn_cast<CompositeType>(CompTy->getTypeAtIndex(TypeIndex));
    }

    //insert addition of new offset before GEPInst when it is not zero
    if (constantOffset != 0) {
      Constant* newConstOffset =
        ConstantInt::get(IntegerType::get(GEPInst->getContext(),
              ptrSize),
            constantOffset);
      currentAddrInst =
        BinaryOperator::Create(Instruction::Add, currentAddrInst,
            newConstOffset, "", GEPInst);
    }

    //convert offset to ptr type (nop)
    IntToPtrInst* intToPtrInst = 
      new IntToPtrInst(currentAddrInst,GEPInst->getType(),"", GEPInst);

    //replace uses of the GEP instruction with the newly calculated pointer
    GEPInst->replaceAllUsesWith(intToPtrInst);
    GEPInst->dropAllReferences();
    GEPInst->eraseFromParent();

    return true;
  }

  BasicBlockPass *createRemoveGEPPass(const ir::Unit &unit) {
    return new GenRemoveGEPPasss(unit);
  }
} /* namespace gbe */