summaryrefslogtreecommitdiff
path: root/mlir/lib/Interfaces/InferIntRangeInterface.cpp
blob: cc31104ce3335272b42077bc461b151c27adf510 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include <optional>
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"

using namespace mlir;

bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
  return umin().getBitWidth() == other.umin().getBitWidth() &&
         umin() == other.umin() && umax() == other.umax() &&
         smin() == other.smin() && smax() == other.smax();
}

const APInt &ConstantIntRanges::umin() const { return uminVal; }

const APInt &ConstantIntRanges::umax() const { return umaxVal; }

const APInt &ConstantIntRanges::smin() const { return sminVal; }

const APInt &ConstantIntRanges::smax() const { return smaxVal; }

unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
  if (type.isIndex())
    return IndexType::kInternalStorageBitWidth;
  if (auto integerType = dyn_cast<IntegerType>(type))
    return integerType.getWidth();
  // Non-integer types have their bounds stored in width 0 `APInt`s.
  return 0;
}

ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
  return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
}

ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
  return {value, value, value, value};
}

ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
                                           bool isSigned) {
  if (isSigned)
    return fromSigned(min, max);
  return fromUnsigned(min, max);
}

ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
                                                const APInt &smax) {
  unsigned int width = smin.getBitWidth();
  APInt umin, umax;
  if (smin.isNonNegative() == smax.isNonNegative()) {
    umin = smin.ult(smax) ? smin : smax;
    umax = smin.ugt(smax) ? smin : smax;
  } else {
    umin = APInt::getMinValue(width);
    umax = APInt::getMaxValue(width);
  }
  return {umin, umax, smin, smax};
}

ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
                                                  const APInt &umax) {
  unsigned int width = umin.getBitWidth();
  APInt smin, smax;
  if (umin.isNonNegative() == umax.isNonNegative()) {
    smin = umin.slt(umax) ? umin : umax;
    smax = umin.sgt(umax) ? umin : umax;
  } else {
    smin = APInt::getSignedMinValue(width);
    smax = APInt::getSignedMaxValue(width);
  }
  return {umin, umax, smin, smax};
}

ConstantIntRanges
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
  // "Not an integer" poisons everything and also cannot be fed to comparison
  // operators.
  if (umin().getBitWidth() == 0)
    return *this;
  if (other.umin().getBitWidth() == 0)
    return other;

  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();

  return {uminUnion, umaxUnion, sminUnion, smaxUnion};
}

ConstantIntRanges
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
  // "Not an integer" poisons everything and also cannot be fed to comparison
  // operators.
  if (umin().getBitWidth() == 0)
    return *this;
  if (other.umin().getBitWidth() == 0)
    return other;

  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();

  return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
}

std::optional<APInt> ConstantIntRanges::getConstantValue() const {
  // Note: we need to exclude the trivially-equal width 0 values here.
  if (umin() == umax() && umin().getBitWidth() != 0)
    return umin();
  if (smin() == smax() && smin().getBitWidth() != 0)
    return smin();
  return std::nullopt;
}

raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
  return os << "unsigned : [" << range.umin() << ", " << range.umax()
            << "] signed : [" << range.smin() << ", " << range.smax() << "]";
}