diff options
Diffstat (limited to 'django/db/backends/base/schema.py')
-rw-r--r-- | django/db/backends/base/schema.py | 88 |
1 files changed, 81 insertions, 7 deletions
diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 6b03450e2f..01b56151be 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -12,7 +12,7 @@ from django.db.backends.ddl_references import ( Table, ) from django.db.backends.utils import names_digest, split_identifier, truncate_name -from django.db.models import Deferrable, Index +from django.db.models import NOT_PROVIDED, Deferrable, Index from django.db.models.sql import Query from django.db.transaction import TransactionManagementError, atomic from django.utils import timezone @@ -296,6 +296,12 @@ class BaseDatabaseSchemaEditor: yield self._comment_sql(field.db_comment) # Work out nullability. null = field.null + # Add database default. + if field.db_default is not NOT_PROVIDED: + default_sql, default_params = self.db_default_sql(field) + yield f"DEFAULT {default_sql}" + params.extend(default_params) + include_default = False # Include a default value, if requested. include_default = ( include_default @@ -400,6 +406,22 @@ class BaseDatabaseSchemaEditor: """ return "%s" + def db_default_sql(self, field): + """Return the sql and params for the field's database default.""" + from django.db.models.expressions import Value + + sql = "%s" if isinstance(field.db_default, Value) else "(%s)" + query = Query(model=field.model) + compiler = query.get_compiler(connection=self.connection) + default_sql, params = compiler.compile(field.db_default) + if self.connection.features.requires_literal_defaults: + # Some databases doesn't support parameterized defaults (Oracle, + # SQLite). If this is the case, the individual schema backend + # should implement prepare_default(). + default_sql %= tuple(self.prepare_default(p) for p in params) + params = [] + return sql % default_sql, params + @staticmethod def _effective_default(field): # This method allows testing its logic without a connection. @@ -1025,6 +1047,21 @@ class BaseDatabaseSchemaEditor: ) actions.append(fragment) post_actions.extend(other_actions) + + if new_field.db_default is not NOT_PROVIDED: + if ( + old_field.db_default is NOT_PROVIDED + or new_field.db_default != old_field.db_default + ): + actions.append( + self._alter_column_database_default_sql(model, old_field, new_field) + ) + elif old_field.db_default is not NOT_PROVIDED: + actions.append( + self._alter_column_database_default_sql( + model, old_field, new_field, drop=True + ) + ) # When changing a column NULL constraint to NOT NULL with a given # default value, we need to perform 4 steps: # 1. Add a default for new incoming writes @@ -1033,7 +1070,11 @@ class BaseDatabaseSchemaEditor: # 4. Drop the default again. # Default change? needs_database_default = False - if old_field.null and not new_field.null: + if ( + old_field.null + and not new_field.null + and new_field.db_default is NOT_PROVIDED + ): old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) if ( @@ -1051,9 +1092,9 @@ class BaseDatabaseSchemaEditor: if fragment: null_actions.append(fragment) # Only if we have a default and there is a change from NULL to NOT NULL - four_way_default_alteration = new_field.has_default() and ( - old_field.null and not new_field.null - ) + four_way_default_alteration = ( + new_field.has_default() or new_field.db_default is not NOT_PROVIDED + ) and (old_field.null and not new_field.null) if actions or null_actions: if not four_way_default_alteration: # If we don't have to do a 4-way default alteration we can @@ -1074,15 +1115,20 @@ class BaseDatabaseSchemaEditor: params, ) if four_way_default_alteration: + if new_field.db_default is NOT_PROVIDED: + default_sql = "%s" + params = [new_default] + else: + default_sql, params = self.db_default_sql(new_field) # Update existing rows with default value self.execute( self.sql_update_with_default % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(new_field.column), - "default": "%s", + "default": default_sql, }, - [new_default], + params, ) # Since we didn't run a NOT NULL change before we need to do it # now @@ -1264,6 +1310,34 @@ class BaseDatabaseSchemaEditor: params, ) + def _alter_column_database_default_sql( + self, model, old_field, new_field, drop=False + ): + """ + Hook to specialize column database default alteration. + + Return a (sql, params) fragment to add or drop (depending on the drop + argument) a default to new_field's column. + """ + if drop: + sql = self.sql_alter_column_no_default + default_sql = "" + params = [] + else: + sql = self.sql_alter_column_default + default_sql, params = self.db_default_sql(new_field) + + new_db_params = new_field.db_parameters(connection=self.connection) + return ( + sql + % { + "column": self.quote_name(new_field.column), + "type": new_db_params["type"], + "default": default_sql, + }, + params, + ) + def _alter_column_type_sql( self, model, old_field, new_field, new_type, old_collation, new_collation ): |