summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordvora-h <67596500+dvora-h@users.noreply.github.com>2022-06-01 14:32:45 +0300
committerGitHub <noreply@github.com>2022-06-01 14:32:45 +0300
commit7880460b72aca49aa5b9512f0995c0d17d884a7d (patch)
tree45f1b580944858202ee3b7e21a9f41a057b1611b
parentedf10043f4a383ec93a2fb51917075ae1bbfbf48 (diff)
downloadredis-py-7880460b72aca49aa5b9512f0995c0d17d884a7d.tar.gz
Add `query_params` to FT.PROFILE (#2198)
* ft.profile query_params * fix pr comments * type hints
-rw-r--r--redis/commands/search/commands.py30
-rw-r--r--tests/test_search.py24
2 files changed, 43 insertions, 11 deletions
diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py
index bf66147..0121436 100644
--- a/redis/commands/search/commands.py
+++ b/redis/commands/search/commands.py
@@ -1,6 +1,6 @@
import itertools
import time
-from typing import Dict, Union
+from typing import Dict, Optional, Union
from redis.client import Pipeline
@@ -363,7 +363,11 @@ class SearchCommands:
it = map(to_string, res)
return dict(zip(it, it))
- def get_params_args(self, query_params: Dict[str, Union[str, int, float]]):
+ def get_params_args(
+ self, query_params: Union[Dict[str, Union[str, int, float]], None]
+ ):
+ if query_params is None:
+ return []
args = []
if len(query_params) > 0:
args.append("params")
@@ -383,8 +387,7 @@ class SearchCommands:
raise ValueError(f"Bad query type {type(query)}")
args += query.get_args()
- if query_params is not None:
- args += self.get_params_args(query_params)
+ args += self.get_params_args(query_params)
return args, query
@@ -459,8 +462,7 @@ class SearchCommands:
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
else:
raise ValueError("Bad query", query)
- if query_params is not None:
- cmd += self.get_params_args(query_params)
+ cmd += self.get_params_args(query_params)
raw = self.execute_command(*cmd)
return self._get_aggregate_result(raw, query, has_cursor)
@@ -485,16 +487,22 @@ class SearchCommands:
return AggregateResult(rows, cursor, schema)
- def profile(self, query, limited=False):
+ def profile(
+ self,
+ query: Union[str, Query, AggregateRequest],
+ limited: bool = False,
+ query_params: Optional[Dict[str, Union[str, int, float]]] = None,
+ ):
"""
Performs a search or aggregate command and collects performance
information.
### Parameters
- **query**: This can be either an `AggregateRequest`, `Query` or
- string.
+ **query**: This can be either an `AggregateRequest`, `Query` or string.
**limited**: If set to True, removes details of reader iterator.
+ **query_params**: Define one or more value parameters.
+ Each parameter has a name and a value.
"""
st = time.time()
@@ -509,6 +517,7 @@ class SearchCommands:
elif isinstance(query, Query):
cmd[2] = "SEARCH"
cmd += query.get_args()
+ cmd += self.get_params_args(query_params)
else:
raise ValueError("Must provide AggregateRequest object or " "Query object.")
@@ -907,8 +916,7 @@ class AsyncSearchCommands(SearchCommands):
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
else:
raise ValueError("Bad query", query)
- if query_params is not None:
- cmd += self.get_params_args(query_params)
+ cmd += self.get_params_args(query_params)
raw = await self.execute_command(*cmd)
return self._get_aggregate_result(raw, query, has_cursor)
diff --git a/tests/test_search.py b/tests/test_search.py
index dba914a..f0a1190 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -1521,6 +1521,30 @@ def test_profile_limited(client):
@pytest.mark.redismod
@skip_ifmodversion_lt("2.4.3", "search")
+def test_profile_query_params(modclient: redis.Redis):
+ modclient.flushdb()
+ modclient.ft().create_index(
+ (
+ VectorField(
+ "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}
+ ),
+ )
+ )
+ modclient.hset("a", "v", "aaaaaaaa")
+ modclient.hset("b", "v", "aaaabaaa")
+ modclient.hset("c", "v", "aaaaabaa")
+ query = "*=>[KNN 2 @v $vec]"
+ q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2)
+ res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"})
+ assert det["Iterators profile"]["Counter"] == 2.0
+ assert det["Iterators profile"]["Type"] == "VECTOR"
+ assert res.total == 2
+ assert "a" == res.docs[0].id
+ assert "0" == res.docs[0].__getattribute__("__v_score")
+
+
+@pytest.mark.redismod
+@skip_ifmodversion_lt("2.4.3", "search")
def test_vector_field(modclient):
modclient.flushdb()
modclient.ft().create_index(