summaryrefslogtreecommitdiff
path: root/tests/custom_lookups
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /tests/custom_lookups
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
downloaddjango-9c19aff7c7561e3a82978a272ecdaad40dda5c00.tar.gz
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'tests/custom_lookups')
-rw-r--r--tests/custom_lookups/tests.py428
1 files changed, 249 insertions, 179 deletions
diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py
index 1be4de790d..2af145595e 100644
--- a/tests/custom_lookups/tests.py
+++ b/tests/custom_lookups/tests.py
@@ -12,31 +12,31 @@ from .models import Article, Author, MySQLUnixTimestamp
class Div3Lookup(models.Lookup):
- lookup_name = 'div3'
+ lookup_name = "div3"
def as_sql(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
- return '(%s) %%%% 3 = %s' % (lhs, rhs), params
+ return "(%s) %%%% 3 = %s" % (lhs, rhs), params
def as_oracle(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
- return 'mod(%s, 3) = %s' % (lhs, rhs), params
+ return "mod(%s, 3) = %s" % (lhs, rhs), params
class Div3Transform(models.Transform):
- lookup_name = 'div3'
+ lookup_name = "div3"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
- return '(%s) %%%% 3' % lhs, lhs_params
+ return "(%s) %%%% 3" % lhs, lhs_params
def as_oracle(self, compiler, connection, **extra_context):
lhs, lhs_params = compiler.compile(self.lhs)
- return 'mod(%s, 3)' % lhs, lhs_params
+ return "mod(%s, 3)" % lhs, lhs_params
class Div3BilateralTransform(Div3Transform):
@@ -45,37 +45,37 @@ class Div3BilateralTransform(Div3Transform):
class Mult3BilateralTransform(models.Transform):
bilateral = True
- lookup_name = 'mult3'
+ lookup_name = "mult3"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
- return '3 * (%s)' % lhs, lhs_params
+ return "3 * (%s)" % lhs, lhs_params
class LastDigitTransform(models.Transform):
- lookup_name = 'lastdigit'
+ lookup_name = "lastdigit"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
- return 'SUBSTR(CAST(%s AS CHAR(2)), 2, 1)' % lhs, lhs_params
+ return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params
class UpperBilateralTransform(models.Transform):
bilateral = True
- lookup_name = 'upper'
+ lookup_name = "upper"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
- return 'UPPER(%s)' % lhs, lhs_params
+ return "UPPER(%s)" % lhs, lhs_params
class YearTransform(models.Transform):
# Use a name that avoids collision with the built-in year lookup.
- lookup_name = 'testyear'
+ lookup_name = "testyear"
def as_sql(self, compiler, connection):
lhs_sql, params = compiler.compile(self.lhs)
- return connection.ops.date_extract_sql('year', lhs_sql), params
+ return connection.ops.date_extract_sql("year", lhs_sql), params
@property
def output_field(self):
@@ -84,7 +84,7 @@ class YearTransform(models.Transform):
@YearTransform.register_lookup
class YearExact(models.lookups.Lookup):
- lookup_name = 'exact'
+ lookup_name = "exact"
def as_sql(self, compiler, connection):
# We will need to skip the extract part, and instead go
@@ -96,9 +96,12 @@ class YearExact(models.lookups.Lookup):
params = lhs_params + rhs_params + lhs_params + rhs_params
# We use PostgreSQL specific SQL here. Note that we must do the
# conversions in SQL instead of in Python to support F() references.
- return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
- "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
- {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+ return (
+ "%(lhs)s >= (%(rhs)s || '-01-01')::date "
+ "AND %(lhs)s <= (%(rhs)s || '-12-31')::date"
+ % {"lhs": lhs_sql, "rhs": rhs_sql},
+ params,
+ )
@YearTransform.register_lookup
@@ -125,15 +128,16 @@ class Exactly(models.lookups.Exact):
"""
This lookup is used to test lookup registration.
"""
- lookup_name = 'exactly'
+
+ lookup_name = "exactly"
def get_rhs_op(self, connection, rhs):
- return connection.operators['exact'] % rhs
+ return connection.operators["exact"] % rhs
class SQLFuncMixin:
def as_sql(self, compiler, connection):
- return '%s()' % self.name, []
+ return "%s()" % self.name, []
@property
def output_field(self):
@@ -153,28 +157,26 @@ class SQLFuncTransform(SQLFuncMixin, models.Transform):
class SQLFuncFactory:
-
def __init__(self, key, name):
self.key = key
self.name = name
def __call__(self, *args, **kwargs):
- if self.key == 'lookupfunc':
+ if self.key == "lookupfunc":
return SQLFuncLookup(self.name, *args, **kwargs)
return SQLFuncTransform(self.name, *args, **kwargs)
class CustomField(models.TextField):
-
def get_lookup(self, lookup_name):
- if lookup_name.startswith('lookupfunc_'):
- key, name = lookup_name.split('_', 1)
+ if lookup_name.startswith("lookupfunc_"):
+ key, name = lookup_name.split("_", 1)
return SQLFuncFactory(key, name)
return super().get_lookup(lookup_name)
def get_transform(self, lookup_name):
- if lookup_name.startswith('transformfunc_'):
- key, name = lookup_name.split('_', 1)
+ if lookup_name.startswith("transformfunc_"):
+ key, name = lookup_name.split("_", 1)
return SQLFuncFactory(key, name)
return super().get_transform(lookup_name)
@@ -190,7 +192,8 @@ class InMonth(models.lookups.Lookup):
"""
InMonth matches if the column's month is the same as value's month.
"""
- lookup_name = 'inmonth'
+
+ lookup_name = "inmonth"
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
@@ -198,13 +201,15 @@ class InMonth(models.lookups.Lookup):
# We need to be careful so that we get the params in right
# places.
params = lhs_params + rhs_params + lhs_params + rhs_params
- return ("%s >= date_trunc('month', %s) and "
- "%s < date_trunc('month', %s) + interval '1 months'" %
- (lhs, rhs, lhs, rhs), params)
+ return (
+ "%s >= date_trunc('month', %s) and "
+ "%s < date_trunc('month', %s) + interval '1 months'" % (lhs, rhs, lhs, rhs),
+ params,
+ )
class DateTimeTransform(models.Transform):
- lookup_name = 'as_datetime'
+ lookup_name = "as_datetime"
@property
def output_field(self):
@@ -212,18 +217,18 @@ class DateTimeTransform(models.Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
- return 'from_unixtime({})'.format(lhs), params
+ return "from_unixtime({})".format(lhs), params
class LookupTests(TestCase):
-
def test_custom_name_lookup(self):
- a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
- Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
- with register_lookup(models.DateField, YearTransform), \
- register_lookup(models.DateField, YearTransform, lookup_name='justtheyear'), \
- register_lookup(YearTransform, Exactly), \
- register_lookup(YearTransform, Exactly, lookup_name='isactually'):
+ a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
+ Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
+ with register_lookup(models.DateField, YearTransform), register_lookup(
+ models.DateField, YearTransform, lookup_name="justtheyear"
+ ), register_lookup(YearTransform, Exactly), register_lookup(
+ YearTransform, Exactly, lookup_name="isactually"
+ ):
qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
self.assertSequenceEqual(qs1, [a1])
@@ -234,175 +239,208 @@ class LookupTests(TestCase):
__exact=None is transformed to __isnull=True if a custom lookup class
with lookup_name != 'exact' is registered as the `exact` lookup.
"""
- field = Author._meta.get_field('birthdate')
- OldExactLookup = field.get_lookup('exact')
- author = Author.objects.create(name='author', birthdate=None)
+ field = Author._meta.get_field("birthdate")
+ OldExactLookup = field.get_lookup("exact")
+ author = Author.objects.create(name="author", birthdate=None)
try:
- field.register_lookup(Exactly, 'exact')
+ field.register_lookup(Exactly, "exact")
self.assertEqual(Author.objects.get(birthdate__exact=None), author)
finally:
- field.register_lookup(OldExactLookup, 'exact')
+ field.register_lookup(OldExactLookup, "exact")
def test_basic_lookup(self):
- a1 = Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- a3 = Author.objects.create(name='a3', age=3)
- a4 = Author.objects.create(name='a4', age=4)
+ a1 = Author.objects.create(name="a1", age=1)
+ a2 = Author.objects.create(name="a2", age=2)
+ a3 = Author.objects.create(name="a3", age=3)
+ a4 = Author.objects.create(name="a4", age=4)
with register_lookup(models.IntegerField, Div3Lookup):
self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3])
- self.assertSequenceEqual(Author.objects.filter(age__div3=1).order_by('age'), [a1, a4])
+ self.assertSequenceEqual(
+ Author.objects.filter(age__div3=1).order_by("age"), [a1, a4]
+ )
self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2])
self.assertSequenceEqual(Author.objects.filter(age__div3=3), [])
- @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ @unittest.skipUnless(
+ connection.vendor == "postgresql", "PostgreSQL specific SQL used"
+ )
def test_birthdate_month(self):
- a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
- a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
- a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
- a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
+ a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
+ a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
+ a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
+ a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
with register_lookup(models.DateField, InMonth):
- self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3])
- self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2])
- self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1])
- self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4])
- self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), [])
+ self.assertSequenceEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3]
+ )
+ self.assertSequenceEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2]
+ )
+ self.assertSequenceEqual(
+ Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1]
+ )
+ self.assertSequenceEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4]
+ )
+ self.assertSequenceEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), []
+ )
def test_div3_extract(self):
with register_lookup(models.IntegerField, Div3Transform):
- a1 = Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- a3 = Author.objects.create(name='a3', age=3)
- a4 = Author.objects.create(name='a4', age=4)
- baseqs = Author.objects.order_by('name')
+ a1 = Author.objects.create(name="a1", age=1)
+ a2 = Author.objects.create(name="a2", age=2)
+ a3 = Author.objects.create(name="a3", age=3)
+ a4 = Author.objects.create(name="a4", age=4)
+ baseqs = Author.objects.order_by("name")
self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4])
self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2])
self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [])
- self.assertSequenceEqual(baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4])
+ self.assertSequenceEqual(
+ baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
+ )
def test_foreignobject_lookup_registration(self):
- field = Article._meta.get_field('author')
+ field = Article._meta.get_field("author")
with register_lookup(models.ForeignObject, Exactly):
- self.assertIs(field.get_lookup('exactly'), Exactly)
+ self.assertIs(field.get_lookup("exactly"), Exactly)
# ForeignObject should ignore regular Field lookups
with register_lookup(models.Field, Exactly):
- self.assertIsNone(field.get_lookup('exactly'))
+ self.assertIsNone(field.get_lookup("exactly"))
def test_lookups_caching(self):
- field = Article._meta.get_field('author')
+ field = Article._meta.get_field("author")
# clear and re-cache
field.get_lookups.cache_clear()
- self.assertNotIn('exactly', field.get_lookups())
+ self.assertNotIn("exactly", field.get_lookups())
# registration should bust the cache
with register_lookup(models.ForeignObject, Exactly):
# getting the lookups again should re-cache
- self.assertIn('exactly', field.get_lookups())
+ self.assertIn("exactly", field.get_lookups())
class BilateralTransformTests(TestCase):
-
def test_bilateral_upper(self):
with register_lookup(models.CharField, UpperBilateralTransform):
- author1 = Author.objects.create(name='Doe')
- author2 = Author.objects.create(name='doe')
- author3 = Author.objects.create(name='Foo')
+ author1 = Author.objects.create(name="Doe")
+ author2 = Author.objects.create(name="doe")
+ author3 = Author.objects.create(name="Foo")
self.assertCountEqual(
- Author.objects.filter(name__upper='doe'),
+ Author.objects.filter(name__upper="doe"),
[author1, author2],
)
self.assertSequenceEqual(
- Author.objects.filter(name__upper__contains='f'),
+ Author.objects.filter(name__upper__contains="f"),
[author3],
)
def test_bilateral_inner_qs(self):
with register_lookup(models.CharField, UpperBilateralTransform):
- msg = 'Bilateral transformations on nested querysets are not implemented.'
+ msg = "Bilateral transformations on nested querysets are not implemented."
with self.assertRaisesMessage(NotImplementedError, msg):
- Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
+ Author.objects.filter(
+ name__upper__in=Author.objects.values_list("name")
+ )
def test_bilateral_multi_value(self):
with register_lookup(models.CharField, UpperBilateralTransform):
- Author.objects.bulk_create([
- Author(name='Foo'),
- Author(name='Bar'),
- Author(name='Ray'),
- ])
+ Author.objects.bulk_create(
+ [
+ Author(name="Foo"),
+ Author(name="Bar"),
+ Author(name="Ray"),
+ ]
+ )
self.assertQuerysetEqual(
- Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'),
- ['Bar', 'Foo'],
- lambda a: a.name
+ Author.objects.filter(name__upper__in=["foo", "bar", "doe"]).order_by(
+ "name"
+ ),
+ ["Bar", "Foo"],
+ lambda a: a.name,
)
def test_div3_bilateral_extract(self):
with register_lookup(models.IntegerField, Div3BilateralTransform):
- a1 = Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- a3 = Author.objects.create(name='a3', age=3)
- a4 = Author.objects.create(name='a4', age=4)
- baseqs = Author.objects.order_by('name')
+ a1 = Author.objects.create(name="a1", age=1)
+ a2 = Author.objects.create(name="a2", age=2)
+ a3 = Author.objects.create(name="a3", age=3)
+ a4 = Author.objects.create(name="a4", age=4)
+ baseqs = Author.objects.order_by("name")
self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3])
self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4])
self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4])
- self.assertSequenceEqual(baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4])
+ self.assertSequenceEqual(
+ baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
+ )
def test_bilateral_order(self):
- with register_lookup(models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform):
- a1 = Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- a3 = Author.objects.create(name='a3', age=3)
- a4 = Author.objects.create(name='a4', age=4)
- baseqs = Author.objects.order_by('name')
+ with register_lookup(
+ models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform
+ ):
+ a1 = Author.objects.create(name="a1", age=1)
+ a2 = Author.objects.create(name="a2", age=2)
+ a3 = Author.objects.create(name="a3", age=3)
+ a4 = Author.objects.create(name="a4", age=4)
+ baseqs = Author.objects.order_by("name")
# mult3__div3 always leads to 0
- self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4])
+ self.assertSequenceEqual(
+ baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]
+ )
self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
def test_transform_order_by(self):
with register_lookup(models.IntegerField, LastDigitTransform):
- a1 = Author.objects.create(name='a1', age=11)
- a2 = Author.objects.create(name='a2', age=23)
- a3 = Author.objects.create(name='a3', age=32)
- a4 = Author.objects.create(name='a4', age=40)
- qs = Author.objects.order_by('age__lastdigit')
+ a1 = Author.objects.create(name="a1", age=11)
+ a2 = Author.objects.create(name="a2", age=23)
+ a3 = Author.objects.create(name="a3", age=32)
+ a4 = Author.objects.create(name="a4", age=40)
+ qs = Author.objects.order_by("age__lastdigit")
self.assertSequenceEqual(qs, [a4, a1, a3, a2])
def test_bilateral_fexpr(self):
with register_lookup(models.IntegerField, Mult3BilateralTransform):
- a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
- a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
- a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
- a4 = Author.objects.create(name='a4', age=4)
- baseqs = Author.objects.order_by('name')
- self.assertSequenceEqual(baseqs.filter(age__mult3=models.F('age')), [a1, a2, a3, a4])
+ a1 = Author.objects.create(name="a1", age=1, average_rating=3.2)
+ a2 = Author.objects.create(name="a2", age=2, average_rating=0.5)
+ a3 = Author.objects.create(name="a3", age=3, average_rating=1.5)
+ a4 = Author.objects.create(name="a4", age=4)
+ baseqs = Author.objects.order_by("name")
+ self.assertSequenceEqual(
+ baseqs.filter(age__mult3=models.F("age")), [a1, a2, a3, a4]
+ )
# Same as age >= average_rating
- self.assertSequenceEqual(baseqs.filter(age__mult3__gte=models.F('average_rating')), [a2, a3])
+ self.assertSequenceEqual(
+ baseqs.filter(age__mult3__gte=models.F("average_rating")), [a2, a3]
+ )
@override_settings(USE_TZ=True)
class DateTimeLookupTests(TestCase):
- @unittest.skipUnless(connection.vendor == 'mysql', "MySQL specific SQL used")
+ @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific SQL used")
def test_datetime_output_field(self):
with register_lookup(models.PositiveIntegerField, DateTimeTransform):
ut = MySQLUnixTimestamp.objects.create(timestamp=time.time())
y2k = timezone.make_aware(datetime(2000, 1, 1))
- self.assertSequenceEqual(MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut])
+ self.assertSequenceEqual(
+ MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut]
+ )
class YearLteTests(TestCase):
@classmethod
def setUpTestData(cls):
- cls.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
- cls.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
- cls.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
- cls.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
+ cls.a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
+ cls.a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
+ cls.a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
+ cls.a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
def setUp(self):
models.DateField.register_lookup(YearTransform)
@@ -410,18 +448,29 @@ class YearLteTests(TestCase):
def tearDown(self):
models.DateField._unregister_lookup(YearTransform)
- @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ @unittest.skipUnless(
+ connection.vendor == "postgresql", "PostgreSQL specific SQL used"
+ )
def test_year_lte(self):
- baseqs = Author.objects.order_by('name')
- self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lte=2012), [self.a1, self.a2, self.a3, self.a4])
- self.assertSequenceEqual(baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4])
-
- self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__testyear=2012).query))
- self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lte=2011), [self.a1])
+ baseqs = Author.objects.order_by("name")
+ self.assertSequenceEqual(
+ baseqs.filter(birthdate__testyear__lte=2012),
+ [self.a1, self.a2, self.a3, self.a4],
+ )
+ self.assertSequenceEqual(
+ baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4]
+ )
+
+ self.assertNotIn("BETWEEN", str(baseqs.filter(birthdate__testyear=2012).query))
+ self.assertSequenceEqual(
+ baseqs.filter(birthdate__testyear__lte=2011), [self.a1]
+ )
# The non-optimized version works, too.
self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1])
- @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ @unittest.skipUnless(
+ connection.vendor == "postgresql", "PostgreSQL specific SQL used"
+ )
def test_year_lte_fexpr(self):
self.a2.age = 2011
self.a2.save()
@@ -429,44 +478,52 @@ class YearLteTests(TestCase):
self.a3.save()
self.a4.age = 2013
self.a4.save()
- baseqs = Author.objects.order_by('name')
- self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lte=models.F('age')), [self.a3, self.a4])
- self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=models.F('age')), [self.a4])
+ baseqs = Author.objects.order_by("name")
+ self.assertSequenceEqual(
+ baseqs.filter(birthdate__testyear__lte=models.F("age")), [self.a3, self.a4]
+ )
+ self.assertSequenceEqual(
+ baseqs.filter(birthdate__testyear__lt=models.F("age")), [self.a4]
+ )
def test_year_lte_sql(self):
# This test will just check the generated SQL for __lte. This
# doesn't require running on PostgreSQL and spots the most likely
# error - not running YearLte SQL at all.
- baseqs = Author.objects.order_by('name')
- self.assertIn(
- '<= (2011 || ', str(baseqs.filter(birthdate__testyear__lte=2011).query))
+ baseqs = Author.objects.order_by("name")
self.assertIn(
- '-12-31', str(baseqs.filter(birthdate__testyear__lte=2011).query))
+ "<= (2011 || ", str(baseqs.filter(birthdate__testyear__lte=2011).query)
+ )
+ self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear__lte=2011).query))
def test_postgres_year_exact(self):
- baseqs = Author.objects.order_by('name')
- self.assertIn(
- '= (2011 || ', str(baseqs.filter(birthdate__testyear=2011).query))
- self.assertIn(
- '-12-31', str(baseqs.filter(birthdate__testyear=2011).query))
+ baseqs = Author.objects.order_by("name")
+ self.assertIn("= (2011 || ", str(baseqs.filter(birthdate__testyear=2011).query))
+ self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear=2011).query))
def test_custom_implementation_year_exact(self):
try:
# Two ways to add a customized implementation for different backends:
# First is MonkeyPatch of the class.
def as_custom_sql(self, compiler, connection):
- lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
+ lhs_sql, lhs_params = self.process_lhs(
+ compiler, connection, self.lhs.lhs
+ )
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params
- return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
- "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
- {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
- setattr(YearExact, 'as_' + connection.vendor, as_custom_sql)
+ return (
+ "%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
+ "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
+ % {"lhs": lhs_sql, "rhs": rhs_sql},
+ params,
+ )
+
+ setattr(YearExact, "as_" + connection.vendor, as_custom_sql)
self.assertIn(
- 'concat(',
- str(Author.objects.filter(birthdate__testyear=2012).query))
+ "concat(", str(Author.objects.filter(birthdate__testyear=2012).query)
+ )
finally:
- delattr(YearExact, 'as_' + connection.vendor)
+ delattr(YearExact, "as_" + connection.vendor)
try:
# The other way is to subclass the original lookup and register the subclassed
# lookup instead of the original.
@@ -475,17 +532,27 @@ class YearLteTests(TestCase):
# and so on, but as we don't know which DB we are running on, we need to use
# setattr.
def as_custom_sql(self, compiler, connection):
- lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
+ lhs_sql, lhs_params = self.process_lhs(
+ compiler, connection, self.lhs.lhs
+ )
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params
- return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
- "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
- {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
- setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql)
+ return (
+ "%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
+ "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
+ % {"lhs": lhs_sql, "rhs": rhs_sql},
+ params,
+ )
+
+ setattr(
+ CustomYearExact,
+ "as_" + connection.vendor,
+ CustomYearExact.as_custom_sql,
+ )
YearTransform.register_lookup(CustomYearExact)
self.assertIn(
- 'CONCAT(',
- str(Author.objects.filter(birthdate__testyear=2012).query))
+ "CONCAT(", str(Author.objects.filter(birthdate__testyear=2012).query)
+ )
finally:
YearTransform._unregister_lookup(CustomYearExact)
YearTransform.register_lookup(YearExact)
@@ -493,23 +560,23 @@ class YearLteTests(TestCase):
class TrackCallsYearTransform(YearTransform):
# Use a name that avoids collision with the built-in year lookup.
- lookup_name = 'testyear'
+ lookup_name = "testyear"
call_order = []
def as_sql(self, compiler, connection):
lhs_sql, params = compiler.compile(self.lhs)
- return connection.ops.date_extract_sql('year', lhs_sql), params
+ return connection.ops.date_extract_sql("year", lhs_sql), params
@property
def output_field(self):
return models.IntegerField()
def get_lookup(self, lookup_name):
- self.call_order.append('lookup')
+ self.call_order.append("lookup")
return super().get_lookup(lookup_name)
def get_transform(self, lookup_name):
- self.call_order.append('transform')
+ self.call_order.append("transform")
return super().get_transform(lookup_name)
@@ -520,51 +587,54 @@ class LookupTransformCallOrderTests(SimpleTestCase):
msg = "Unsupported lookup 'junk' for IntegerField or join on the field not permitted."
with self.assertRaisesMessage(FieldError, msg):
Author.objects.filter(birthdate__testyear__junk=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['lookup', 'transform'])
+ self.assertEqual(
+ TrackCallsYearTransform.call_order, ["lookup", "transform"]
+ )
TrackCallsYearTransform.call_order = []
# junk transform - tries transform only, then fails
with self.assertRaisesMessage(FieldError, msg):
Author.objects.filter(birthdate__testyear__junk__more_junk=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['transform'])
+ self.assertEqual(TrackCallsYearTransform.call_order, ["transform"])
TrackCallsYearTransform.call_order = []
# Just getting the year (implied __exact) - lookup only
Author.objects.filter(birthdate__testyear=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['lookup'])
+ self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
TrackCallsYearTransform.call_order = []
# Just getting the year (explicit __exact) - lookup only
Author.objects.filter(birthdate__testyear__exact=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['lookup'])
+ self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
class CustomisedMethodsTests(SimpleTestCase):
-
def test_overridden_get_lookup(self):
q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
- self.assertIn('monkeys()', str(q.query))
+ self.assertIn("monkeys()", str(q.query))
def test_overridden_get_transform(self):
q = CustomModel.objects.filter(field__transformfunc_banana=3)
- self.assertIn('banana()', str(q.query))
+ self.assertIn("banana()", str(q.query))
def test_overridden_get_lookup_chain(self):
- q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
- self.assertIn('elephants()', str(q.query))
+ q = CustomModel.objects.filter(
+ field__transformfunc_banana__lookupfunc_elephants=3
+ )
+ self.assertIn("elephants()", str(q.query))
def test_overridden_get_transform_chain(self):
- q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
- self.assertIn('pear()', str(q.query))
+ q = CustomModel.objects.filter(
+ field__transformfunc_banana__transformfunc_pear=3
+ )
+ self.assertIn("pear()", str(q.query))
class SubqueryTransformTests(TestCase):
def test_subquery_usage(self):
with register_lookup(models.IntegerField, Div3Transform):
- Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- Author.objects.create(name='a3', age=3)
- Author.objects.create(name='a4', age=4)
- qs = Author.objects.order_by('name').filter(id__in=Author.objects.filter(age__div3=2))
+ Author.objects.create(name="a1", age=1)
+ a2 = Author.objects.create(name="a2", age=2)
+ Author.objects.create(name="a3", age=3)
+ Author.objects.create(name="a4", age=4)
+ qs = Author.objects.order_by("name").filter(
+ id__in=Author.objects.filter(age__div3=2)
+ )
self.assertSequenceEqual(qs, [a2])