summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
authorTim Graham <timograham@gmail.com>2016-04-22 13:43:05 -0400
committerTim Graham <timograham@gmail.com>2016-05-02 07:58:24 -0400
commiteab5df12b664b154b2e280330aa43d8c0621b94a (patch)
tree73a6bd2a322ff0c62d17287b1c9a0cfd263102b1 /django/db
parenteb5d7bc2f465b055f4eb54a3d238502bdddb6d7e (diff)
downloaddjango-eab5df12b664b154b2e280330aa43d8c0621b94a.tar.gz
Refs #22936 -- Moved more of Field.get_db_prep_lookup() to lookups.
Diffstat (limited to 'django/db')
-rw-r--r--django/db/models/fields/__init__.py9
-rw-r--r--django/db/models/lookups.py46
2 files changed, 38 insertions, 17 deletions
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index ec6cf3cf29..5951681e81 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -783,14 +783,7 @@ class Field(RegisterLookupMixin):
value = self.get_prep_lookup(lookup_type, value)
prepared = True
- if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
- return [self.get_db_prep_value(value, connection=connection,
- prepared=prepared)]
- elif lookup_type in ('range', 'in'):
- return [self.get_db_prep_value(v, connection=connection,
- prepared=prepared) for v in value]
- else:
- return [value]
+ return [value]
def has_default(self):
"""
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index ad55a0138d..62acfcb887 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -51,8 +51,7 @@ class Lookup(object):
sqls.append(sql)
sqls_params.extend(sql_params)
else:
- params = self.lhs.output_field.get_db_prep_lookup(
- self.lookup_name, rhs, connection, prepared=True)
+ _, params = self.get_db_prep_lookup(rhs, connection)
sqls, sqls_params = ['%s'] * len(params), params
return sqls, sqls_params
@@ -164,7 +163,36 @@ class BuiltinLookup(Lookup):
return connection.operators[self.lookup_name] % rhs
-class Exact(BuiltinLookup):
+class FieldGetDbPrepValueMixin(object):
+ """
+ Some lookups require Field.get_db_prep_value() to be called on their
+ inputs.
+ """
+ get_db_prep_lookup_value_is_iterable = False
+
+ def get_db_prep_lookup(self, value, connection):
+ # For relational fields, use the output_field of the 'field' attribute.
+ field = getattr(self.lhs.output_field, 'field', None)
+ get_db_prep_value = getattr(field, 'get_db_prep_value', None)
+ if not get_db_prep_value:
+ get_db_prep_value = self.lhs.output_field.get_db_prep_value
+ return (
+ '%s',
+ [get_db_prep_value(v, connection, prepared=True) for v in value]
+ if self.get_db_prep_lookup_value_is_iterable else
+ [get_db_prep_value(value, connection, prepared=True)]
+ )
+
+
+class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
+ """
+ Some lookups require Field.get_db_prep_value() to be called on each value
+ in an iterable.
+ """
+ get_db_prep_lookup_value_is_iterable = True
+
+
+class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'exact'
Field.register_lookup(Exact)
@@ -182,22 +210,22 @@ class IExact(BuiltinLookup):
Field.register_lookup(IExact)
-class GreaterThan(BuiltinLookup):
+class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'gt'
Field.register_lookup(GreaterThan)
-class GreaterThanOrEqual(BuiltinLookup):
+class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'gte'
Field.register_lookup(GreaterThanOrEqual)
-class LessThan(BuiltinLookup):
+class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'lt'
Field.register_lookup(LessThan)
-class LessThanOrEqual(BuiltinLookup):
+class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'lte'
Field.register_lookup(LessThanOrEqual)
@@ -223,7 +251,7 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
IntegerField.register_lookup(IntegerLessThan)
-class In(BuiltinLookup):
+class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'in'
def process_rhs(self, compiler, connection):
@@ -365,7 +393,7 @@ class IEndsWith(PatternLookup):
Field.register_lookup(IEndsWith)
-class Range(BuiltinLookup):
+class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'range'
def get_rhs_op(self, connection, rhs):