summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAshley Sommer <ashleysommer@gmail.com>2020-03-28 13:11:32 +1000
committerGitHub <noreply@github.com>2020-03-28 13:11:32 +1000
commit2aebbf00449065ffe33a5224cfaf35165394b6b1 (patch)
treefeba0db88ed1dfde1eb51de1c185d679b1ea3ed0
parenta86eaa845857e00af5f393204a9546d3c76b08f7 (diff)
parentf01effead03b3b118573580536da7275c8f4b7cd (diff)
downloadrdflib-2aebbf00449065ffe33a5224cfaf35165394b6b1.tar.gz
Merge pull request #978 from RDFLib/pr_451_redux
Fix adding together literals, always get the expected datatype out.
-rw-r--r--rdflib/term.py51
-rw-r--r--test/test_term.py81
2 files changed, 126 insertions, 6 deletions
diff --git a/rdflib/term.py b/rdflib/term.py
index 8f3c8e6f..3d290258 100644
--- a/rdflib/term.py
+++ b/rdflib/term.py
@@ -644,15 +644,45 @@ class Literal(Identifier):
rdflib.term.Literal(u'11')
"""
- py = self.toPython()
- if not isinstance(py, Literal):
+ # if no val is supplied, return this Literal
+ if val is None:
+ return self
+
+ # convert the val to a Literal, if it isn't already one
+ if not isinstance(val, Literal):
+ val = Literal(val)
+
+ # if the datatypes are the same, just add the Python values and convert back
+ if self.datatype == val.datatype:
+ return Literal(self.toPython() + val.toPython(), self.language, datatype=self.datatype)
+ # if the datatypes are not the same but are both numeric, add the Python values and strip off decimal junk
+ # (i.e. tiny numbers (more than 17 decimal places) and trailing zeros) and return as a decimal
+ elif (
+ self.datatype in _NUMERIC_LITERAL_TYPES
+ and
+ val.datatype in _NUMERIC_LITERAL_TYPES
+ ):
+ return Literal(
+ Decimal(
+ ('%f' % round(Decimal(self.toPython()) + Decimal(val.toPython()), 15)).rstrip('0').rstrip('.')
+ ),
+ datatype=_XSD_DECIMAL
+ )
+ # in all other cases, perform string concatenation
+ else:
try:
- return Literal(py + val)
+ s = text_type.__add__(self, val)
except TypeError:
- pass # fall-through
+ s = str(self.value) + str(val)
+
+ # if the original datatype is string-like, use that
+ if self.datatype in _STRING_LITERAL_TYPES:
+ new_datatype = self.datatype
+ # if not, use string
+ else:
+ new_datatype = _XSD_STRING
- s = text_type.__add__(self, val)
- return Literal(s, self.language, self.datatype)
+ return Literal(s, self.language, datatype=new_datatype)
def __bool__(self):
"""
@@ -1434,6 +1464,15 @@ _TOTAL_ORDER_CASTERS = {
}
+_STRING_LITERAL_TYPES = (
+ _XSD_STRING,
+ _RDF_XMLLITERAL,
+ _RDF_HTMLLITERAL,
+ URIRef(_XSD_PFX + 'normalizedString'),
+ URIRef(_XSD_PFX + 'token')
+)
+
+
def _py2literal(obj, pType, castFunc, dType):
if castFunc:
return castFunc(obj), dType
diff --git a/test/test_term.py b/test/test_term.py
index c222a8d2..4dea9c3c 100644
--- a/test/test_term.py
+++ b/test/test_term.py
@@ -123,6 +123,87 @@ class TestLiteral(unittest.TestCase):
random.shuffle(l2)
self.assertListEqual(l1, sorted(l2))
+ def test_literal_add(self):
+ from decimal import Decimal
+
+ # compares Python decimals
+ def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
+ a = float(a)
+ b = float(b)
+ return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
+
+ cases = [
+ (1, Literal(1), Literal(1), Literal(2)),
+ (2, Literal(Decimal(1)), Literal(Decimal(1)), Literal(Decimal(2))),
+ (3, Literal(float(1)), Literal(float(1)), Literal(float(2))),
+ (4, Literal(1), Literal(1.1), Literal(2.1, datatype=XSD.decimal)),
+ (5, Literal(1.1), Literal(1.1), Literal(2.2)),
+ (6, Literal(Decimal(1)), Literal(Decimal(1.1)), Literal(Decimal(2.1), datatype=XSD.decimal)),
+ (7, Literal(Decimal(1.1)), Literal(Decimal(1.1)), Literal(Decimal(2.2))),
+ (8, Literal(float(1)), Literal(float(1.1)), Literal(float(2.1))),
+ (9, Literal(float(1.1)), Literal(float(1.1)), Literal(float(2.2))),
+ (10, Literal(-1), Literal(-1), Literal(-2)),
+ (12, Literal(Decimal(-1)), Literal(Decimal(-1)), Literal(Decimal(-2))),
+ (13, Literal(float(-1)), Literal(float(-1)), Literal(float(-2))),
+ (14, Literal(-1), Literal(-1.1), Literal(-2.1)),
+ (15, Literal(-1.1), Literal(-1.1), Literal(-2.2)),
+ (16, Literal(Decimal(-1)), Literal(Decimal(-1.1)), Literal(Decimal(-2.1))),
+ (17, Literal(Decimal(-1.1)), Literal(Decimal(-1.1)), Literal(Decimal(-2.2))),
+ (18, Literal(float(-1)), Literal(float(-1.1)), Literal(float(-2.1))),
+ (19, Literal(float(-1.1)), Literal(float(-1.1)), Literal(float(-2.2))),
+
+ (20, Literal(1), Literal(1.0), Literal(2.0)),
+ (21, Literal(1.0), Literal(1.0), Literal(2.0)),
+ (22, Literal(Decimal(1)), Literal(Decimal(1.0)), Literal(Decimal(2.0))),
+ (23, Literal(Decimal(1.0)), Literal(Decimal(1.0)), Literal(Decimal(2.0))),
+ (24, Literal(float(1)), Literal(float(1.0)), Literal(float(2.0))),
+ (25, Literal(float(1.0)), Literal(float(1.0)), Literal(float(2.0))),
+
+ (26, Literal(1, datatype=XSD.integer), Literal(1, datatype=XSD.integer), Literal(2, datatype=XSD.integer)),
+ (27, Literal(1, datatype=XSD.integer), Literal("1", datatype=XSD.integer), Literal("2", datatype=XSD.integer)),
+ (28, Literal("1", datatype=XSD.integer), Literal("1", datatype=XSD.integer), Literal("2", datatype=XSD.integer)),
+ (29, Literal("1"), Literal("1", datatype=XSD.integer), Literal("11", datatype=XSD.string)),
+ (30, Literal(1), Literal("1", datatype=XSD.integer), Literal("2", datatype=XSD.integer)),
+ (31, Literal(Decimal(1), datatype=XSD.decimal), Literal(Decimal(1), datatype=XSD.decimal), Literal(Decimal(2), datatype=XSD.decimal)),
+ (32, Literal(Decimal(1)), Literal(Decimal(1), datatype=XSD.decimal), Literal(Decimal(2), datatype=XSD.decimal)),
+ (33, Literal(float(1)), Literal(float(1), datatype=XSD.float), Literal(float(2), datatype=XSD.float)),
+ (34, Literal(float(1), datatype=XSD.float), Literal(float(1), datatype=XSD.float), Literal(float(2), datatype=XSD.float)),
+
+ (35, Literal(1), 1, Literal(2)),
+ (36, Literal(1), 1.0, Literal(2, datatype=XSD.decimal)),
+ (37, Literal(1.0), 1, Literal(2, datatype=XSD.decimal)),
+ (38, Literal(1.0), 1.0, Literal(2.0)),
+ (39, Literal(Decimal(1.0)), Decimal(1), Literal(Decimal(2.0))),
+ (40, Literal(Decimal(1.0)), Decimal(1.0), Literal(Decimal(2.0))),
+ (41, Literal(float(1.0)), float(1), Literal(float(2.0))),
+ (42, Literal(float(1.0)), float(1.0), Literal(float(2.0))),
+
+ (43, Literal(1, datatype=XSD.integer), "+1.1", Literal("1+1.1", datatype=XSD.string)),
+ (44, Literal(1, datatype=XSD.integer), Literal("+1.1", datatype=XSD.string), Literal("1+1.1", datatype=XSD.string)),
+ (45, Literal(Decimal(1.0), datatype=XSD.integer), Literal(u"1", datatype=XSD.string), Literal("11", datatype=XSD.string)),
+ (46, Literal(1.1, datatype=XSD.integer), Literal("1", datatype=XSD.string), Literal("1.11", datatype=XSD.string)),
+
+ (47, Literal(1, datatype=XSD.integer), None, Literal(1, datatype=XSD.integer)),
+ (48, Literal("1", datatype=XSD.string), None, Literal("1", datatype=XSD.string)),
+ ]
+
+ for case in cases:
+ # see if the addition exactly matches the expected output
+ case_passed = (case[1] + case[2]) == (case[3])
+ # see if the addition almost matches the expected output, for decimal precision errors
+ if not case_passed:
+ try:
+ case_passed = isclose((case[1] + case[2].value), case[3].value)
+ except:
+ pass
+
+ if not case_passed:
+ print(case[1], case[2])
+ print("expected: " + case[3] + ", " + case[3].datatype)
+ print("actual: " + (case[1] + case[2]) + ", " + (case[1] + case[2]).datatype)
+
+ self.assertTrue(case_passed, "Case " + str(case[0]) + " failed")
+
class TestValidityFunctions(unittest.TestCase):