summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenn Knowles <kenn@kennknowles.com>2016-10-04 20:55:03 -0700
committerKenn Knowles <kenn@kennknowles.com>2016-10-04 20:55:03 -0700
commit270fcbf99cc5770fe545f12fa93d4930d3929481 (patch)
tree8025e26c243d348de38da28316d181584c633f2e
parent1235defdb679bdf4f91ba408548b7ff3e27a6b1e (diff)
downloadjsonpath-rw-270fcbf99cc5770fe545f12fa93d4930d3929481.tar.gz
Add tests and fixes for update
-rw-r--r--jsonpath_rw/jsonpath.py38
-rw-r--r--tests/test_jsonpath.py62
-rw-r--r--tests/test_update.py67
3 files changed, 95 insertions, 72 deletions
diff --git a/jsonpath_rw/jsonpath.py b/jsonpath_rw/jsonpath.py
index 395516e..93d7b81 100644
--- a/jsonpath_rw/jsonpath.py
+++ b/jsonpath_rw/jsonpath.py
@@ -26,7 +26,11 @@ class JSONPath(object):
raise NotImplementedError()
def update(self, data, val):
- "Returns `data` with the specified path replaced by `val`"
+ """
+ Returns `data` with the specified path replaced by `val`. Only updates
+ if the specified path exists.
+ """
+
raise NotImplementedError()
def child(self, child):
@@ -340,8 +344,30 @@ class Descendants(JSONPath):
return False
def update(self, data, val):
- for datum in self.left.find(data):
- self.right.update(datum.value, val)
+ # Get all left matches into a list
+ left_matches = self.left.find(data)
+ if not isinstance(left_matches, list):
+ left_matches = [left_matches]
+
+ def update_recursively(data):
+ # Update only mutable values corresponding to JSON types
+ if not (isinstance(data, list) or isinstance(data, dict)):
+ return
+
+ self.right.update(data, val)
+
+ # Manually do the * or [*] to avoid coercion and recurse just the right-hand pattern
+ if isinstance(data, list):
+ for i in range(0, len(data)):
+ update_recursively(data[i])
+
+ elif isinstance(data, dict):
+ for field in data.keys():
+ update_recursively(data[field])
+
+ for submatch in left_matches:
+ update_recursively(submatch.value)
+
return data
def __str__(self):
@@ -432,7 +458,8 @@ class Fields(JSONPath):
def update(self, data, val):
for field in self.reified_fields(DatumInContext.wrap(data)):
- data[field] = val
+ if field in data:
+ data[field] = val
return data
def __str__(self):
@@ -466,7 +493,8 @@ class Index(JSONPath):
return []
def update(self, data, val):
- data[self.index] = val
+ if len(data) > self.index:
+ data[self.index] = val
return data
def __eq__(self, other):
diff --git a/tests/test_jsonpath.py b/tests/test_jsonpath.py
index bf3d5c6..2c01992 100644
--- a/tests/test_jsonpath.py
+++ b/tests/test_jsonpath.py
@@ -290,3 +290,65 @@ class TestJsonPath(unittest.TestCase):
} },
['foo.baz',
'foo.bing.baz'] )])
+
+ def check_update_cases(self, test_cases):
+ for original, expr_str, value, expected in test_cases:
+ print('parse(%r).update(%r, %r) =?= %r'
+ % (expr_str, original, value, expected))
+ expr = parse(expr_str)
+ actual = expr.update(original, value)
+ assert actual == expected
+
+ def test_update_root(self):
+ self.check_update_cases([
+ ('foo', '$', 'bar', 'bar')
+ ])
+
+ def test_update_this(self):
+ self.check_update_cases([
+ ('foo', '`this`', 'bar', 'bar')
+ ])
+
+ def test_update_fields(self):
+ self.check_update_cases([
+ ({'foo': 1}, 'foo', 5, {'foo': 5}),
+ ({'foo': 1, 'bar': 2}, '$.*', 3, {'foo': 3, 'bar': 3})
+ ])
+
+ def test_update_child(self):
+ self.check_update_cases([
+ ({'foo': 'bar'}, '$.foo', 'baz', {'foo': 'baz'}),
+ ({'foo': {'bar': 1}}, 'foo.bar', 'baz', {'foo': {'bar': 'baz'}})
+ ])
+
+ def test_update_where(self):
+ self.check_update_cases([
+ ({'foo': {'bar': {'baz': 1}}, 'bar': {'baz': 2}},
+ '*.bar where baz', 5, {'foo': {'bar': 5}, 'bar': {'baz': 2}})
+ ])
+
+ def test_update_descendants_where(self):
+ self.check_update_cases([
+ ({'foo': {'bar': 1, 'flag': 1}, 'baz': {'bar': 2}},
+ '(* where flag) .. bar', 3,
+ {'foo': {'bar': 3, 'flag': 1}, 'baz': {'bar': 2}})
+ ])
+
+ def test_update_descendants(self):
+ self.check_update_cases([
+ ({'somefield': 1}, '$..somefield', 42, {'somefield': 42}),
+ ({'outer': {'nestedfield': 1}}, '$..nestedfield', 42, {'outer': {'nestedfield': 42}}),
+ ({'outs': {'bar': 1, 'ins': {'bar': 9}}, 'outs2': {'bar': 2}},
+ '$..bar', 42,
+ {'outs': {'bar': 42, 'ins': {'bar': 42}}, 'outs2': {'bar': 42}})
+ ])
+
+ def test_update_index(self):
+ self.check_update_cases([
+ (['foo', 'bar', 'baz'], '[0]', 'test', ['test', 'bar', 'baz'])
+ ])
+
+ def test_update_slice(self):
+ self.check_update_cases([
+ (['foo', 'bar', 'baz'], '[0:2]', 'test', ['test', 'test', 'baz'])
+ ])
diff --git a/tests/test_update.py b/tests/test_update.py
deleted file mode 100644
index f556c66..0000000
--- a/tests/test_update.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes
-
-import unittest
-import logging
-
-from jsonpath_rw.parser import parse
-
-
-class TestUpdate(unittest.TestCase):
-
- @classmethod
- def setup_class(cls):
- logging.basicConfig()
-
- def check_update_cases(self, test_cases):
- for original, expr_str, value, expected in test_cases:
- print('parse(%r).update(%r, %r) =?= %r'
- % (expr_str, original, value, expected))
- expr = parse(expr_str)
- actual = expr.update(original, value)
- assert actual == expected
-
- def test_update_root(self):
- self.check_update_cases([
- ('foo', '$', 'bar', 'bar')
- ])
-
- def test_update_this(self):
- self.check_update_cases([
- ('foo', '`this`', 'bar', 'bar')
- ])
-
- def test_update_fields(self):
- self.check_update_cases([
- ({'foo': 1}, 'foo', 5, {'foo': 5}),
- ({}, 'foo', 1, {'foo': 1}),
- ({'foo': 1, 'bar': 2}, '$.*', 3, {'foo': 3, 'bar': 3})
- ])
-
- def test_update_child(self):
- self.check_update_cases([
- ({'foo': 'bar'}, '$.foo', 'baz', {'foo': 'baz'}),
- ({'foo': {'bar': 1}}, 'foo.bar', 'baz', {'foo': {'bar': 'baz'}})
- ])
-
- def test_update_where(self):
- self.check_update_cases([
- ({'foo': {'bar': {'baz': 1}}, 'bar': {'baz': 2}},
- '*.bar where baz', 5, {'foo': {'bar': 5}, 'bar': {'baz': 2}})
- ])
-
- def test_update_descendants(self):
- self.check_update_cases([
- ({'foo': {'bar': 1, 'flag': 1}, 'baz': {'bar': 2}},
- '* where flag .. bar', 3,
- {'foo': {'bar': 3, 'flag': 1}, 'baz': {'bar': 2}})
- ])
-
- def test_update_index(self):
- self.check_update_cases([
- (['foo', 'bar', 'baz'], '[0]', 'test', ['test', 'bar', 'baz'])
- ])
-
- def test_update_slice(self):
- self.check_update_cases([
- (['foo', 'bar', 'baz'], '[0:2]', 'test', ['test', 'test', 'baz'])
- ])