summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Benner <josh@bennerweb.com>2015-07-13 14:11:31 -0400
committerJosh Benner <josh@bennerweb.com>2015-07-13 14:11:31 -0400
commit1235defdb679bdf4f91ba408548b7ff3e27a6b1e (patch)
treec4d40ff05d51b7726b2371b02f88c84ca9bbdcd1
parent5a4b2c7053ceeb1ea5d30abb019dfb7b2c2e31bb (diff)
downloadjsonpath-rw-1235defdb679bdf4f91ba408548b7ff3e27a6b1e.tar.gz
Implemented update method on most JSONPath descendents. Tests included.
-rw-r--r--jsonpath_rw/jsonpath.py29
-rw-r--r--tests/test_update.py67
2 files changed, 96 insertions, 0 deletions
diff --git a/jsonpath_rw/jsonpath.py b/jsonpath_rw/jsonpath.py
index 3c491d0..395516e 100644
--- a/jsonpath_rw/jsonpath.py
+++ b/jsonpath_rw/jsonpath.py
@@ -227,6 +227,11 @@ class Child(JSONPath):
if not isinstance(subdata, AutoIdForDatum)
for submatch in self.right.find(subdata)]
+ def update(self, data, val):
+ for datum in self.left.find(data):
+ self.right.update(datum.value, val)
+ return data
+
def __eq__(self, other):
return isinstance(other, Child) and self.left == other.left and self.right == other.right
@@ -274,6 +279,11 @@ class Where(JSONPath):
def find(self, data):
return [subdata for subdata in self.left.find(data) if self.right.find(subdata)]
+ def update(self, data, val):
+ for datum in self.find(data):
+ datum.path.update(data, val)
+ return data
+
def __str__(self):
return '%s where %s' % (self.left, self.right)
@@ -329,6 +339,11 @@ class Descendants(JSONPath):
def is_singular():
return False
+ def update(self, data, val):
+ for datum in self.left.find(data):
+ self.right.update(datum.value, val)
+ return data
+
def __str__(self):
return '%s..%s' % (self.left, self.right)
@@ -415,6 +430,11 @@ class Fields(JSONPath):
for field_datum in [self.get_field_datum(datum, field) for field in self.reified_fields(datum)]
if field_datum is not None]
+ def update(self, data, val):
+ for field in self.reified_fields(DatumInContext.wrap(data)):
+ data[field] = val
+ return data
+
def __str__(self):
return ','.join(map(str, self.fields))
@@ -445,6 +465,10 @@ class Index(JSONPath):
else:
return []
+ def update(self, data, val):
+ data[self.index] = val
+ return data
+
def __eq__(self, other):
return isinstance(other, Index) and self.index == other.index
@@ -495,6 +519,11 @@ class Slice(JSONPath):
else:
return [DatumInContext(datum.value[i], path=Index(i), context=datum) for i in range(0, len(datum.value))[self.start:self.end:self.step]]
+ def update(self, data, val):
+ for datum in self.find(data):
+ datum.path.update(data, val)
+ return data
+
def __str__(self):
if self.start == None and self.end == None and self.step == None:
return '[*]'
diff --git a/tests/test_update.py b/tests/test_update.py
new file mode 100644
index 0000000..f556c66
--- /dev/null
+++ b/tests/test_update.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes
+
+import unittest
+import logging
+
+from jsonpath_rw.parser import parse
+
+
+class TestUpdate(unittest.TestCase):
+
+ @classmethod
+ def setup_class(cls):
+ logging.basicConfig()
+
+ def check_update_cases(self, test_cases):
+ for original, expr_str, value, expected in test_cases:
+ print('parse(%r).update(%r, %r) =?= %r'
+ % (expr_str, original, value, expected))
+ expr = parse(expr_str)
+ actual = expr.update(original, value)
+ assert actual == expected
+
+ def test_update_root(self):
+ self.check_update_cases([
+ ('foo', '$', 'bar', 'bar')
+ ])
+
+ def test_update_this(self):
+ self.check_update_cases([
+ ('foo', '`this`', 'bar', 'bar')
+ ])
+
+ def test_update_fields(self):
+ self.check_update_cases([
+ ({'foo': 1}, 'foo', 5, {'foo': 5}),
+ ({}, 'foo', 1, {'foo': 1}),
+ ({'foo': 1, 'bar': 2}, '$.*', 3, {'foo': 3, 'bar': 3})
+ ])
+
+ def test_update_child(self):
+ self.check_update_cases([
+ ({'foo': 'bar'}, '$.foo', 'baz', {'foo': 'baz'}),
+ ({'foo': {'bar': 1}}, 'foo.bar', 'baz', {'foo': {'bar': 'baz'}})
+ ])
+
+ def test_update_where(self):
+ self.check_update_cases([
+ ({'foo': {'bar': {'baz': 1}}, 'bar': {'baz': 2}},
+ '*.bar where baz', 5, {'foo': {'bar': 5}, 'bar': {'baz': 2}})
+ ])
+
+ def test_update_descendants(self):
+ self.check_update_cases([
+ ({'foo': {'bar': 1, 'flag': 1}, 'baz': {'bar': 2}},
+ '* where flag .. bar', 3,
+ {'foo': {'bar': 3, 'flag': 1}, 'baz': {'bar': 2}})
+ ])
+
+ def test_update_index(self):
+ self.check_update_cases([
+ (['foo', 'bar', 'baz'], '[0]', 'test', ['test', 'bar', 'baz'])
+ ])
+
+ def test_update_slice(self):
+ self.check_update_cases([
+ (['foo', 'bar', 'baz'], '[0:2]', 'test', ['test', 'test', 'baz'])
+ ])