diff options
author | Stephen Rosen <sirosen@globus.org> | 2019-08-26 19:20:04 -0400 |
---|---|---|
committer | Jon Parise <jon@pinterest.com> | 2019-08-26 16:20:04 -0700 |
commit | d5dafc1490b26fc8109571bd30311c9b729ea8a5 (patch) | |
tree | 84be8650d9ffe6a2bcde717d3613aa20fb4c2a73 | |
parent | 5699c9dfa7067a99000e281091dd6400a1e84122 (diff) | |
download | pymemcache-d5dafc1490b26fc8109571bd30311c9b729ea8a5.tar.gz |
Validate cas inputs as strings of digits (#250)
For consideration for v3.0.0
'cas' is documented as needing to be an int or bytestring of the digits
0-9. However, this is not actually enforced and it is possible to pass a
value to pymemcache which doesn't conform to these rules. In fact, you
can do weird things like `cas=b'noreply'` and potentially trigger
"unexpected" behavior.
To go along with validating int inputs, validate cas inputs. However,
these are not necessarily integers. Instead, if an int or string is
given, it will be encoded as a bytestring. But in order to validate the
value given, it is checked against isdigit() .
(NB: You could also use `int(cas)` for very similar checking.)
Rationale for allowing non-integer inputs to cas is not obvious.
Presumably it allows callers using `gets()` to pass the `cas` value they
get back into a `cas` command without issue. But it may be debatable.
-rw-r--r-- | ChangeLog.rst | 3 | ||||
-rw-r--r-- | pymemcache/client/base.py | 27 | ||||
-rw-r--r-- | pymemcache/test/test_client.py | 25 | ||||
-rw-r--r-- | pymemcache/test/test_integration.py | 5 |
4 files changed, 54 insertions, 6 deletions
diff --git a/ChangeLog.rst b/ChangeLog.rst index fb8b628..69702d6 100644 --- a/ChangeLog.rst +++ b/ChangeLog.rst @@ -13,6 +13,9 @@ New in version 3.0.0 (unreleased) as methods. (`serialize` and `deserialize` are still supported but considered deprecated) +* Validate inputs for ``cas`` -- values which are not integers or strings of + 0-9 now raise ``MemcacheIllegalInputError`` + New in version 2.2.2 -------------------- * Fix ``long_description`` string in Python packaging. diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 1214afc..f3fe270 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -787,6 +787,32 @@ class Client(object): return six.text_type(value).encode(self.encoding) + def _check_cas(self, cas): + """Check that a value is a valid input for 'cas' -- either an int or a + string containing only 0-9 + + The value will be (re)encoded so that we can accept strings or bytes. + """ + # convert non-binary values to binary + if isinstance(cas, (six.integer_types, VALID_STRING_TYPES)): + try: + cas = six.text_type(cas).encode(self.encoding) + except UnicodeEncodeError: + raise MemcacheIllegalInputError( + 'non-ASCII cas value: %r' % cas) + elif not isinstance(cas, six.binary_type): + raise MemcacheIllegalInputError( + 'cas must be integer, string, or bytes, got bad value: %r' % cas + ) + + if not cas.isdigit(): + raise MemcacheIllegalInputError( + 'cas must only contain values in 0-9, got bad value: %r' + % cas + ) + + return cas + def _extract_value(self, expect_cas, line, buf, remapped_keys, prefixed_keys): """ @@ -857,6 +883,7 @@ class Client(object): extra = b'' if cas is not None: + cas = self._check_cas(cas) extra += b' ' + cas if noreply: extra += b' noreply' diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index bfee6e1..0abad02 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -471,34 +471,47 @@ class TestClient(ClientTestMixin, unittest.TestCase): result = client.prepend(b'key', b'value', noreply=False) assert result is True + def test_cas_malformed(self): + client = self.make_client([b'STORED\r\n']) + with pytest.raises(MemcacheIllegalInputError): + client.cas(b'key', b'value', 'nonintegerstring', noreply=False) + + with pytest.raises(MemcacheIllegalInputError): + # even a space makes it a noninteger string + client.cas(b'key', b'value', '123 ', noreply=False) + + with pytest.raises(MemcacheIllegalInputError): + # non-ASCII digit + client.cas(b'key', b'value', u'⁰', noreply=False) + def test_cas_stored(self): client = self.make_client([b'STORED\r\n']) - result = client.cas(b'key', b'value', b'cas', noreply=False) + result = client.cas(b'key', b'value', b'123', noreply=False) assert result is True # unit test for encoding passed in __init__() client = self.make_client([b'STORED\r\n'], encoding='utf-8') - result = client.cas(b'key', b'value', b'cas', noreply=False) + result = client.cas(b'key', b'value', b'123', noreply=False) assert result is True def test_cas_exists(self): client = self.make_client([b'EXISTS\r\n']) - result = client.cas(b'key', b'value', b'cas', noreply=False) + result = client.cas(b'key', b'value', b'123', noreply=False) assert result is False # unit test for encoding passed in __init__() client = self.make_client([b'EXISTS\r\n'], encoding='utf-8') - result = client.cas(b'key', b'value', b'cas', noreply=False) + result = client.cas(b'key', b'value', b'123', noreply=False) assert result is False def test_cas_not_found(self): client = self.make_client([b'NOT_FOUND\r\n']) - result = client.cas(b'key', b'value', b'cas', noreply=False) + result = client.cas(b'key', b'value', b'123', noreply=False) assert result is None # unit test for encoding passed in __init__() client = self.make_client([b'NOT_FOUND\r\n'], encoding='utf-8') - result = client.cas(b'key', b'value', b'cas', noreply=False) + result = client.cas(b'key', b'value', b'123', noreply=False) assert result is None def test_cr_nl_boundaries(self): diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 9ba6f97..f5dbb1f 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -151,8 +151,13 @@ def test_cas(client_class, host, port, socket_module): result = client.set(b'key', b'value', noreply=False) assert result is True + # binary, string, and raw int all match -- should all be encoded as b'1' result = client.cas(b'key', b'value', b'1', noreply=False) assert result is False + result = client.cas(b'key', b'value', '1', noreply=False) + assert result is False + result = client.cas(b'key', b'value', 1, noreply=False) + assert result is False result, cas = client.gets(b'key') assert result == b'value' |