diff options
author | Tim Graham <timograham@gmail.com> | 2016-04-22 13:43:05 -0400 |
---|---|---|
committer | Tim Graham <timograham@gmail.com> | 2016-05-02 07:58:24 -0400 |
commit | eab5df12b664b154b2e280330aa43d8c0621b94a (patch) | |
tree | 73a6bd2a322ff0c62d17287b1c9a0cfd263102b1 /django/db | |
parent | eb5d7bc2f465b055f4eb54a3d238502bdddb6d7e (diff) | |
download | django-eab5df12b664b154b2e280330aa43d8c0621b94a.tar.gz |
Refs #22936 -- Moved more of Field.get_db_prep_lookup() to lookups.
Diffstat (limited to 'django/db')
-rw-r--r-- | django/db/models/fields/__init__.py | 9 | ||||
-rw-r--r-- | django/db/models/lookups.py | 46 |
2 files changed, 38 insertions, 17 deletions
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index ec6cf3cf29..5951681e81 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -783,14 +783,7 @@ class Field(RegisterLookupMixin): value = self.get_prep_lookup(lookup_type, value) prepared = True - if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): - return [self.get_db_prep_value(value, connection=connection, - prepared=prepared)] - elif lookup_type in ('range', 'in'): - return [self.get_db_prep_value(v, connection=connection, - prepared=prepared) for v in value] - else: - return [value] + return [value] def has_default(self): """ diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index ad55a0138d..62acfcb887 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -51,8 +51,7 @@ class Lookup(object): sqls.append(sql) sqls_params.extend(sql_params) else: - params = self.lhs.output_field.get_db_prep_lookup( - self.lookup_name, rhs, connection, prepared=True) + _, params = self.get_db_prep_lookup(rhs, connection) sqls, sqls_params = ['%s'] * len(params), params return sqls, sqls_params @@ -164,7 +163,36 @@ class BuiltinLookup(Lookup): return connection.operators[self.lookup_name] % rhs -class Exact(BuiltinLookup): +class FieldGetDbPrepValueMixin(object): + """ + Some lookups require Field.get_db_prep_value() to be called on their + inputs. + """ + get_db_prep_lookup_value_is_iterable = False + + def get_db_prep_lookup(self, value, connection): + # For relational fields, use the output_field of the 'field' attribute. + field = getattr(self.lhs.output_field, 'field', None) + get_db_prep_value = getattr(field, 'get_db_prep_value', None) + if not get_db_prep_value: + get_db_prep_value = self.lhs.output_field.get_db_prep_value + return ( + '%s', + [get_db_prep_value(v, connection, prepared=True) for v in value] + if self.get_db_prep_lookup_value_is_iterable else + [get_db_prep_value(value, connection, prepared=True)] + ) + + +class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): + """ + Some lookups require Field.get_db_prep_value() to be called on each value + in an iterable. + """ + get_db_prep_lookup_value_is_iterable = True + + +class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'exact' Field.register_lookup(Exact) @@ -182,22 +210,22 @@ class IExact(BuiltinLookup): Field.register_lookup(IExact) -class GreaterThan(BuiltinLookup): +class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'gt' Field.register_lookup(GreaterThan) -class GreaterThanOrEqual(BuiltinLookup): +class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'gte' Field.register_lookup(GreaterThanOrEqual) -class LessThan(BuiltinLookup): +class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'lt' Field.register_lookup(LessThan) -class LessThanOrEqual(BuiltinLookup): +class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'lte' Field.register_lookup(LessThanOrEqual) @@ -223,7 +251,7 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan): IntegerField.register_lookup(IntegerLessThan) -class In(BuiltinLookup): +class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'in' def process_rhs(self, compiler, connection): @@ -365,7 +393,7 @@ class IEndsWith(PatternLookup): Field.register_lookup(IEndsWith) -class Range(BuiltinLookup): +class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'range' def get_rhs_op(self, connection, rhs): |