diff options
author | Kenn Knowles <kenn@kennknowles.com> | 2016-10-04 20:55:03 -0700 |
---|---|---|
committer | Kenn Knowles <kenn@kennknowles.com> | 2016-10-04 20:55:03 -0700 |
commit | 270fcbf99cc5770fe545f12fa93d4930d3929481 (patch) | |
tree | 8025e26c243d348de38da28316d181584c633f2e | |
parent | 1235defdb679bdf4f91ba408548b7ff3e27a6b1e (diff) | |
download | jsonpath-rw-270fcbf99cc5770fe545f12fa93d4930d3929481.tar.gz |
Add tests and fixes for update
-rw-r--r-- | jsonpath_rw/jsonpath.py | 38 | ||||
-rw-r--r-- | tests/test_jsonpath.py | 62 | ||||
-rw-r--r-- | tests/test_update.py | 67 |
3 files changed, 95 insertions, 72 deletions
diff --git a/jsonpath_rw/jsonpath.py b/jsonpath_rw/jsonpath.py index 395516e..93d7b81 100644 --- a/jsonpath_rw/jsonpath.py +++ b/jsonpath_rw/jsonpath.py @@ -26,7 +26,11 @@ class JSONPath(object): raise NotImplementedError() def update(self, data, val): - "Returns `data` with the specified path replaced by `val`" + """ + Returns `data` with the specified path replaced by `val`. Only updates + if the specified path exists. + """ + raise NotImplementedError() def child(self, child): @@ -340,8 +344,30 @@ class Descendants(JSONPath): return False def update(self, data, val): - for datum in self.left.find(data): - self.right.update(datum.value, val) + # Get all left matches into a list + left_matches = self.left.find(data) + if not isinstance(left_matches, list): + left_matches = [left_matches] + + def update_recursively(data): + # Update only mutable values corresponding to JSON types + if not (isinstance(data, list) or isinstance(data, dict)): + return + + self.right.update(data, val) + + # Manually do the * or [*] to avoid coercion and recurse just the right-hand pattern + if isinstance(data, list): + for i in range(0, len(data)): + update_recursively(data[i]) + + elif isinstance(data, dict): + for field in data.keys(): + update_recursively(data[field]) + + for submatch in left_matches: + update_recursively(submatch.value) + return data def __str__(self): @@ -432,7 +458,8 @@ class Fields(JSONPath): def update(self, data, val): for field in self.reified_fields(DatumInContext.wrap(data)): - data[field] = val + if field in data: + data[field] = val return data def __str__(self): @@ -466,7 +493,8 @@ class Index(JSONPath): return [] def update(self, data, val): - data[self.index] = val + if len(data) > self.index: + data[self.index] = val return data def __eq__(self, other): diff --git a/tests/test_jsonpath.py b/tests/test_jsonpath.py index bf3d5c6..2c01992 100644 --- a/tests/test_jsonpath.py +++ b/tests/test_jsonpath.py @@ -290,3 +290,65 @@ class TestJsonPath(unittest.TestCase): } }, ['foo.baz', 'foo.bing.baz'] )]) + + def check_update_cases(self, test_cases): + for original, expr_str, value, expected in test_cases: + print('parse(%r).update(%r, %r) =?= %r' + % (expr_str, original, value, expected)) + expr = parse(expr_str) + actual = expr.update(original, value) + assert actual == expected + + def test_update_root(self): + self.check_update_cases([ + ('foo', '$', 'bar', 'bar') + ]) + + def test_update_this(self): + self.check_update_cases([ + ('foo', '`this`', 'bar', 'bar') + ]) + + def test_update_fields(self): + self.check_update_cases([ + ({'foo': 1}, 'foo', 5, {'foo': 5}), + ({'foo': 1, 'bar': 2}, '$.*', 3, {'foo': 3, 'bar': 3}) + ]) + + def test_update_child(self): + self.check_update_cases([ + ({'foo': 'bar'}, '$.foo', 'baz', {'foo': 'baz'}), + ({'foo': {'bar': 1}}, 'foo.bar', 'baz', {'foo': {'bar': 'baz'}}) + ]) + + def test_update_where(self): + self.check_update_cases([ + ({'foo': {'bar': {'baz': 1}}, 'bar': {'baz': 2}}, + '*.bar where baz', 5, {'foo': {'bar': 5}, 'bar': {'baz': 2}}) + ]) + + def test_update_descendants_where(self): + self.check_update_cases([ + ({'foo': {'bar': 1, 'flag': 1}, 'baz': {'bar': 2}}, + '(* where flag) .. bar', 3, + {'foo': {'bar': 3, 'flag': 1}, 'baz': {'bar': 2}}) + ]) + + def test_update_descendants(self): + self.check_update_cases([ + ({'somefield': 1}, '$..somefield', 42, {'somefield': 42}), + ({'outer': {'nestedfield': 1}}, '$..nestedfield', 42, {'outer': {'nestedfield': 42}}), + ({'outs': {'bar': 1, 'ins': {'bar': 9}}, 'outs2': {'bar': 2}}, + '$..bar', 42, + {'outs': {'bar': 42, 'ins': {'bar': 42}}, 'outs2': {'bar': 42}}) + ]) + + def test_update_index(self): + self.check_update_cases([ + (['foo', 'bar', 'baz'], '[0]', 'test', ['test', 'bar', 'baz']) + ]) + + def test_update_slice(self): + self.check_update_cases([ + (['foo', 'bar', 'baz'], '[0:2]', 'test', ['test', 'test', 'baz']) + ]) diff --git a/tests/test_update.py b/tests/test_update.py deleted file mode 100644 index f556c66..0000000 --- a/tests/test_update.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - -import unittest -import logging - -from jsonpath_rw.parser import parse - - -class TestUpdate(unittest.TestCase): - - @classmethod - def setup_class(cls): - logging.basicConfig() - - def check_update_cases(self, test_cases): - for original, expr_str, value, expected in test_cases: - print('parse(%r).update(%r, %r) =?= %r' - % (expr_str, original, value, expected)) - expr = parse(expr_str) - actual = expr.update(original, value) - assert actual == expected - - def test_update_root(self): - self.check_update_cases([ - ('foo', '$', 'bar', 'bar') - ]) - - def test_update_this(self): - self.check_update_cases([ - ('foo', '`this`', 'bar', 'bar') - ]) - - def test_update_fields(self): - self.check_update_cases([ - ({'foo': 1}, 'foo', 5, {'foo': 5}), - ({}, 'foo', 1, {'foo': 1}), - ({'foo': 1, 'bar': 2}, '$.*', 3, {'foo': 3, 'bar': 3}) - ]) - - def test_update_child(self): - self.check_update_cases([ - ({'foo': 'bar'}, '$.foo', 'baz', {'foo': 'baz'}), - ({'foo': {'bar': 1}}, 'foo.bar', 'baz', {'foo': {'bar': 'baz'}}) - ]) - - def test_update_where(self): - self.check_update_cases([ - ({'foo': {'bar': {'baz': 1}}, 'bar': {'baz': 2}}, - '*.bar where baz', 5, {'foo': {'bar': 5}, 'bar': {'baz': 2}}) - ]) - - def test_update_descendants(self): - self.check_update_cases([ - ({'foo': {'bar': 1, 'flag': 1}, 'baz': {'bar': 2}}, - '* where flag .. bar', 3, - {'foo': {'bar': 3, 'flag': 1}, 'baz': {'bar': 2}}) - ]) - - def test_update_index(self): - self.check_update_cases([ - (['foo', 'bar', 'baz'], '[0]', 'test', ['test', 'bar', 'baz']) - ]) - - def test_update_slice(self): - self.check_update_cases([ - (['foo', 'bar', 'baz'], '[0:2]', 'test', ['test', 'test', 'baz']) - ]) |