summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Writer/IRNumbering.h
blob: aeb624e58ba0c1c492cb7d97cb8f9a65439861bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains various utilities that number IR structures in preparation
// for bytecode emission.
//
//===----------------------------------------------------------------------===//

#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H

#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"

namespace mlir {
class BytecodeDialectInterface;
class BytecodeWriterConfig;

namespace bytecode {
namespace detail {
struct DialectNumbering;

//===----------------------------------------------------------------------===//
// Attribute and Type Numbering
//===----------------------------------------------------------------------===//

/// This class represents a numbering entry for an Attribute or Type.
struct AttrTypeNumbering {
  AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {}

  /// The concrete value.
  PointerUnion<Attribute, Type> value;

  /// The number assigned to this value.
  unsigned number = 0;

  /// The number of references to this value.
  unsigned refCount = 1;

  /// The dialect of this value.
  DialectNumbering *dialect = nullptr;
};
struct AttributeNumbering : public AttrTypeNumbering {
  AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
  Attribute getValue() const { return value.get<Attribute>(); }
};
struct TypeNumbering : public AttrTypeNumbering {
  TypeNumbering(Type value) : AttrTypeNumbering(value) {}
  Type getValue() const { return value.get<Type>(); }
};

//===----------------------------------------------------------------------===//
// OpName Numbering
//===----------------------------------------------------------------------===//

/// This class represents the numbering entry of an operation name.
struct OpNameNumbering {
  OpNameNumbering(DialectNumbering *dialect, OperationName name)
      : dialect(dialect), name(name) {}

  /// The dialect of this value.
  DialectNumbering *dialect;

  /// The concrete name.
  OperationName name;

  /// The number assigned to this name.
  unsigned number = 0;

  /// The number of references to this name.
  unsigned refCount = 1;
};

//===----------------------------------------------------------------------===//
// Dialect Resource Numbering
//===----------------------------------------------------------------------===//

/// This class represents a numbering entry for a dialect resource.
struct DialectResourceNumbering {
  DialectResourceNumbering(std::string key) : key(std::move(key)) {}

  /// The key used to reference this resource.
  std::string key;

  /// The number assigned to this resource.
  unsigned number = 0;

  /// A flag indicating if this resource is only a declaration, not a full
  /// definition.
  bool isDeclaration = true;
};

//===----------------------------------------------------------------------===//
// Dialect Numbering
//===----------------------------------------------------------------------===//

/// This class represents a numbering entry for an Dialect.
struct DialectNumbering {
  DialectNumbering(StringRef name, unsigned number)
      : name(name), number(number) {}

  /// The namespace of the dialect.
  StringRef name;

  /// The number assigned to the dialect.
  unsigned number;

  /// The bytecode dialect interface of the dialect if defined.
  const BytecodeDialectInterface *interface = nullptr;

  /// The asm dialect interface of the dialect if defined.
  const OpAsmDialectInterface *asmInterface = nullptr;

  /// The referenced resources of this dialect.
  SetVector<AsmDialectResourceHandle> resources;

  /// A mapping from resource key to the corresponding resource numbering entry.
  llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
};

//===----------------------------------------------------------------------===//
// IRNumberingState
//===----------------------------------------------------------------------===//

/// This class manages numbering IR entities in preparation of bytecode
/// emission.
class IRNumberingState {
public:
  IRNumberingState(Operation *op);

  /// Return the numbered dialects.
  auto getDialects() {
    return llvm::make_pointee_range(llvm::make_second_range(dialects));
  }
  auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); }
  auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); }
  auto getTypes() { return llvm::make_pointee_range(orderedTypes); }

  /// Return the number for the given IR unit.
  unsigned getNumber(Attribute attr) {
    assert(attrs.count(attr) && "attribute not numbered");
    return attrs[attr]->number;
  }
  unsigned getNumber(Block *block) {
    assert(blockIDs.count(block) && "block not numbered");
    return blockIDs[block];
  }
  unsigned getNumber(OperationName opName) {
    assert(opNames.count(opName) && "opName not numbered");
    return opNames[opName]->number;
  }
  unsigned getNumber(Type type) {
    assert(types.count(type) && "type not numbered");
    return types[type]->number;
  }
  unsigned getNumber(Value value) {
    assert(valueIDs.count(value) && "value not numbered");
    return valueIDs[value];
  }
  unsigned getNumber(const AsmDialectResourceHandle &resource) {
    assert(dialectResources.count(resource) && "resource not numbered");
    return dialectResources[resource]->number;
  }

  /// Return the block and value counts of the given region.
  std::pair<unsigned, unsigned> getBlockValueCount(Region *region) {
    assert(regionBlockValueCounts.count(region) && "value not numbered");
    return regionBlockValueCounts[region];
  }

  /// Return the number of operations in the given block.
  unsigned getOperationCount(Block *block) {
    assert(blockOperationCounts.count(block) && "block not numbered");
    return blockOperationCounts[block];
  }

private:
  /// This class is used to provide a fake dialect writer for numbering nested
  /// attributes and types.
  struct NumberingDialectWriter;

  /// Number the given IR unit for bytecode emission.
  void number(Attribute attr);
  void number(Block &block);
  DialectNumbering &numberDialect(Dialect *dialect);
  DialectNumbering &numberDialect(StringRef dialect);
  void number(Operation &op);
  void number(OperationName opName);
  void number(Region &region);
  void number(Type type);

  /// Number the given dialect resources.
  void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources);

  /// Finalize the numberings of any dialect resources.
  void finalizeDialectResourceNumberings(Operation *rootOp);

  /// Mapping from IR to the respective numbering entries.
  DenseMap<Attribute, AttributeNumbering *> attrs;
  DenseMap<OperationName, OpNameNumbering *> opNames;
  DenseMap<Type, TypeNumbering *> types;
  DenseMap<Dialect *, DialectNumbering *> registeredDialects;
  llvm::MapVector<StringRef, DialectNumbering *> dialects;
  std::vector<AttributeNumbering *> orderedAttrs;
  std::vector<OpNameNumbering *> orderedOpNames;
  std::vector<TypeNumbering *> orderedTypes;

  /// A mapping from dialect resource handle to the numbering for the referenced
  /// resource.
  llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *>
      dialectResources;

  /// Allocators used for the various numbering entries.
  llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
  llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
  llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
  llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
  llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;

  /// The value ID for each Block and Value.
  DenseMap<Block *, unsigned> blockIDs;
  DenseMap<Value, unsigned> valueIDs;

  /// The number of operations in each block.
  DenseMap<Block *, unsigned> blockOperationCounts;

  /// A map from region to the number of blocks and values within that region.
  DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts;

  /// The next value ID to assign when numbering.
  unsigned nextValueID = 0;
};
} // namespace detail
} // namespace bytecode
} // namespace mlir

#endif