diff options
author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2022-12-01 10:59:20 +0100 |
---|---|---|
committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-12-01 11:05:57 +0100 |
commit | d3e746ace5eeea07216da97d9c3801f2fdc43223 (patch) | |
tree | b3763756d824d45447d58dfef987bef4ef6ab94b | |
parent | 149b55fefad03c18589d580ef53d41e7c99408ed (diff) | |
download | django-d3e746ace5eeea07216da97d9c3801f2fdc43223.tar.gz |
Refs #33308 -- Added get_type_oids() hook and simplified registering type handlers on PostgreSQL.
-rw-r--r-- | django/contrib/postgres/signals.py | 57 | ||||
-rw-r--r-- | tests/postgres_tests/test_signals.py | 3 |
2 files changed, 27 insertions, 33 deletions
diff --git a/django/contrib/postgres/signals.py b/django/contrib/postgres/signals.py index b61673fe1f..5c6ca3687a 100644 --- a/django/contrib/postgres/signals.py +++ b/django/contrib/postgres/signals.py @@ -1,22 +1,16 @@ import functools import psycopg2 -from psycopg2 import ProgrammingError from psycopg2.extras import register_hstore from django.db import connections from django.db.backends.base.base import NO_DB_ALIAS -@functools.lru_cache -def get_hstore_oids(connection_alias): - """Return hstore and hstore array OIDs.""" +def get_type_oids(connection_alias, type_name): with connections[connection_alias].cursor() as cursor: cursor.execute( - "SELECT t.oid, typarray " - "FROM pg_type t " - "JOIN pg_namespace ns ON typnamespace = ns.oid " - "WHERE typname = 'hstore'" + "SELECT oid, typarray FROM pg_type WHERE typname = %s", (type_name,) ) oids = [] array_oids = [] @@ -27,42 +21,41 @@ def get_hstore_oids(connection_alias): @functools.lru_cache +def get_hstore_oids(connection_alias): + """Return hstore and hstore array OIDs.""" + return get_type_oids(connection_alias, "hstore") + + +@functools.lru_cache def get_citext_oids(connection_alias): - """Return citext array OIDs.""" - with connections[connection_alias].cursor() as cursor: - cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'") - return tuple(row[0] for row in cursor) + """Return citext and citext array OIDs.""" + return get_type_oids(connection_alias, "citext") def register_type_handlers(connection, **kwargs): if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS: return - try: - oids, array_oids = get_hstore_oids(connection.alias) + oids, array_oids = get_hstore_oids(connection.alias) + # Don't register handlers when hstore is not available on the database. + # + # If someone tries to create an hstore field it will error there. This is + # necessary as someone may be using PSQL without extensions installed but + # be using other features of contrib.postgres. + # + # This is also needed in order to create the connection in order to install + # the hstore extension. + if oids: register_hstore( connection.connection, globally=True, oid=oids, array_oid=array_oids ) - except ProgrammingError: - # Hstore is not available on the database. - # - # If someone tries to create an hstore field it will error there. - # This is necessary as someone may be using PSQL without extensions - # installed but be using other features of contrib.postgres. - # - # This is also needed in order to create the connection in order to - # install the hstore extension. - pass - try: - citext_oids = get_citext_oids(connection.alias) + oids, citext_oids = get_citext_oids(connection.alias) + # Don't register handlers when citext is not available on the database. + # + # The same comments in the above call to register_hstore() also apply here. + if oids: array_type = psycopg2.extensions.new_array_type( citext_oids, "citext[]", psycopg2.STRING ) psycopg2.extensions.register_type(array_type, None) - except ProgrammingError: - # citext is not available on the database. - # - # The same comments in the except block of the above call to - # register_hstore() also apply here. - pass diff --git a/tests/postgres_tests/test_signals.py b/tests/postgres_tests/test_signals.py index 764524d8e6..80a28a9776 100644 --- a/tests/postgres_tests/test_signals.py +++ b/tests/postgres_tests/test_signals.py @@ -34,8 +34,9 @@ class OIDTests(PostgreSQLTestCase): self.assertOIDs(array_oids) def test_citext_values(self): - oids = get_citext_oids(connection.alias) + oids, citext_oids = get_citext_oids(connection.alias) self.assertOIDs(oids) + self.assertOIDs(citext_oids) def test_register_type_handlers_no_db(self): """Registering type handlers for the nodb connection does nothing.""" |