diff options
author | Bar Shaul <88437685+barshaul@users.noreply.github.com> | 2022-11-10 12:38:47 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-10 12:38:47 +0200 |
commit | bb06ccd52924800ac501d17c8a42038c8e5c5770 (patch) | |
tree | df9fa0ae2c2553ecc3779b3f7166d6cad4855c03 | |
parent | fb647430f00cc7bb67c978e75f2dabc661567779 (diff) | |
download | redis-py-bb06ccd52924800ac501d17c8a42038c8e5c5770.tar.gz |
CredentialsProvider class added to support password rotation (#2261)
* A CredentialsProvider class has been added to allow the user to add his own provider for password rotation
* Moved CredentialsProvider to a separate file, added type hints
* Changed username and password to properties
* Added: StaticCredentialProvider, examples, tests
Changed: CredentialsProvider to CredentialProvider
Fixed: calling AUTH only with password
* Changed private members' prefix to __
* fixed linters
* fixed auth test
* fixed credential test
* Raise an error if username or password are passed along with credential_provider
* fixing linters
* fixing test
* Changed dundered to single per side underscore
* Changed Connection class members username and password to properties to enable backward compatibility with changing the members value on existing connection.
* Reverting last commit and adding backward compatibility to 'username' and 'password' inside on_connect function
* Refactored CredentialProvider class
* Fixing tuple type to Tuple
* Fixing optional string members in UsernamePasswordCredentialProvider
* Fixed credential test
* Added credential provider support to AsyncRedis
* linters
* linters
* linters
* linters - black
Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
Co-authored-by: dvora-h <dvora.heller@redis.com>
-rw-r--r-- | CHANGES | 1 | ||||
-rw-r--r-- | docs/examples/asyncio_examples.ipynb | 129 | ||||
-rw-r--r-- | docs/examples/connection_examples.ipynb | 193 | ||||
-rw-r--r-- | redis/__init__.py | 3 | ||||
-rw-r--r-- | redis/asyncio/client.py | 3 | ||||
-rw-r--r-- | redis/asyncio/cluster.py | 3 | ||||
-rw-r--r-- | redis/asyncio/connection.py | 41 | ||||
-rwxr-xr-x | redis/client.py | 4 | ||||
-rw-r--r-- | redis/cluster.py | 1 | ||||
-rwxr-xr-x | redis/connection.py | 39 | ||||
-rw-r--r-- | redis/credentials.py | 26 | ||||
-rw-r--r-- | tests/test_asyncio/test_credentials.py | 284 | ||||
-rw-r--r-- | tests/test_credentials.py | 245 |
13 files changed, 898 insertions, 74 deletions
@@ -24,6 +24,7 @@ * ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225) * Remove compatibility code for old versions of Hiredis, drop Packaging dependency * The `deprecated` library is no longer a dependency + * Added CredentialsProvider class to support password rotation * Enable Lock for asyncio cluster mode * 4.1.3 (Feb 8, 2022) diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index dab7a96..855255c 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -21,11 +21,6 @@ { "cell_type": "code", "execution_count": 1, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [ { "name": "stdout", @@ -41,27 +36,29 @@ "connection = redis.Redis()\n", "print(f\"Ping successful: {await connection.ping()}\")\n", "await connection.close()" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", + "source": [ + "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ], "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } - }, - "source": [ - "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." - ] + } }, { "cell_type": "code", "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [], "source": [ "import redis.asyncio as redis\n", @@ -70,15 +67,16 @@ "await connection.close()\n", "# Or: await connection.close(close_connection_pool=False)\n", "await connection.connection_pool.disconnect()" - ] - }, - { - "cell_type": "markdown", + ], "metadata": { + "collapsed": false, "pycharm": { - "name": "#%% md\n" + "name": "#%%\n" } - }, + } + }, + { + "cell_type": "markdown", "source": [ "## Transactions (Multi/Exec)\n", "\n", @@ -87,16 +85,17 @@ "The commands will not be reflected in Redis until execute() is called & awaited.\n", "\n", "Usually, when performing a bulk operation, taking advantage of a “transaction” (e.g., Multi/Exec) is to be desired, as it will also add a layer of atomicity to your bulk operation." - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": 3, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [], "source": [ "import redis.asyncio as redis\n", @@ -106,25 +105,31 @@ " ok1, ok2 = await (pipe.set(\"key1\", \"value1\").set(\"key2\", \"value2\").execute())\n", "assert ok1\n", "assert ok2" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Pub/Sub Mode\n", "\n", "Subscribing to specific channels:" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": 4, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [ { "name": "stdout", @@ -165,23 +170,29 @@ " await r.publish(\"channel:1\", STOPWORD)\n", "\n", " await future" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ "Subscribing to channels matching a glob-style pattern:" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [ { "name": "stdout", @@ -223,11 +234,16 @@ " await r.publish(\"channel:1\", STOPWORD)\n", "\n", " await future" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Sentinel Client\n", "\n", @@ -236,16 +252,17 @@ "Calling aioredis.sentinel.Sentinel.master_for or aioredis.sentinel.Sentinel.slave_for methods will return Redis clients connected to specified services monitored by Sentinel.\n", "\n", "Sentinel client will detect failover and reconnect Redis clients automatically." - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [], "source": [ "import asyncio\n", @@ -260,7 +277,13 @@ "assert ok\n", "val = await r.get(\"key\")\n", "assert val == b\"value\"" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } } ], "metadata": { @@ -284,4 +307,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +}
\ No newline at end of file diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index b0084ff..ca8dd44 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -99,6 +99,197 @@ }, { "cell_type": "markdown", + "source": [ + "## Connecting to a redis instance with username and password credential provider" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import redis\n", + "\n", + "creds_provider = redis.UsernamePasswordCredentialProvider(\"username\", \"password\")\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", + "user_connection.ping()" + ], + "metadata": {} + } + }, + { + "cell_type": "markdown", + "source": [ + "## Connecting to a redis instance with standard credential provider" + ], + "metadata": {} + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from typing import Tuple\n", + "import redis\n", + "\n", + "creds_map = {\"user_1\": \"pass_1\",\n", + " \"user_2\": \"pass_2\"}\n", + "\n", + "class UserMapCredentialProvider(redis.CredentialProvider):\n", + " def __init__(self, username: str):\n", + " self.username = username\n", + "\n", + " def get_credentials(self) -> Tuple[str, str]:\n", + " return self.username, creds_map.get(self.username)\n", + "\n", + "# Create a default connection to set the ACL user\n", + "default_connection = redis.Redis(host=\"localhost\", port=6379)\n", + "default_connection.acl_setuser(\n", + " \"user_1\",\n", + " enabled=True,\n", + " passwords=[\"+\" + \"pass_1\"],\n", + " keys=\"~*\",\n", + " commands=[\"+ping\", \"+command\", \"+info\", \"+select\", \"+flushdb\"],\n", + ")\n", + "\n", + "# Create a UserMapCredentialProvider instance for user_1\n", + "creds_provider = UserMapCredentialProvider(\"user_1\")\n", + "# Initiate user connection with the credential provider\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379,\n", + " credential_provider=creds_provider)\n", + "user_connection.ping()" + ], + "metadata": {} + } + }, + { + "cell_type": "markdown", + "source": [ + "## Connecting to a redis instance first with an initial credential set and then calling the credential provider" + ], + "metadata": {} + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from typing import Union\n", + "import redis\n", + "\n", + "class InitCredsSetCredentialProvider(redis.CredentialProvider):\n", + " def __init__(self, username, password):\n", + " self.username = username\n", + " self.password = password\n", + " self.call_supplier = False\n", + "\n", + " def call_external_supplier(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " # Call to an external credential supplier\n", + " raise NotImplementedError\n", + "\n", + " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " if self.call_supplier:\n", + " return self.call_external_supplier()\n", + " # Use the init set only for the first time\n", + " self.call_supplier = True\n", + " return self.username, self.password\n", + "\n", + "cred_provider = InitCredsSetCredentialProvider(username=\"init_user\", password=\"init_pass\")" + ], + "metadata": {} + } + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Connecting to a redis instance with AWS Secrets Manager credential provider." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import redis\n", + "import boto3\n", + "import json\n", + "import cachetools.func\n", + "\n", + "sm_client = boto3.client('secretsmanager')\n", + " \n", + "def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", + " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", + " secret = sm_client.get_secret_value(secret_id, version_id)\n", + " return json.loads(secret['SecretString'])\n", + " creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n", + " return creds['username'], creds['password']\n", + "\n", + "secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", + "creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", + "user_connection.ping()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connecting to a redis instance with ElastiCache IAM credential provider." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import redis\n", + "import boto3\n", + "import cachetools.func\n", + "\n", + "ec_client = boto3.client('elasticache')\n", + "\n", + "def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", + " def get_iam_auth_token(user, endpoint, port, region):\n", + " return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", + " iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n", + " return iam_auth_token\n", + "\n", + "username = \"barshaul\"\n", + "endpoint = \"test-001.use1.cache.amazonaws.com\"\n", + "creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n", + " endpoint=endpoint)\n", + "user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n", + "user_connection.ping()" + ] + }, + { + "cell_type": "markdown", "metadata": {}, "source": [ "## Connecting to Redis instances by specifying a URL scheme.\n", @@ -176,4 +367,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +}
\ No newline at end of file diff --git a/redis/__init__.py b/redis/__init__.py index b7560a6..5201fe2 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -9,6 +9,7 @@ from redis.connection import ( SSLConnection, UnixDomainSocketConnection, ) +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -62,6 +63,7 @@ __all__ = [ "Connection", "ConnectionError", "ConnectionPool", + "CredentialProvider", "DataError", "from_url", "InvalidResponse", @@ -76,6 +78,7 @@ __all__ = [ "SentinelManagedConnection", "SentinelManagedSSLConnection", "SSLConnection", + "UsernamePasswordCredentialProvider", "StrictRedis", "TimeoutError", "UnixDomainSocketConnection", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 619ee11..c085571 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -46,6 +46,7 @@ from redis.commands import ( list_or_args, ) from redis.compat import Protocol, TypedDict +from redis.credentials import CredentialProvider from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -174,6 +175,7 @@ class Redis( retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new Redis client. @@ -199,6 +201,7 @@ class Redis( "db": db, "username": username, "password": password, + "credential_provider": credential_provider, "socket_timeout": socket_timeout, "encoding": encoding, "encoding_errors": encoding_errors, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 97f4151..57aafbd 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -40,6 +40,7 @@ from redis.cluster import ( ) from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.credentials import CredentialProvider from redis.exceptions import ( AskError, BusyLoadingError, @@ -220,6 +221,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, + credential_provider: Optional[CredentialProvider] = None, username: Optional[str] = None, password: Optional[str] = None, client_name: Optional[str] = None, @@ -266,6 +268,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "connection_class": Connection, "parser_class": ClusterParser, # Client related kwargs + "credential_provider": credential_provider, "username": username, "password": password, "client_name": client_name, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 1288bb6..df066c4 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -29,6 +29,7 @@ import async_timeout from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.compat import Protocol, TypedDict +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -416,6 +417,7 @@ class Connection: "db", "username", "client_name", + "credential_provider", "password", "socket_timeout", "socket_connect_timeout", @@ -465,14 +467,23 @@ class Connection: retry: Optional[Retry] = None, redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, + credential_provider: Optional[CredentialProvider] = None, ): + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.pid = os.getpid() self.host = host self.port = int(port) self.db = db - self.username = username self.client_name = client_name + self.credential_provider = credential_provider self.password = password + self.username = username self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None self.socket_keepalive = socket_keepalive @@ -637,14 +648,13 @@ class Connection: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) - # if username and/or password are set, authenticate - if self.username or self.password: - auth_args: Union[Tuple[str], Tuple[str, str]] - if self.username: - auth_args = (self.username, self.password or "") - else: - # Mypy bug: https://github.com/python/mypy/issues/10944 - auth_args = (self.password or "",) + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH await self.send_command("AUTH", *auth_args, check_health=False) @@ -656,7 +666,7 @@ class Connection: # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - await self.send_command("AUTH", self.password, check_health=False) + await self.send_command("AUTH", auth_args[-1], check_health=False) auth_response = await self.read_response() if str_if_bytes(auth_response) != "OK": @@ -1014,18 +1024,27 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] client_name: str = None, retry: Optional[Retry] = None, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new UnixDomainSocketConnection. To specify a retry policy, first set `retry_on_timeout` to `True` then set `retry` to a valid `Retry` object """ + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.pid = os.getpid() self.path = path self.db = db - self.username = username self.client_name = client_name + self.credential_provider = credential_provider self.password = password + self.username = username self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None self.retry_on_timeout = retry_on_timeout diff --git a/redis/client.py b/redis/client.py index 6a26d28..8356ba7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -5,6 +5,7 @@ import threading import time import warnings from itertools import chain +from typing import Optional from redis.commands import ( CoreCommands, @@ -13,6 +14,7 @@ from redis.commands import ( list_or_args, ) from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection +from redis.credentials import CredentialProvider from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -938,6 +940,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): username=None, retry=None, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new Redis client. @@ -985,6 +988,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): "health_check_interval": health_check_interval, "client_name": client_name, "redis_connect_func": redis_connect_func, + "credential_provider": credential_provider, } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/cluster.py b/redis/cluster.py index cb3b2a6..027fe40 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -121,6 +121,7 @@ REDIS_ALLOWED_KEYS = ( "connection_class", "connection_pool", "client_name", + "credential_provider", "db", "decode_responses", "encoding", diff --git a/redis/connection.py b/redis/connection.py index fecb06b..a2b0074 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,9 +8,11 @@ import weakref from itertools import chain from queue import Empty, Full, LifoQueue from time import time +from typing import Optional from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -502,6 +504,7 @@ class Connection: username=None, retry=None, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new Connection. @@ -510,13 +513,21 @@ class Connection: `retry` to a valid `Retry` object. To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.pid = os.getpid() self.host = host self.port = int(port) self.db = db - self.username = username self.client_name = client_name + self.credential_provider = credential_provider self.password = password + self.username = username self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive @@ -675,12 +686,13 @@ class Connection: "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) - # if username and/or password are set, authenticate - if self.username or self.password: - if self.username: - auth_args = (self.username, self.password or "") - else: - auth_args = (self.password,) + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -692,7 +704,7 @@ class Connection: # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command("AUTH", self.password, check_health=False) + self.send_command("AUTH", auth_args[-1], check_health=False) auth_response = self.read_response() if str_if_bytes(auth_response) != "OK": @@ -1050,6 +1062,7 @@ class UnixDomainSocketConnection(Connection): client_name=None, retry=None, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new UnixDomainSocketConnection. @@ -1058,12 +1071,20 @@ class UnixDomainSocketConnection(Connection): `retry` to a valid `Retry` object. To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.pid = os.getpid() self.path = path self.db = db - self.username = username self.client_name = client_name + self.credential_provider = credential_provider self.password = password + self.username = username self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: diff --git a/redis/credentials.py b/redis/credentials.py new file mode 100644 index 0000000..7ba26dc --- /dev/null +++ b/redis/credentials.py @@ -0,0 +1,26 @@ +from typing import Optional, Tuple, Union + + +class CredentialProvider: + """ + Credentials Provider. + """ + + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + raise NotImplementedError("get_credentials must be implemented") + + +class UsernamePasswordCredentialProvider(CredentialProvider): + """ + Simple implementation of CredentialProvider that just wraps static + username and password. + """ + + def __init__(self, username: Optional[str] = None, password: Optional[str] = None): + self.username = username or "" + self.password = password or "" + + def get_credentials(self): + if self.username: + return self.username, self.password + return (self.password,) diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py new file mode 100644 index 0000000..8e213cd --- /dev/null +++ b/tests/test_asyncio/test_credentials.py @@ -0,0 +1,284 @@ +import functools +import random +import string +from typing import Optional, Tuple, Union + +import pytest +import pytest_asyncio + +import redis +from redis import AuthenticationError, DataError, ResponseError +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.utils import str_if_bytes +from tests.conftest import skip_if_redis_enterprise + + +@pytest_asyncio.fixture() +async def r_acl_teardown(r: redis.Redis): + """ + A special fixture which removes the provided names from the database after use + """ + usernames = [] + + def factory(username): + usernames.append(username) + return r + + yield factory + for username in usernames: + await r.acl_deluser(username) + + +@pytest_asyncio.fixture() +async def r_required_pass_teardown(r: redis.Redis): + """ + A special fixture which removes the provided password from the database after use + """ + passwords = [] + + def factory(username): + passwords.append(username) + return r + + yield factory + for password in passwords: + try: + await r.auth(password) + except (ResponseError, AuthenticationError): + await r.auth("default", "") + await r.config_set("requirepass", "") + + +class NoPassCredProvider(CredentialProvider): + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + return "username", "" + + +class AsyncRandomAuthCredProvider(CredentialProvider): + def __init__(self, user: Optional[str], endpoint: str): + self.user = user + self.endpoint = endpoint + + @functools.lru_cache(maxsize=10) + def get_credentials(self) -> Union[Tuple[str, str], Tuple[str]]: + def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + return result_str + + if self.user: + auth_token: str = get_random_string(5) + self.user + "_" + self.endpoint + return self.user, auth_token + else: + auth_token: str = get_random_string(5) + self.endpoint + return (auth_token,) + + +async def init_acl_user(r, username, password): + # reset the user + await r.acl_deluser(username) + if password: + assert ( + await r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + ) + is True + ) + else: + assert ( + await r.acl_setuser( + username, + enabled=True, + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + nopass=True, + ) + is True + ) + + +async def init_required_pass(r, password): + await r.config_set("requirepass", password) + + +@pytest.mark.asyncio +class TestCredentialsProvider: + @skip_if_redis_enterprise() + async def test_only_pass_without_creds_provider( + self, r_required_pass_teardown, create_redis + ): + # test for default user (`username` is supposed to be optional) + password = "password" + r = r_required_pass_teardown(password) + await init_required_pass(r, password) + assert await r.auth(password) is True + + r2 = await create_redis(flushdb=False, password=password) + + assert await r2.ping() is True + + @skip_if_redis_enterprise() + async def test_user_and_pass_without_creds_provider( + self, r_acl_teardown, create_redis + ): + """ + Test backward compatibility with username and password + """ + # test for other users + username = "username" + password = "password" + r = r_acl_teardown(username) + await init_acl_user(r, username, password) + r2 = await create_redis(flushdb=False, username=username, password=password) + + assert await r2.ping() is True + + @pytest.mark.parametrize("username", ["username", None]) + @skip_if_redis_enterprise() + @pytest.mark.onlynoncluster + async def test_credential_provider_with_supplier( + self, r_acl_teardown, r_required_pass_teardown, create_redis, username + ): + creds_provider = AsyncRandomAuthCredProvider( + user=username, + endpoint="localhost", + ) + + auth_args = creds_provider.get_credentials() + password = auth_args[-1] + + if username: + r = r_acl_teardown(username) + await init_acl_user(r, username, password) + else: + r = r_required_pass_teardown(password) + await init_required_pass(r, password) + + r2 = await create_redis(flushdb=False, credential_provider=creds_provider) + + assert await r2.ping() is True + + async def test_async_credential_provider_no_password_success( + self, r_acl_teardown, create_redis + ): + username = "username" + r = r_acl_teardown(username) + await init_acl_user(r, username, "") + r2 = await create_redis( + flushdb=False, + credential_provider=NoPassCredProvider(), + ) + assert await r2.ping() is True + + @pytest.mark.onlynoncluster + async def test_credential_provider_no_password_error( + self, r_acl_teardown, create_redis + ): + username = "username" + r = r_acl_teardown(username) + await init_acl_user(r, username, "password") + with pytest.raises(AuthenticationError) as e: + await create_redis( + flushdb=False, + credential_provider=NoPassCredProvider(), + single_connection_client=True, + ) + assert e.match("invalid username-password") + assert await r.acl_deluser(username) + + @pytest.mark.onlynoncluster + async def test_password_and_username_together_with_cred_provider_raise_error( + self, r_acl_teardown, create_redis + ): + username = "username" + r = r_acl_teardown(username) + await init_acl_user(r, username, "password") + cred_provider = UsernamePasswordCredentialProvider( + username="username", password="password" + ) + with pytest.raises(DataError) as e: + await create_redis( + flushdb=False, + username="username", + password="password", + credential_provider=cred_provider, + single_connection_client=True, + ) + assert e.match( + "'username' and 'password' cannot be passed along with " + "'credential_provider'." + ) + + @pytest.mark.onlynoncluster + async def test_change_username_password_on_existing_connection( + self, r_acl_teardown, create_redis + ): + username = "origin_username" + password = "origin_password" + new_username = "new_username" + new_password = "new_password" + r = r_acl_teardown(username) + await init_acl_user(r, username, password) + r2 = await create_redis(flushdb=False, username=username, password=password) + assert await r2.ping() is True + conn = await r2.connection_pool.get_connection("_") + await conn.send_command("PING") + assert str_if_bytes(await conn.read_response()) == "PONG" + assert conn.username == username + assert conn.password == password + await init_acl_user(r, new_username, new_password) + conn.password = new_password + conn.username = new_username + await conn.send_command("PING") + assert str_if_bytes(await conn.read_response()) == "PONG" + + +@pytest.mark.asyncio +class TestUsernamePasswordCredentialProvider: + async def test_user_pass_credential_provider_acl_user_and_pass( + self, r_acl_teardown, create_redis + ): + username = "username" + password = "password" + r = r_acl_teardown(username) + provider = UsernamePasswordCredentialProvider(username, password) + assert provider.username == username + assert provider.password == password + assert provider.get_credentials() == (username, password) + await init_acl_user(r, provider.username, provider.password) + r2 = await create_redis(flushdb=False, credential_provider=provider) + assert await r2.ping() is True + + async def test_user_pass_provider_only_password( + self, r_required_pass_teardown, create_redis + ): + password = "password" + provider = UsernamePasswordCredentialProvider(password=password) + r = r_required_pass_teardown(password) + assert provider.username == "" + assert provider.password == password + assert provider.get_credentials() == (password,) + + await init_required_pass(r, password) + + r2 = await create_redis(flushdb=False, credential_provider=provider) + assert await r2.auth(provider.password) is True + assert await r2.ping() is True diff --git a/tests/test_credentials.py b/tests/test_credentials.py new file mode 100644 index 0000000..9aeb1ef --- /dev/null +++ b/tests/test_credentials.py @@ -0,0 +1,245 @@ +import functools +import random +import string +from typing import Optional, Tuple, Union + +import pytest + +import redis +from redis import AuthenticationError, DataError, ResponseError +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.utils import str_if_bytes +from tests.conftest import _get_client, skip_if_redis_enterprise + + +class NoPassCredProvider(CredentialProvider): + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + return "username", "" + + +class RandomAuthCredProvider(CredentialProvider): + def __init__(self, user: Optional[str], endpoint: str): + self.user = user + self.endpoint = endpoint + + @functools.lru_cache(maxsize=10) + def get_credentials(self) -> Union[Tuple[str, str], Tuple[str]]: + def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + return result_str + + if self.user: + auth_token: str = get_random_string(5) + self.user + "_" + self.endpoint + return self.user, auth_token + else: + auth_token: str = get_random_string(5) + self.endpoint + return (auth_token,) + + +def init_acl_user(r, request, username, password): + # reset the user + r.acl_deluser(username) + if password: + assert ( + r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + ) + is True + ) + else: + assert ( + r.acl_setuser( + username, + enabled=True, + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + nopass=True, + ) + is True + ) + + if request is not None: + + def teardown(): + r.acl_deluser(username) + + request.addfinalizer(teardown) + + +def init_required_pass(r, request, password): + r.config_set("requirepass", password) + + def teardown(): + try: + r.auth(password) + except (ResponseError, AuthenticationError): + r.auth("default", "") + r.config_set("requirepass", "") + + request.addfinalizer(teardown) + + +class TestCredentialsProvider: + @skip_if_redis_enterprise() + def test_only_pass_without_creds_provider(self, r, request): + # test for default user (`username` is supposed to be optional) + password = "password" + init_required_pass(r, request, password) + assert r.auth(password) is True + + r2 = _get_client(redis.Redis, request, flushdb=False, password=password) + + assert r2.ping() is True + + @skip_if_redis_enterprise() + def test_user_and_pass_without_creds_provider(self, r, request): + """ + Test backward compatibility with username and password + """ + # test for other users + username = "username" + password = "password" + + init_acl_user(r, request, username, password) + r2 = _get_client( + redis.Redis, request, flushdb=False, username=username, password=password + ) + + assert r2.ping() is True + + @pytest.mark.parametrize("username", ["username", None]) + @skip_if_redis_enterprise() + @pytest.mark.onlynoncluster + def test_credential_provider_with_supplier(self, r, request, username): + creds_provider = RandomAuthCredProvider( + user=username, + endpoint="localhost", + ) + + password = creds_provider.get_credentials()[-1] + + if username: + init_acl_user(r, request, username, password) + else: + init_required_pass(r, request, password) + + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=creds_provider + ) + + assert r2.ping() is True + + def test_credential_provider_no_password_success(self, r, request): + init_acl_user(r, request, "username", "") + r2 = _get_client( + redis.Redis, + request, + flushdb=False, + credential_provider=NoPassCredProvider(), + ) + assert r2.ping() is True + + @pytest.mark.onlynoncluster + def test_credential_provider_no_password_error(self, r, request): + init_acl_user(r, request, "username", "password") + with pytest.raises(AuthenticationError) as e: + _get_client( + redis.Redis, + request, + flushdb=False, + credential_provider=NoPassCredProvider(), + ) + assert e.match("invalid username-password") + + @pytest.mark.onlynoncluster + def test_password_and_username_together_with_cred_provider_raise_error( + self, r, request + ): + init_acl_user(r, request, "username", "password") + cred_provider = UsernamePasswordCredentialProvider( + username="username", password="password" + ) + with pytest.raises(DataError) as e: + _get_client( + redis.Redis, + request, + flushdb=False, + username="username", + password="password", + credential_provider=cred_provider, + ) + assert e.match( + "'username' and 'password' cannot be passed along with " + "'credential_provider'." + ) + + @pytest.mark.onlynoncluster + def test_change_username_password_on_existing_connection(self, r, request): + username = "origin_username" + password = "origin_password" + new_username = "new_username" + new_password = "new_password" + init_acl_user(r, request, username, password) + r2 = _get_client( + redis.Redis, request, flushdb=False, username=username, password=password + ) + assert r2.ping() is True + conn = r2.connection_pool.get_connection("_") + conn.send_command("PING") + assert str_if_bytes(conn.read_response()) == "PONG" + assert conn.username == username + assert conn.password == password + init_acl_user(r, request, new_username, new_password) + conn.password = new_password + conn.username = new_username + conn.send_command("PING") + assert str_if_bytes(conn.read_response()) == "PONG" + + +class TestUsernamePasswordCredentialProvider: + def test_user_pass_credential_provider_acl_user_and_pass(self, r, request): + username = "username" + password = "password" + provider = UsernamePasswordCredentialProvider(username, password) + assert provider.username == username + assert provider.password == password + assert provider.get_credentials() == (username, password) + init_acl_user(r, request, provider.username, provider.password) + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=provider + ) + assert r2.ping() is True + + def test_user_pass_provider_only_password(self, r, request): + password = "password" + provider = UsernamePasswordCredentialProvider(password=password) + assert provider.username == "" + assert provider.password == password + assert provider.get_credentials() == (password,) + + init_required_pass(r, request, password) + + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=provider + ) + assert r2.auth(provider.password) is True + assert r2.ping() is True |