1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
|
//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python::adaptors;
static void populateDialectQuantSubmodule(const py::module &m) {
//===-------------------------------------------------------------------===//
// QuantizedType
//===-------------------------------------------------------------------===//
auto quantizedType =
mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
quantizedType.def_staticmethod(
"default_minimum_for_integer",
[](bool isSigned, unsigned integralWidth) {
return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
integralWidth);
},
"Default minimum value for the integer with the specified signedness and "
"bit width.",
py::arg("is_signed"), py::arg("integral_width"));
quantizedType.def_staticmethod(
"default_maximum_for_integer",
[](bool isSigned, unsigned integralWidth) {
return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
integralWidth);
},
"Default maximum value for the integer with the specified signedness and "
"bit width.",
py::arg("is_signed"), py::arg("integral_width"));
quantizedType.def_property_readonly(
"expressed_type",
[](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
"Type expressed by this quantized type.");
quantizedType.def_property_readonly(
"flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
"Flags of this quantized type (named accessors should be preferred to "
"this)");
quantizedType.def_property_readonly(
"is_signed",
[](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
"Signedness of this quantized type.");
quantizedType.def_property_readonly(
"storage_type",
[](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
"Storage type backing this quantized type.");
quantizedType.def_property_readonly(
"storage_type_min",
[](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
"The minimum value held by the storage type of this quantized type.");
quantizedType.def_property_readonly(
"storage_type_max",
[](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
"The maximum value held by the storage type of this quantized type.");
quantizedType.def_property_readonly(
"storage_type_integral_width",
[](MlirType type) {
return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
},
"The bitwidth of the storage type of this quantized type.");
quantizedType.def(
"is_compatible_expressed_type",
[](MlirType type, MlirType candidate) {
return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
},
"Checks whether the candidate type can be expressed by this quantized "
"type.",
py::arg("candidate"));
quantizedType.def_property_readonly(
"quantized_element_type",
[](MlirType type) {
return mlirQuantizedTypeGetQuantizedElementType(type);
},
"Element type of this quantized type expressed as quantized type.");
quantizedType.def(
"cast_from_storage_type",
[](MlirType type, MlirType candidate) {
MlirType castResult =
mlirQuantizedTypeCastFromStorageType(type, candidate);
if (!mlirTypeIsNull(castResult))
return castResult;
throw py::type_error("Invalid cast.");
},
"Casts from a type based on the storage type of this quantized type to a "
"corresponding type based on the quantized type. Raises TypeError if the "
"cast is not valid.",
py::arg("candidate"));
quantizedType.def_staticmethod(
"cast_to_storage_type",
[](MlirType type) {
MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
if (!mlirTypeIsNull(castResult))
return castResult;
throw py::type_error("Invalid cast.");
},
"Casts from a type based on a quantized type to a corresponding type "
"based on the storage type of this quantized type. Raises TypeError if "
"the cast is not valid.",
py::arg("type"));
quantizedType.def(
"cast_from_expressed_type",
[](MlirType type, MlirType candidate) {
MlirType castResult =
mlirQuantizedTypeCastFromExpressedType(type, candidate);
if (!mlirTypeIsNull(castResult))
return castResult;
throw py::type_error("Invalid cast.");
},
"Casts from a type based on the expressed type of this quantized type to "
"a corresponding type based on the quantized type. Raises TypeError if "
"the cast is not valid.",
py::arg("candidate"));
quantizedType.def_staticmethod(
"cast_to_expressed_type",
[](MlirType type) {
MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
if (!mlirTypeIsNull(castResult))
return castResult;
throw py::type_error("Invalid cast.");
},
"Casts from a type based on a quantized type to a corresponding type "
"based on the expressed type of this quantized type. Raises TypeError if "
"the cast is not valid.",
py::arg("type"));
quantizedType.def(
"cast_expressed_to_storage_type",
[](MlirType type, MlirType candidate) {
MlirType castResult =
mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
if (!mlirTypeIsNull(castResult))
return castResult;
throw py::type_error("Invalid cast.");
},
"Casts from a type based on the expressed type of this quantized type to "
"a corresponding type based on the storage type. Raises TypeError if the "
"cast is not valid.",
py::arg("candidate"));
quantizedType.get_class().attr("FLAG_SIGNED") =
mlirQuantizedTypeGetSignedFlag();
//===-------------------------------------------------------------------===//
// AnyQuantizedType
//===-------------------------------------------------------------------===//
auto anyQuantizedType =
mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
quantizedType.get_class());
anyQuantizedType.def_classmethod(
"get",
[](py::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
storageTypeMin, storageTypeMax));
},
"Gets an instance of AnyQuantizedType in the same context as the "
"provided storage type.",
py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
py::arg("expressed_type"), py::arg("storage_type_min"),
py::arg("storage_type_max"));
//===-------------------------------------------------------------------===//
// UniformQuantizedType
//===-------------------------------------------------------------------===//
auto uniformQuantizedType = mlir_type_subclass(
m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
quantizedType.get_class());
uniformQuantizedType.def_classmethod(
"get",
[](py::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return cls(mlirUniformQuantizedTypeGet(flags, storageType,
expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax));
},
"Gets an instance of UniformQuantizedType in the same context as the "
"provided storage type.",
py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
py::arg("storage_type_min"), py::arg("storage_type_max"));
uniformQuantizedType.def_property_readonly(
"scale",
[](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
"The scale designates the difference between the real values "
"corresponding to consecutive quantized values differing by 1.");
uniformQuantizedType.def_property_readonly(
"zero_point",
[](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
"The storage value corresponding to the real value 0 in the affine "
"equation.");
uniformQuantizedType.def_property_readonly(
"is_fixed_point",
[](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
"Fixed point values are real numbers divided by a scale.");
//===-------------------------------------------------------------------===//
// UniformQuantizedPerAxisType
//===-------------------------------------------------------------------===//
auto uniformQuantizedPerAxisType = mlir_type_subclass(
m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
quantizedType.get_class());
uniformQuantizedPerAxisType.def_classmethod(
"get",
[](py::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, std::vector<double> scales,
std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (scales.size() != zeroPoints.size())
throw py::value_error(
"Mismatching number of scales and zero points.");
auto nDims = static_cast<intptr_t>(scales.size());
return cls(mlirUniformQuantizedPerAxisTypeGet(
flags, storageType, expressedType, nDims, scales.data(),
zeroPoints.data(), quantizedDimension, storageTypeMin,
storageTypeMax));
},
"Gets an instance of UniformQuantizedPerAxisType in the same context as "
"the provided storage type.",
py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
py::arg("quantized_dimension"), py::arg("storage_type_min"),
py::arg("storage_type_max"));
uniformQuantizedPerAxisType.def_property_readonly(
"scales",
[](MlirType type) {
intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
std::vector<double> scales;
scales.reserve(nDim);
for (intptr_t i = 0; i < nDim; ++i) {
double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
scales.push_back(scale);
}
},
"The scales designate the difference between the real values "
"corresponding to consecutive quantized values differing by 1. The ith "
"scale corresponds to the ith slice in the quantized_dimension.");
uniformQuantizedPerAxisType.def_property_readonly(
"zero_points",
[](MlirType type) {
intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
std::vector<int64_t> zeroPoints;
zeroPoints.reserve(nDim);
for (intptr_t i = 0; i < nDim; ++i) {
int64_t zeroPoint =
mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
zeroPoints.push_back(zeroPoint);
}
},
"the storage values corresponding to the real value 0 in the affine "
"equation. The ith zero point corresponds to the ith slice in the "
"quantized_dimension.");
uniformQuantizedPerAxisType.def_property_readonly(
"quantized_dimension",
[](MlirType type) {
return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
},
"Specifies the dimension of the shape that the scales and zero points "
"correspond to.");
uniformQuantizedPerAxisType.def_property_readonly(
"is_fixed_point",
[](MlirType type) {
return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
},
"Fixed point values are real numbers divided by a scale.");
//===-------------------------------------------------------------------===//
// CalibratedQuantizedType
//===-------------------------------------------------------------------===//
auto calibratedQuantizedType = mlir_type_subclass(
m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
quantizedType.get_class());
calibratedQuantizedType.def_classmethod(
"get",
[](py::object cls, MlirType expressedType, double min, double max) {
return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
},
"Gets an instance of CalibratedQuantizedType in the same context as the "
"provided expressed type.",
py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
py::arg("max"));
calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
return mlirCalibratedQuantizedTypeGetMin(type);
});
calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
return mlirCalibratedQuantizedTypeGetMax(type);
});
}
PYBIND11_MODULE(_mlirDialectsQuant, m) {
m.doc() = "MLIR Quantization dialect";
populateDialectQuantSubmodule(m);
}
|