summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen Rosen <sirosen@globus.org>2019-08-26 19:20:04 -0400
committerJon Parise <jon@pinterest.com>2019-08-26 16:20:04 -0700
commitd5dafc1490b26fc8109571bd30311c9b729ea8a5 (patch)
tree84be8650d9ffe6a2bcde717d3613aa20fb4c2a73
parent5699c9dfa7067a99000e281091dd6400a1e84122 (diff)
downloadpymemcache-d5dafc1490b26fc8109571bd30311c9b729ea8a5.tar.gz
Validate cas inputs as strings of digits (#250)
For consideration for v3.0.0 'cas' is documented as needing to be an int or bytestring of the digits 0-9. However, this is not actually enforced and it is possible to pass a value to pymemcache which doesn't conform to these rules. In fact, you can do weird things like `cas=b'noreply'` and potentially trigger "unexpected" behavior. To go along with validating int inputs, validate cas inputs. However, these are not necessarily integers. Instead, if an int or string is given, it will be encoded as a bytestring. But in order to validate the value given, it is checked against isdigit() . (NB: You could also use `int(cas)` for very similar checking.) Rationale for allowing non-integer inputs to cas is not obvious. Presumably it allows callers using `gets()` to pass the `cas` value they get back into a `cas` command without issue. But it may be debatable.
-rw-r--r--ChangeLog.rst3
-rw-r--r--pymemcache/client/base.py27
-rw-r--r--pymemcache/test/test_client.py25
-rw-r--r--pymemcache/test/test_integration.py5
4 files changed, 54 insertions, 6 deletions
diff --git a/ChangeLog.rst b/ChangeLog.rst
index fb8b628..69702d6 100644
--- a/ChangeLog.rst
+++ b/ChangeLog.rst
@@ -13,6 +13,9 @@ New in version 3.0.0 (unreleased)
as methods. (`serialize` and `deserialize` are still supported but considered
deprecated)
+* Validate inputs for ``cas`` -- values which are not integers or strings of
+ 0-9 now raise ``MemcacheIllegalInputError``
+
New in version 2.2.2
--------------------
* Fix ``long_description`` string in Python packaging.
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index 1214afc..f3fe270 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -787,6 +787,32 @@ class Client(object):
return six.text_type(value).encode(self.encoding)
+ def _check_cas(self, cas):
+ """Check that a value is a valid input for 'cas' -- either an int or a
+ string containing only 0-9
+
+ The value will be (re)encoded so that we can accept strings or bytes.
+ """
+ # convert non-binary values to binary
+ if isinstance(cas, (six.integer_types, VALID_STRING_TYPES)):
+ try:
+ cas = six.text_type(cas).encode(self.encoding)
+ except UnicodeEncodeError:
+ raise MemcacheIllegalInputError(
+ 'non-ASCII cas value: %r' % cas)
+ elif not isinstance(cas, six.binary_type):
+ raise MemcacheIllegalInputError(
+ 'cas must be integer, string, or bytes, got bad value: %r' % cas
+ )
+
+ if not cas.isdigit():
+ raise MemcacheIllegalInputError(
+ 'cas must only contain values in 0-9, got bad value: %r'
+ % cas
+ )
+
+ return cas
+
def _extract_value(self, expect_cas, line, buf, remapped_keys,
prefixed_keys):
"""
@@ -857,6 +883,7 @@ class Client(object):
extra = b''
if cas is not None:
+ cas = self._check_cas(cas)
extra += b' ' + cas
if noreply:
extra += b' noreply'
diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py
index bfee6e1..0abad02 100644
--- a/pymemcache/test/test_client.py
+++ b/pymemcache/test/test_client.py
@@ -471,34 +471,47 @@ class TestClient(ClientTestMixin, unittest.TestCase):
result = client.prepend(b'key', b'value', noreply=False)
assert result is True
+ def test_cas_malformed(self):
+ client = self.make_client([b'STORED\r\n'])
+ with pytest.raises(MemcacheIllegalInputError):
+ client.cas(b'key', b'value', 'nonintegerstring', noreply=False)
+
+ with pytest.raises(MemcacheIllegalInputError):
+ # even a space makes it a noninteger string
+ client.cas(b'key', b'value', '123 ', noreply=False)
+
+ with pytest.raises(MemcacheIllegalInputError):
+ # non-ASCII digit
+ client.cas(b'key', b'value', u'⁰', noreply=False)
+
def test_cas_stored(self):
client = self.make_client([b'STORED\r\n'])
- result = client.cas(b'key', b'value', b'cas', noreply=False)
+ result = client.cas(b'key', b'value', b'123', noreply=False)
assert result is True
# unit test for encoding passed in __init__()
client = self.make_client([b'STORED\r\n'], encoding='utf-8')
- result = client.cas(b'key', b'value', b'cas', noreply=False)
+ result = client.cas(b'key', b'value', b'123', noreply=False)
assert result is True
def test_cas_exists(self):
client = self.make_client([b'EXISTS\r\n'])
- result = client.cas(b'key', b'value', b'cas', noreply=False)
+ result = client.cas(b'key', b'value', b'123', noreply=False)
assert result is False
# unit test for encoding passed in __init__()
client = self.make_client([b'EXISTS\r\n'], encoding='utf-8')
- result = client.cas(b'key', b'value', b'cas', noreply=False)
+ result = client.cas(b'key', b'value', b'123', noreply=False)
assert result is False
def test_cas_not_found(self):
client = self.make_client([b'NOT_FOUND\r\n'])
- result = client.cas(b'key', b'value', b'cas', noreply=False)
+ result = client.cas(b'key', b'value', b'123', noreply=False)
assert result is None
# unit test for encoding passed in __init__()
client = self.make_client([b'NOT_FOUND\r\n'], encoding='utf-8')
- result = client.cas(b'key', b'value', b'cas', noreply=False)
+ result = client.cas(b'key', b'value', b'123', noreply=False)
assert result is None
def test_cr_nl_boundaries(self):
diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py
index 9ba6f97..f5dbb1f 100644
--- a/pymemcache/test/test_integration.py
+++ b/pymemcache/test/test_integration.py
@@ -151,8 +151,13 @@ def test_cas(client_class, host, port, socket_module):
result = client.set(b'key', b'value', noreply=False)
assert result is True
+ # binary, string, and raw int all match -- should all be encoded as b'1'
result = client.cas(b'key', b'value', b'1', noreply=False)
assert result is False
+ result = client.cas(b'key', b'value', '1', noreply=False)
+ assert result is False
+ result = client.cas(b'key', b'value', 1, noreply=False)
+ assert result is False
result, cas = client.gets(b'key')
assert result == b'value'