summaryrefslogtreecommitdiff
path: root/src/mbgl/style/expression/comparison.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mbgl/style/expression/comparison.cpp')
-rw-r--r--src/mbgl/style/expression/comparison.cpp275
1 files changed, 275 insertions, 0 deletions
diff --git a/src/mbgl/style/expression/comparison.cpp b/src/mbgl/style/expression/comparison.cpp
new file mode 100644
index 0000000000..6179c3ce88
--- /dev/null
+++ b/src/mbgl/style/expression/comparison.cpp
@@ -0,0 +1,275 @@
+#include <mbgl/style/expression/collator.hpp>
+#include <mbgl/style/expression/comparison.hpp>
+#include <mbgl/style/expression/dsl.hpp>
+
+namespace mbgl {
+namespace style {
+namespace expression {
+
+static bool isComparableType(const std::string& op, const type::Type& type) {
+ if (op == "==" || op == "!=") {
+ return type == type::String ||
+ type == type::Number ||
+ type == type::Boolean ||
+ type == type::Null ||
+ type == type::Value;
+ } else {
+ return type == type::String ||
+ type == type::Number ||
+ type == type::Value;
+ }
+}
+
+bool eq(Value a, Value b) { return a == b; }
+bool neq(Value a, Value b) { return a != b; }
+bool lt(Value lhs, Value rhs) {
+ return lhs.match(
+ [&](const std::string& a) { return a < rhs.get<std::string>(); },
+ [&](double a) { return a < rhs.get<double>(); },
+ [&](const auto&) { assert(false); return false; }
+ );
+}
+bool gt(Value lhs, Value rhs) {
+ return lhs.match(
+ [&](const std::string& a) { return a > rhs.get<std::string>(); },
+ [&](double a) { return a > rhs.get<double>(); },
+ [&](const auto&) { assert(false); return false; }
+ );
+}
+bool lteq(Value lhs, Value rhs) {
+ return lhs.match(
+ [&](const std::string& a) { return a <= rhs.get<std::string>(); },
+ [&](double a) { return a <= rhs.get<double>(); },
+ [&](const auto&) { assert(false); return false; }
+ );
+}
+bool gteq(Value lhs, Value rhs) {
+ return lhs.match(
+ [&](const std::string& a) { return a >= rhs.get<std::string>(); },
+ [&](double a) { return a >= rhs.get<double>(); },
+ [&](const auto&) { assert(false); return false; }
+ );
+}
+
+bool eqCollate(std::string a, std::string b, Collator c) { return c.compare(a, b) == 0; }
+bool neqCollate(std::string a, std::string b, Collator c) { return !eqCollate(a, b, c); }
+bool ltCollate(std::string a, std::string b, Collator c) { return c.compare(a, b) < 0; }
+bool gtCollate(std::string a, std::string b, Collator c) { return c.compare(a, b) > 0; }
+bool lteqCollate(std::string a, std::string b, Collator c) { return c.compare(a, b) <= 0; }
+bool gteqCollate(std::string a, std::string b, Collator c) { return c.compare(a, b) >= 0; }
+
+static BasicComparison::CompareFunctionType getBasicCompareFunction(const std::string& op) {
+ if (op == "==") return eq;
+ else if (op == "!=") return neq;
+ else if (op == ">") return gt;
+ else if (op == "<") return lt;
+ else if (op == ">=") return gteq;
+ else if (op == "<=") return lteq;
+ assert(false);
+ return nullptr;
+}
+
+static CollatorComparison::CompareFunctionType getCollatorComparisonFunction(const std::string& op) {
+ if (op == "==") return eqCollate;
+ else if (op == "!=") return neqCollate;
+ else if (op == ">") return gtCollate;
+ else if (op == "<") return ltCollate;
+ else if (op == ">=") return gteqCollate;
+ else if (op == "<=") return lteqCollate;
+ assert(false);
+ return nullptr;
+
+}
+
+
+BasicComparison::BasicComparison(
+ std::string op_,
+ std::unique_ptr<Expression> lhs_,
+ std::unique_ptr<Expression> rhs_)
+ : Expression(Kind::Comparison, type::Boolean),
+ op(std::move(op_)),
+ compare(getBasicCompareFunction(op)),
+ lhs(std::move(lhs_)),
+ rhs(std::move(rhs_)) {
+ assert(isComparableType(op, lhs->getType()) && isComparableType(op, rhs->getType()));
+ assert(lhs->getType() == rhs->getType() || lhs->getType() == type::Value || rhs->getType() == type::Value);
+
+ needsRuntimeTypeCheck = (op != "==" && op != "!=") &&
+ (lhs->getType() == type::Value || rhs->getType() == type::Value);
+}
+
+EvaluationResult BasicComparison::evaluate(const EvaluationContext& params) const {
+ EvaluationResult lhsResult = lhs->evaluate(params);
+ if (!lhsResult) return lhsResult;
+
+ EvaluationResult rhsResult = rhs->evaluate(params);
+ if (!rhsResult) return rhsResult;
+
+ if (needsRuntimeTypeCheck) {
+ type::Type lhsType = typeOf(*lhsResult);
+ type::Type rhsType = typeOf(*rhsResult);
+ // check that type is string or number, and equal
+ if (lhsType != rhsType || !(lhsType == type::String || lhsType == type::Number)) {
+ return EvaluationError {
+ R"(Expected arguments for ")" + op + R"(")" +
+ " to be (string, string) or (number, number), but found (" + toString(lhsType) + ", " +
+ toString(rhsType) + ") instead."
+ };
+ }
+ }
+
+ return compare(*lhsResult, *rhsResult);
+}
+
+void BasicComparison::eachChild(const std::function<void(const Expression&)>& visit) const {
+ visit(*lhs);
+ visit(*rhs);
+}
+
+std::string BasicComparison::getOperator() const { return op; }
+
+bool BasicComparison::operator==(const Expression& e) const {
+ if (e.getKind() == Kind::Comparison) {
+ auto comp = static_cast<const BasicComparison*>(&e);
+ return comp->op == op &&
+ *comp->lhs == *lhs &&
+ *comp->rhs == *rhs;
+ }
+ return false;
+}
+
+std::vector<optional<Value>> BasicComparison::possibleOutputs() const {
+ return {{ true }, { false }};
+}
+
+CollatorComparison::CollatorComparison(
+ std::string op_,
+ std::unique_ptr<Expression> lhs_,
+ std::unique_ptr<Expression> rhs_,
+ std::unique_ptr<Expression> collator_)
+ : Expression(Kind::Comparison, type::Boolean),
+ op(op_),
+ compare(getCollatorComparisonFunction(op)),
+ lhs(std::move(lhs_)),
+ rhs(std::move(rhs_)),
+ collator(std::move(collator_)) {
+ assert(isComparableType(op, lhs->getType()) && isComparableType(op, rhs->getType()));
+ assert(lhs->getType() == rhs->getType() || lhs->getType() == type::Value || rhs->getType() == type::Value);
+
+ needsRuntimeTypeCheck = (op == "==" || op == "!=") &&
+ (lhs->getType() == type::Value || rhs->getType() == type::Value);
+}
+
+EvaluationResult CollatorComparison::evaluate(const EvaluationContext& params) const {
+ EvaluationResult lhsResult = lhs->evaluate(params);
+ if (!lhsResult) return lhsResult;
+
+ EvaluationResult rhsResult = rhs->evaluate(params);
+ if (!rhsResult) return lhsResult;
+
+ if (needsRuntimeTypeCheck) {
+ if (typeOf(*lhsResult) != type::String || typeOf(*rhsResult) != type::String) {
+ return getBasicCompareFunction(op)(*lhsResult, *rhsResult);
+ }
+ }
+
+ auto collatorResult = collator->evaluate(params);
+ if (!collatorResult) return collatorResult;
+
+ const Collator& c = collatorResult->get<Collator>();
+ return compare(lhsResult->get<std::string>(), rhsResult->get<std::string>(), c);
+}
+
+void CollatorComparison::eachChild(const std::function<void(const Expression&)>& visit) const {
+ visit(*lhs);
+ visit(*rhs);
+ visit(*collator);
+}
+
+std::string CollatorComparison::getOperator() const { return op; }
+
+bool CollatorComparison::operator==(const Expression& e) const {
+ if (e.getKind() == Kind::Comparison) {
+ auto comp = static_cast<const CollatorComparison*>(&e);
+ return comp->op == op &&
+ *comp->collator == *collator &&
+ *comp->lhs == *lhs &&
+ *comp->rhs == *rhs;
+ }
+ return false;
+}
+
+std::vector<optional<Value>> CollatorComparison::possibleOutputs() const {
+ return {{ true }, { false }};
+}
+
+using namespace mbgl::style::conversion;
+ParseResult parseComparison(const Convertible& value, ParsingContext& ctx) {
+ std::size_t length = arrayLength(value);
+
+ if (length != 3 && length != 4) {
+ ctx.error("Expected two or three arguments.");
+ return ParseResult();
+ }
+
+ std::string op = *toString(arrayMember(value, 0));
+
+ assert(getBasicCompareFunction(op));
+
+ ParseResult lhs = ctx.parse(arrayMember(value, 1), 1, {type::Value});
+ if (!lhs) return ParseResult();
+ type::Type lhsType = (*lhs)->getType();
+ if (!isComparableType(op, lhsType)) {
+ ctx.error(R"(")" + op + R"(" comparisons are not supported for type ')" + toString(lhsType) + "'.", 1);
+ return ParseResult();
+ }
+
+ ParseResult rhs = ctx.parse(arrayMember(value, 2), 2, {type::Value});
+ if (!rhs) return ParseResult();
+ type::Type rhsType = (*rhs)->getType();
+ if (!isComparableType(op, rhsType)) {
+ ctx.error(R"(")" + op + R"(" comparisons are not supported for type ')" + toString(rhsType) + "'.", 2);
+ return ParseResult();
+ }
+
+ if (
+ lhsType != rhsType &&
+ lhsType != type::Value &&
+ rhsType != type::Value
+ ) {
+ ctx.error("Cannot compare types '" + toString(lhsType) + "' and '" + toString(rhsType) + "'.");
+ return ParseResult();
+ }
+
+ if (op != "==" && op != "!=") {
+ // typing rules specific to less/greater than operators
+ if (lhsType == type::Value && rhsType != type::Value) {
+ // (value, T)
+ lhs = dsl::assertion(rhsType, std::move(*lhs));
+ } else if (lhsType != type::Value && rhsType == type::Value) {
+ // (T, value)
+ rhs = dsl::assertion(lhsType, std::move(*rhs));
+ }
+ }
+
+ if (length == 4) {
+ if (
+ lhsType != type::String &&
+ rhsType != type::String &&
+ lhsType != type::Value &&
+ rhsType != type::Value
+ ) {
+ ctx.error("Cannot use collator to compare non-string types.");
+ return ParseResult();
+ }
+ ParseResult collatorParseResult = ctx.parse(arrayMember(value, 3), 3, {type::Collator});
+ if (!collatorParseResult) return ParseResult();
+ return ParseResult(std::make_unique<CollatorComparison>(op, std::move(*lhs), std::move(*rhs), std::move(*collatorParseResult)));
+ }
+
+ return ParseResult(std::make_unique<BasicComparison>(op, std::move(*lhs), std::move(*rhs)));
+}
+
+} // namespace expression
+} // namespace style
+} // namespace mbgl