summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBar Shaul <88437685+barshaul@users.noreply.github.com>2022-11-10 12:38:47 +0200
committerGitHub <noreply@github.com>2022-11-10 12:38:47 +0200
commitbb06ccd52924800ac501d17c8a42038c8e5c5770 (patch)
treedf9fa0ae2c2553ecc3779b3f7166d6cad4855c03
parentfb647430f00cc7bb67c978e75f2dabc661567779 (diff)
downloadredis-py-bb06ccd52924800ac501d17c8a42038c8e5c5770.tar.gz
CredentialsProvider class added to support password rotation (#2261)
* A CredentialsProvider class has been added to allow the user to add his own provider for password rotation * Moved CredentialsProvider to a separate file, added type hints * Changed username and password to properties * Added: StaticCredentialProvider, examples, tests Changed: CredentialsProvider to CredentialProvider Fixed: calling AUTH only with password * Changed private members' prefix to __ * fixed linters * fixed auth test * fixed credential test * Raise an error if username or password are passed along with credential_provider * fixing linters * fixing test * Changed dundered to single per side underscore * Changed Connection class members username and password to properties to enable backward compatibility with changing the members value on existing connection. * Reverting last commit and adding backward compatibility to 'username' and 'password' inside on_connect function * Refactored CredentialProvider class * Fixing tuple type to Tuple * Fixing optional string members in UsernamePasswordCredentialProvider * Fixed credential test * Added credential provider support to AsyncRedis * linters * linters * linters * linters - black Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
-rw-r--r--CHANGES1
-rw-r--r--docs/examples/asyncio_examples.ipynb129
-rw-r--r--docs/examples/connection_examples.ipynb193
-rw-r--r--redis/__init__.py3
-rw-r--r--redis/asyncio/client.py3
-rw-r--r--redis/asyncio/cluster.py3
-rw-r--r--redis/asyncio/connection.py41
-rwxr-xr-xredis/client.py4
-rw-r--r--redis/cluster.py1
-rwxr-xr-xredis/connection.py39
-rw-r--r--redis/credentials.py26
-rw-r--r--tests/test_asyncio/test_credentials.py284
-rw-r--r--tests/test_credentials.py245
13 files changed, 898 insertions, 74 deletions
diff --git a/CHANGES b/CHANGES
index 7bdfacf..4945f61 100644
--- a/CHANGES
+++ b/CHANGES
@@ -24,6 +24,7 @@
* ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225)
* Remove compatibility code for old versions of Hiredis, drop Packaging dependency
* The `deprecated` library is no longer a dependency
+ * Added CredentialsProvider class to support password rotation
* Enable Lock for asyncio cluster mode
* 4.1.3 (Feb 8, 2022)
diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb
index dab7a96..855255c 100644
--- a/docs/examples/asyncio_examples.ipynb
+++ b/docs/examples/asyncio_examples.ipynb
@@ -21,11 +21,6 @@
{
"cell_type": "code",
"execution_count": 1,
- "metadata": {
- "pycharm": {
- "name": "#%%\n"
- }
- },
"outputs": [
{
"name": "stdout",
@@ -41,27 +36,29 @@
"connection = redis.Redis()\n",
"print(f\"Ping successful: {await connection.ping()}\")\n",
"await connection.close()"
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
},
{
"cell_type": "markdown",
+ "source": [
+ "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool."
+ ],
"metadata": {
+ "collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
- },
- "source": [
- "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool."
- ]
+ }
},
{
"cell_type": "code",
"execution_count": 2,
- "metadata": {
- "pycharm": {
- "name": "#%%\n"
- }
- },
"outputs": [],
"source": [
"import redis.asyncio as redis\n",
@@ -70,15 +67,16 @@
"await connection.close()\n",
"# Or: await connection.close(close_connection_pool=False)\n",
"await connection.connection_pool.disconnect()"
- ]
- },
- {
- "cell_type": "markdown",
+ ],
"metadata": {
+ "collapsed": false,
"pycharm": {
- "name": "#%% md\n"
+ "name": "#%%\n"
}
- },
+ }
+ },
+ {
+ "cell_type": "markdown",
"source": [
"## Transactions (Multi/Exec)\n",
"\n",
@@ -87,16 +85,17 @@
"The commands will not be reflected in Redis until execute() is called & awaited.\n",
"\n",
"Usually, when performing a bulk operation, taking advantage of a “transaction” (e.g., Multi/Exec) is to be desired, as it will also add a layer of atomicity to your bulk operation."
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
},
{
"cell_type": "code",
"execution_count": 3,
- "metadata": {
- "pycharm": {
- "name": "#%%\n"
- }
- },
"outputs": [],
"source": [
"import redis.asyncio as redis\n",
@@ -106,25 +105,31 @@
" ok1, ok2 = await (pipe.set(\"key1\", \"value1\").set(\"key2\", \"value2\").execute())\n",
"assert ok1\n",
"assert ok2"
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
},
{
"cell_type": "markdown",
- "metadata": {},
"source": [
"## Pub/Sub Mode\n",
"\n",
"Subscribing to specific channels:"
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
},
{
"cell_type": "code",
"execution_count": 4,
- "metadata": {
- "pycharm": {
- "name": "#%%\n"
- }
- },
"outputs": [
{
"name": "stdout",
@@ -165,23 +170,29 @@
" await r.publish(\"channel:1\", STOPWORD)\n",
"\n",
" await future"
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
},
{
"cell_type": "markdown",
- "metadata": {},
"source": [
"Subscribing to channels matching a glob-style pattern:"
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
},
{
"cell_type": "code",
"execution_count": 5,
- "metadata": {
- "pycharm": {
- "name": "#%%\n"
- }
- },
"outputs": [
{
"name": "stdout",
@@ -223,11 +234,16 @@
" await r.publish(\"channel:1\", STOPWORD)\n",
"\n",
" await future"
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
},
{
"cell_type": "markdown",
- "metadata": {},
"source": [
"## Sentinel Client\n",
"\n",
@@ -236,16 +252,17 @@
"Calling aioredis.sentinel.Sentinel.master_for or aioredis.sentinel.Sentinel.slave_for methods will return Redis clients connected to specified services monitored by Sentinel.\n",
"\n",
"Sentinel client will detect failover and reconnect Redis clients automatically."
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "pycharm": {
- "name": "#%%\n"
- }
- },
"outputs": [],
"source": [
"import asyncio\n",
@@ -260,7 +277,13 @@
"assert ok\n",
"val = await r.get(\"key\")\n",
"assert val == b\"value\""
- ]
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
}
],
"metadata": {
@@ -284,4 +307,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
-}
+} \ No newline at end of file
diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb
index b0084ff..ca8dd44 100644
--- a/docs/examples/connection_examples.ipynb
+++ b/docs/examples/connection_examples.ipynb
@@ -99,6 +99,197 @@
},
{
"cell_type": "markdown",
+ "source": [
+ "## Connecting to a redis instance with username and password credential provider"
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "import redis\n",
+ "\n",
+ "creds_provider = redis.UsernamePasswordCredentialProvider(\"username\", \"password\")\n",
+ "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
+ "user_connection.ping()"
+ ],
+ "metadata": {}
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Connecting to a redis instance with standard credential provider"
+ ],
+ "metadata": {}
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "from typing import Tuple\n",
+ "import redis\n",
+ "\n",
+ "creds_map = {\"user_1\": \"pass_1\",\n",
+ " \"user_2\": \"pass_2\"}\n",
+ "\n",
+ "class UserMapCredentialProvider(redis.CredentialProvider):\n",
+ " def __init__(self, username: str):\n",
+ " self.username = username\n",
+ "\n",
+ " def get_credentials(self) -> Tuple[str, str]:\n",
+ " return self.username, creds_map.get(self.username)\n",
+ "\n",
+ "# Create a default connection to set the ACL user\n",
+ "default_connection = redis.Redis(host=\"localhost\", port=6379)\n",
+ "default_connection.acl_setuser(\n",
+ " \"user_1\",\n",
+ " enabled=True,\n",
+ " passwords=[\"+\" + \"pass_1\"],\n",
+ " keys=\"~*\",\n",
+ " commands=[\"+ping\", \"+command\", \"+info\", \"+select\", \"+flushdb\"],\n",
+ ")\n",
+ "\n",
+ "# Create a UserMapCredentialProvider instance for user_1\n",
+ "creds_provider = UserMapCredentialProvider(\"user_1\")\n",
+ "# Initiate user connection with the credential provider\n",
+ "user_connection = redis.Redis(host=\"localhost\", port=6379,\n",
+ " credential_provider=creds_provider)\n",
+ "user_connection.ping()"
+ ],
+ "metadata": {}
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Connecting to a redis instance first with an initial credential set and then calling the credential provider"
+ ],
+ "metadata": {}
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "from typing import Union\n",
+ "import redis\n",
+ "\n",
+ "class InitCredsSetCredentialProvider(redis.CredentialProvider):\n",
+ " def __init__(self, username, password):\n",
+ " self.username = username\n",
+ " self.password = password\n",
+ " self.call_supplier = False\n",
+ "\n",
+ " def call_external_supplier(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
+ " # Call to an external credential supplier\n",
+ " raise NotImplementedError\n",
+ "\n",
+ " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
+ " if self.call_supplier:\n",
+ " return self.call_external_supplier()\n",
+ " # Use the init set only for the first time\n",
+ " self.call_supplier = True\n",
+ " return self.username, self.password\n",
+ "\n",
+ "cred_provider = InitCredsSetCredentialProvider(username=\"init_user\", password=\"init_pass\")"
+ ],
+ "metadata": {}
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": [
+ "## Connecting to a redis instance with AWS Secrets Manager credential provider."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import redis\n",
+ "import boto3\n",
+ "import json\n",
+ "import cachetools.func\n",
+ "\n",
+ "sm_client = boto3.client('secretsmanager')\n",
+ " \n",
+ "def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
+ " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
+ " def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
+ " secret = sm_client.get_secret_value(secret_id, version_id)\n",
+ " return json.loads(secret['SecretString'])\n",
+ " creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n",
+ " return creds['username'], creds['password']\n",
+ "\n",
+ "secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
+ "creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n",
+ "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
+ "user_connection.ping()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Connecting to a redis instance with ElastiCache IAM credential provider."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import redis\n",
+ "import boto3\n",
+ "import cachetools.func\n",
+ "\n",
+ "ec_client = boto3.client('elasticache')\n",
+ "\n",
+ "def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n",
+ " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
+ " def get_iam_auth_token(user, endpoint, port, region):\n",
+ " return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
+ " iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n",
+ " return iam_auth_token\n",
+ "\n",
+ "username = \"barshaul\"\n",
+ "endpoint = \"test-001.use1.cache.amazonaws.com\"\n",
+ "creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n",
+ " endpoint=endpoint)\n",
+ "user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n",
+ "user_connection.ping()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
"metadata": {},
"source": [
"## Connecting to Redis instances by specifying a URL scheme.\n",
@@ -176,4 +367,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
-}
+} \ No newline at end of file
diff --git a/redis/__init__.py b/redis/__init__.py
index b7560a6..5201fe2 100644
--- a/redis/__init__.py
+++ b/redis/__init__.py
@@ -9,6 +9,7 @@ from redis.connection import (
SSLConnection,
UnixDomainSocketConnection,
)
+from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
@@ -62,6 +63,7 @@ __all__ = [
"Connection",
"ConnectionError",
"ConnectionPool",
+ "CredentialProvider",
"DataError",
"from_url",
"InvalidResponse",
@@ -76,6 +78,7 @@ __all__ = [
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
+ "UsernamePasswordCredentialProvider",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 619ee11..c085571 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -46,6 +46,7 @@ from redis.commands import (
list_or_args,
)
from redis.compat import Protocol, TypedDict
+from redis.credentials import CredentialProvider
from redis.exceptions import (
ConnectionError,
ExecAbortError,
@@ -174,6 +175,7 @@ class Redis(
retry: Optional[Retry] = None,
auto_close_connection_pool: bool = True,
redis_connect_func=None,
+ credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new Redis client.
@@ -199,6 +201,7 @@ class Redis(
"db": db,
"username": username,
"password": password,
+ "credential_provider": credential_provider,
"socket_timeout": socket_timeout,
"encoding": encoding,
"encoding_errors": encoding_errors,
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 97f4151..57aafbd 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -40,6 +40,7 @@ from redis.cluster import (
)
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
+from redis.credentials import CredentialProvider
from redis.exceptions import (
AskError,
BusyLoadingError,
@@ -220,6 +221,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
# Client related kwargs
db: Union[str, int] = 0,
path: Optional[str] = None,
+ credential_provider: Optional[CredentialProvider] = None,
username: Optional[str] = None,
password: Optional[str] = None,
client_name: Optional[str] = None,
@@ -266,6 +268,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
"connection_class": Connection,
"parser_class": ClusterParser,
# Client related kwargs
+ "credential_provider": credential_provider,
"username": username,
"password": password,
"client_name": client_name,
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 1288bb6..df066c4 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -29,6 +29,7 @@ import async_timeout
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.compat import Protocol, TypedDict
+from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
@@ -416,6 +417,7 @@ class Connection:
"db",
"username",
"client_name",
+ "credential_provider",
"password",
"socket_timeout",
"socket_connect_timeout",
@@ -465,14 +467,23 @@ class Connection:
retry: Optional[Retry] = None,
redis_connect_func: Optional[ConnectCallbackT] = None,
encoder_class: Type[Encoder] = Encoder,
+ credential_provider: Optional[CredentialProvider] = None,
):
+ if (username or password) and credential_provider is not None:
+ raise DataError(
+ "'username' and 'password' cannot be passed along with 'credential_"
+ "provider'. Please provide only one of the following arguments: \n"
+ "1. 'password' and (optional) 'username'\n"
+ "2. 'credential_provider'"
+ )
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.db = db
- self.username = username
self.client_name = client_name
+ self.credential_provider = credential_provider
self.password = password
+ self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
self.socket_keepalive = socket_keepalive
@@ -637,14 +648,13 @@ class Connection:
"""Initialize the connection, authenticate and select a database"""
self._parser.on_connect(self)
- # if username and/or password are set, authenticate
- if self.username or self.password:
- auth_args: Union[Tuple[str], Tuple[str, str]]
- if self.username:
- auth_args = (self.username, self.password or "")
- else:
- # Mypy bug: https://github.com/python/mypy/issues/10944
- auth_args = (self.password or "",)
+ # if credential provider or username and/or password are set, authenticate
+ if self.credential_provider or (self.username or self.password):
+ cred_provider = (
+ self.credential_provider
+ or UsernamePasswordCredentialProvider(self.username, self.password)
+ )
+ auth_args = cred_provider.get_credentials()
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
await self.send_command("AUTH", *auth_args, check_health=False)
@@ -656,7 +666,7 @@ class Connection:
# server seems to be < 6.0.0 which expects a single password
# arg. retry auth with just the password.
# https://github.com/andymccurdy/redis-py/issues/1274
- await self.send_command("AUTH", self.password, check_health=False)
+ await self.send_command("AUTH", auth_args[-1], check_health=False)
auth_response = await self.read_response()
if str_if_bytes(auth_response) != "OK":
@@ -1014,18 +1024,27 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init]
client_name: str = None,
retry: Optional[Retry] = None,
redis_connect_func=None,
+ credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new UnixDomainSocketConnection.
To specify a retry policy, first set `retry_on_timeout` to `True`
then set `retry` to a valid `Retry` object
"""
+ if (username or password) and credential_provider is not None:
+ raise DataError(
+ "'username' and 'password' cannot be passed along with 'credential_"
+ "provider'. Please provide only one of the following arguments: \n"
+ "1. 'password' and (optional) 'username'\n"
+ "2. 'credential_provider'"
+ )
self.pid = os.getpid()
self.path = path
self.db = db
- self.username = username
self.client_name = client_name
+ self.credential_provider = credential_provider
self.password = password
+ self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
self.retry_on_timeout = retry_on_timeout
diff --git a/redis/client.py b/redis/client.py
index 6a26d28..8356ba7 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -5,6 +5,7 @@ import threading
import time
import warnings
from itertools import chain
+from typing import Optional
from redis.commands import (
CoreCommands,
@@ -13,6 +14,7 @@ from redis.commands import (
list_or_args,
)
from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection
+from redis.credentials import CredentialProvider
from redis.exceptions import (
ConnectionError,
ExecAbortError,
@@ -938,6 +940,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
username=None,
retry=None,
redis_connect_func=None,
+ credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new Redis client.
@@ -985,6 +988,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
"health_check_interval": health_check_interval,
"client_name": client_name,
"redis_connect_func": redis_connect_func,
+ "credential_provider": credential_provider,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
diff --git a/redis/cluster.py b/redis/cluster.py
index cb3b2a6..027fe40 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -121,6 +121,7 @@ REDIS_ALLOWED_KEYS = (
"connection_class",
"connection_pool",
"client_name",
+ "credential_provider",
"db",
"decode_responses",
"encoding",
diff --git a/redis/connection.py b/redis/connection.py
index fecb06b..a2b0074 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -8,9 +8,11 @@ import weakref
from itertools import chain
from queue import Empty, Full, LifoQueue
from time import time
+from typing import Optional
from urllib.parse import parse_qs, unquote, urlparse
from redis.backoff import NoBackoff
+from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
@@ -502,6 +504,7 @@ class Connection:
username=None,
retry=None,
redis_connect_func=None,
+ credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new Connection.
@@ -510,13 +513,21 @@ class Connection:
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
+ if (username or password) and credential_provider is not None:
+ raise DataError(
+ "'username' and 'password' cannot be passed along with 'credential_"
+ "provider'. Please provide only one of the following arguments: \n"
+ "1. 'password' and (optional) 'username'\n"
+ "2. 'credential_provider'"
+ )
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.db = db
- self.username = username
self.client_name = client_name
+ self.credential_provider = credential_provider
self.password = password
+ self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
@@ -675,12 +686,13 @@ class Connection:
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
- # if username and/or password are set, authenticate
- if self.username or self.password:
- if self.username:
- auth_args = (self.username, self.password or "")
- else:
- auth_args = (self.password,)
+ # if credential provider or username and/or password are set, authenticate
+ if self.credential_provider or (self.username or self.password):
+ cred_provider = (
+ self.credential_provider
+ or UsernamePasswordCredentialProvider(self.username, self.password)
+ )
+ auth_args = cred_provider.get_credentials()
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
self.send_command("AUTH", *auth_args, check_health=False)
@@ -692,7 +704,7 @@ class Connection:
# server seems to be < 6.0.0 which expects a single password
# arg. retry auth with just the password.
# https://github.com/andymccurdy/redis-py/issues/1274
- self.send_command("AUTH", self.password, check_health=False)
+ self.send_command("AUTH", auth_args[-1], check_health=False)
auth_response = self.read_response()
if str_if_bytes(auth_response) != "OK":
@@ -1050,6 +1062,7 @@ class UnixDomainSocketConnection(Connection):
client_name=None,
retry=None,
redis_connect_func=None,
+ credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new UnixDomainSocketConnection.
@@ -1058,12 +1071,20 @@ class UnixDomainSocketConnection(Connection):
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
+ if (username or password) and credential_provider is not None:
+ raise DataError(
+ "'username' and 'password' cannot be passed along with 'credential_"
+ "provider'. Please provide only one of the following arguments: \n"
+ "1. 'password' and (optional) 'username'\n"
+ "2. 'credential_provider'"
+ )
self.pid = os.getpid()
self.path = path
self.db = db
- self.username = username
self.client_name = client_name
+ self.credential_provider = credential_provider
self.password = password
+ self.username = username
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
diff --git a/redis/credentials.py b/redis/credentials.py
new file mode 100644
index 0000000..7ba26dc
--- /dev/null
+++ b/redis/credentials.py
@@ -0,0 +1,26 @@
+from typing import Optional, Tuple, Union
+
+
+class CredentialProvider:
+ """
+ Credentials Provider.
+ """
+
+ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
+ raise NotImplementedError("get_credentials must be implemented")
+
+
+class UsernamePasswordCredentialProvider(CredentialProvider):
+ """
+ Simple implementation of CredentialProvider that just wraps static
+ username and password.
+ """
+
+ def __init__(self, username: Optional[str] = None, password: Optional[str] = None):
+ self.username = username or ""
+ self.password = password or ""
+
+ def get_credentials(self):
+ if self.username:
+ return self.username, self.password
+ return (self.password,)
diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py
new file mode 100644
index 0000000..8e213cd
--- /dev/null
+++ b/tests/test_asyncio/test_credentials.py
@@ -0,0 +1,284 @@
+import functools
+import random
+import string
+from typing import Optional, Tuple, Union
+
+import pytest
+import pytest_asyncio
+
+import redis
+from redis import AuthenticationError, DataError, ResponseError
+from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
+from redis.utils import str_if_bytes
+from tests.conftest import skip_if_redis_enterprise
+
+
+@pytest_asyncio.fixture()
+async def r_acl_teardown(r: redis.Redis):
+ """
+ A special fixture which removes the provided names from the database after use
+ """
+ usernames = []
+
+ def factory(username):
+ usernames.append(username)
+ return r
+
+ yield factory
+ for username in usernames:
+ await r.acl_deluser(username)
+
+
+@pytest_asyncio.fixture()
+async def r_required_pass_teardown(r: redis.Redis):
+ """
+ A special fixture which removes the provided password from the database after use
+ """
+ passwords = []
+
+ def factory(username):
+ passwords.append(username)
+ return r
+
+ yield factory
+ for password in passwords:
+ try:
+ await r.auth(password)
+ except (ResponseError, AuthenticationError):
+ await r.auth("default", "")
+ await r.config_set("requirepass", "")
+
+
+class NoPassCredProvider(CredentialProvider):
+ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
+ return "username", ""
+
+
+class AsyncRandomAuthCredProvider(CredentialProvider):
+ def __init__(self, user: Optional[str], endpoint: str):
+ self.user = user
+ self.endpoint = endpoint
+
+ @functools.lru_cache(maxsize=10)
+ def get_credentials(self) -> Union[Tuple[str, str], Tuple[str]]:
+ def get_random_string(length):
+ letters = string.ascii_lowercase
+ result_str = "".join(random.choice(letters) for i in range(length))
+ return result_str
+
+ if self.user:
+ auth_token: str = get_random_string(5) + self.user + "_" + self.endpoint
+ return self.user, auth_token
+ else:
+ auth_token: str = get_random_string(5) + self.endpoint
+ return (auth_token,)
+
+
+async def init_acl_user(r, username, password):
+ # reset the user
+ await r.acl_deluser(username)
+ if password:
+ assert (
+ await r.acl_setuser(
+ username,
+ enabled=True,
+ passwords=["+" + password],
+ keys="~*",
+ commands=[
+ "+ping",
+ "+command",
+ "+info",
+ "+select",
+ "+flushdb",
+ "+cluster",
+ ],
+ )
+ is True
+ )
+ else:
+ assert (
+ await r.acl_setuser(
+ username,
+ enabled=True,
+ keys="~*",
+ commands=[
+ "+ping",
+ "+command",
+ "+info",
+ "+select",
+ "+flushdb",
+ "+cluster",
+ ],
+ nopass=True,
+ )
+ is True
+ )
+
+
+async def init_required_pass(r, password):
+ await r.config_set("requirepass", password)
+
+
+@pytest.mark.asyncio
+class TestCredentialsProvider:
+ @skip_if_redis_enterprise()
+ async def test_only_pass_without_creds_provider(
+ self, r_required_pass_teardown, create_redis
+ ):
+ # test for default user (`username` is supposed to be optional)
+ password = "password"
+ r = r_required_pass_teardown(password)
+ await init_required_pass(r, password)
+ assert await r.auth(password) is True
+
+ r2 = await create_redis(flushdb=False, password=password)
+
+ assert await r2.ping() is True
+
+ @skip_if_redis_enterprise()
+ async def test_user_and_pass_without_creds_provider(
+ self, r_acl_teardown, create_redis
+ ):
+ """
+ Test backward compatibility with username and password
+ """
+ # test for other users
+ username = "username"
+ password = "password"
+ r = r_acl_teardown(username)
+ await init_acl_user(r, username, password)
+ r2 = await create_redis(flushdb=False, username=username, password=password)
+
+ assert await r2.ping() is True
+
+ @pytest.mark.parametrize("username", ["username", None])
+ @skip_if_redis_enterprise()
+ @pytest.mark.onlynoncluster
+ async def test_credential_provider_with_supplier(
+ self, r_acl_teardown, r_required_pass_teardown, create_redis, username
+ ):
+ creds_provider = AsyncRandomAuthCredProvider(
+ user=username,
+ endpoint="localhost",
+ )
+
+ auth_args = creds_provider.get_credentials()
+ password = auth_args[-1]
+
+ if username:
+ r = r_acl_teardown(username)
+ await init_acl_user(r, username, password)
+ else:
+ r = r_required_pass_teardown(password)
+ await init_required_pass(r, password)
+
+ r2 = await create_redis(flushdb=False, credential_provider=creds_provider)
+
+ assert await r2.ping() is True
+
+ async def test_async_credential_provider_no_password_success(
+ self, r_acl_teardown, create_redis
+ ):
+ username = "username"
+ r = r_acl_teardown(username)
+ await init_acl_user(r, username, "")
+ r2 = await create_redis(
+ flushdb=False,
+ credential_provider=NoPassCredProvider(),
+ )
+ assert await r2.ping() is True
+
+ @pytest.mark.onlynoncluster
+ async def test_credential_provider_no_password_error(
+ self, r_acl_teardown, create_redis
+ ):
+ username = "username"
+ r = r_acl_teardown(username)
+ await init_acl_user(r, username, "password")
+ with pytest.raises(AuthenticationError) as e:
+ await create_redis(
+ flushdb=False,
+ credential_provider=NoPassCredProvider(),
+ single_connection_client=True,
+ )
+ assert e.match("invalid username-password")
+ assert await r.acl_deluser(username)
+
+ @pytest.mark.onlynoncluster
+ async def test_password_and_username_together_with_cred_provider_raise_error(
+ self, r_acl_teardown, create_redis
+ ):
+ username = "username"
+ r = r_acl_teardown(username)
+ await init_acl_user(r, username, "password")
+ cred_provider = UsernamePasswordCredentialProvider(
+ username="username", password="password"
+ )
+ with pytest.raises(DataError) as e:
+ await create_redis(
+ flushdb=False,
+ username="username",
+ password="password",
+ credential_provider=cred_provider,
+ single_connection_client=True,
+ )
+ assert e.match(
+ "'username' and 'password' cannot be passed along with "
+ "'credential_provider'."
+ )
+
+ @pytest.mark.onlynoncluster
+ async def test_change_username_password_on_existing_connection(
+ self, r_acl_teardown, create_redis
+ ):
+ username = "origin_username"
+ password = "origin_password"
+ new_username = "new_username"
+ new_password = "new_password"
+ r = r_acl_teardown(username)
+ await init_acl_user(r, username, password)
+ r2 = await create_redis(flushdb=False, username=username, password=password)
+ assert await r2.ping() is True
+ conn = await r2.connection_pool.get_connection("_")
+ await conn.send_command("PING")
+ assert str_if_bytes(await conn.read_response()) == "PONG"
+ assert conn.username == username
+ assert conn.password == password
+ await init_acl_user(r, new_username, new_password)
+ conn.password = new_password
+ conn.username = new_username
+ await conn.send_command("PING")
+ assert str_if_bytes(await conn.read_response()) == "PONG"
+
+
+@pytest.mark.asyncio
+class TestUsernamePasswordCredentialProvider:
+ async def test_user_pass_credential_provider_acl_user_and_pass(
+ self, r_acl_teardown, create_redis
+ ):
+ username = "username"
+ password = "password"
+ r = r_acl_teardown(username)
+ provider = UsernamePasswordCredentialProvider(username, password)
+ assert provider.username == username
+ assert provider.password == password
+ assert provider.get_credentials() == (username, password)
+ await init_acl_user(r, provider.username, provider.password)
+ r2 = await create_redis(flushdb=False, credential_provider=provider)
+ assert await r2.ping() is True
+
+ async def test_user_pass_provider_only_password(
+ self, r_required_pass_teardown, create_redis
+ ):
+ password = "password"
+ provider = UsernamePasswordCredentialProvider(password=password)
+ r = r_required_pass_teardown(password)
+ assert provider.username == ""
+ assert provider.password == password
+ assert provider.get_credentials() == (password,)
+
+ await init_required_pass(r, password)
+
+ r2 = await create_redis(flushdb=False, credential_provider=provider)
+ assert await r2.auth(provider.password) is True
+ assert await r2.ping() is True
diff --git a/tests/test_credentials.py b/tests/test_credentials.py
new file mode 100644
index 0000000..9aeb1ef
--- /dev/null
+++ b/tests/test_credentials.py
@@ -0,0 +1,245 @@
+import functools
+import random
+import string
+from typing import Optional, Tuple, Union
+
+import pytest
+
+import redis
+from redis import AuthenticationError, DataError, ResponseError
+from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
+from redis.utils import str_if_bytes
+from tests.conftest import _get_client, skip_if_redis_enterprise
+
+
+class NoPassCredProvider(CredentialProvider):
+ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
+ return "username", ""
+
+
+class RandomAuthCredProvider(CredentialProvider):
+ def __init__(self, user: Optional[str], endpoint: str):
+ self.user = user
+ self.endpoint = endpoint
+
+ @functools.lru_cache(maxsize=10)
+ def get_credentials(self) -> Union[Tuple[str, str], Tuple[str]]:
+ def get_random_string(length):
+ letters = string.ascii_lowercase
+ result_str = "".join(random.choice(letters) for i in range(length))
+ return result_str
+
+ if self.user:
+ auth_token: str = get_random_string(5) + self.user + "_" + self.endpoint
+ return self.user, auth_token
+ else:
+ auth_token: str = get_random_string(5) + self.endpoint
+ return (auth_token,)
+
+
+def init_acl_user(r, request, username, password):
+ # reset the user
+ r.acl_deluser(username)
+ if password:
+ assert (
+ r.acl_setuser(
+ username,
+ enabled=True,
+ passwords=["+" + password],
+ keys="~*",
+ commands=[
+ "+ping",
+ "+command",
+ "+info",
+ "+select",
+ "+flushdb",
+ "+cluster",
+ ],
+ )
+ is True
+ )
+ else:
+ assert (
+ r.acl_setuser(
+ username,
+ enabled=True,
+ keys="~*",
+ commands=[
+ "+ping",
+ "+command",
+ "+info",
+ "+select",
+ "+flushdb",
+ "+cluster",
+ ],
+ nopass=True,
+ )
+ is True
+ )
+
+ if request is not None:
+
+ def teardown():
+ r.acl_deluser(username)
+
+ request.addfinalizer(teardown)
+
+
+def init_required_pass(r, request, password):
+ r.config_set("requirepass", password)
+
+ def teardown():
+ try:
+ r.auth(password)
+ except (ResponseError, AuthenticationError):
+ r.auth("default", "")
+ r.config_set("requirepass", "")
+
+ request.addfinalizer(teardown)
+
+
+class TestCredentialsProvider:
+ @skip_if_redis_enterprise()
+ def test_only_pass_without_creds_provider(self, r, request):
+ # test for default user (`username` is supposed to be optional)
+ password = "password"
+ init_required_pass(r, request, password)
+ assert r.auth(password) is True
+
+ r2 = _get_client(redis.Redis, request, flushdb=False, password=password)
+
+ assert r2.ping() is True
+
+ @skip_if_redis_enterprise()
+ def test_user_and_pass_without_creds_provider(self, r, request):
+ """
+ Test backward compatibility with username and password
+ """
+ # test for other users
+ username = "username"
+ password = "password"
+
+ init_acl_user(r, request, username, password)
+ r2 = _get_client(
+ redis.Redis, request, flushdb=False, username=username, password=password
+ )
+
+ assert r2.ping() is True
+
+ @pytest.mark.parametrize("username", ["username", None])
+ @skip_if_redis_enterprise()
+ @pytest.mark.onlynoncluster
+ def test_credential_provider_with_supplier(self, r, request, username):
+ creds_provider = RandomAuthCredProvider(
+ user=username,
+ endpoint="localhost",
+ )
+
+ password = creds_provider.get_credentials()[-1]
+
+ if username:
+ init_acl_user(r, request, username, password)
+ else:
+ init_required_pass(r, request, password)
+
+ r2 = _get_client(
+ redis.Redis, request, flushdb=False, credential_provider=creds_provider
+ )
+
+ assert r2.ping() is True
+
+ def test_credential_provider_no_password_success(self, r, request):
+ init_acl_user(r, request, "username", "")
+ r2 = _get_client(
+ redis.Redis,
+ request,
+ flushdb=False,
+ credential_provider=NoPassCredProvider(),
+ )
+ assert r2.ping() is True
+
+ @pytest.mark.onlynoncluster
+ def test_credential_provider_no_password_error(self, r, request):
+ init_acl_user(r, request, "username", "password")
+ with pytest.raises(AuthenticationError) as e:
+ _get_client(
+ redis.Redis,
+ request,
+ flushdb=False,
+ credential_provider=NoPassCredProvider(),
+ )
+ assert e.match("invalid username-password")
+
+ @pytest.mark.onlynoncluster
+ def test_password_and_username_together_with_cred_provider_raise_error(
+ self, r, request
+ ):
+ init_acl_user(r, request, "username", "password")
+ cred_provider = UsernamePasswordCredentialProvider(
+ username="username", password="password"
+ )
+ with pytest.raises(DataError) as e:
+ _get_client(
+ redis.Redis,
+ request,
+ flushdb=False,
+ username="username",
+ password="password",
+ credential_provider=cred_provider,
+ )
+ assert e.match(
+ "'username' and 'password' cannot be passed along with "
+ "'credential_provider'."
+ )
+
+ @pytest.mark.onlynoncluster
+ def test_change_username_password_on_existing_connection(self, r, request):
+ username = "origin_username"
+ password = "origin_password"
+ new_username = "new_username"
+ new_password = "new_password"
+ init_acl_user(r, request, username, password)
+ r2 = _get_client(
+ redis.Redis, request, flushdb=False, username=username, password=password
+ )
+ assert r2.ping() is True
+ conn = r2.connection_pool.get_connection("_")
+ conn.send_command("PING")
+ assert str_if_bytes(conn.read_response()) == "PONG"
+ assert conn.username == username
+ assert conn.password == password
+ init_acl_user(r, request, new_username, new_password)
+ conn.password = new_password
+ conn.username = new_username
+ conn.send_command("PING")
+ assert str_if_bytes(conn.read_response()) == "PONG"
+
+
+class TestUsernamePasswordCredentialProvider:
+ def test_user_pass_credential_provider_acl_user_and_pass(self, r, request):
+ username = "username"
+ password = "password"
+ provider = UsernamePasswordCredentialProvider(username, password)
+ assert provider.username == username
+ assert provider.password == password
+ assert provider.get_credentials() == (username, password)
+ init_acl_user(r, request, provider.username, provider.password)
+ r2 = _get_client(
+ redis.Redis, request, flushdb=False, credential_provider=provider
+ )
+ assert r2.ping() is True
+
+ def test_user_pass_provider_only_password(self, r, request):
+ password = "password"
+ provider = UsernamePasswordCredentialProvider(password=password)
+ assert provider.username == ""
+ assert provider.password == password
+ assert provider.get_credentials() == (password,)
+
+ init_required_pass(r, request, password)
+
+ r2 = _get_client(
+ redis.Redis, request, flushdb=False, credential_provider=provider
+ )
+ assert r2.auth(provider.password) is True
+ assert r2.ping() is True