diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7c9a68236..502371be9 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1056,6 +1056,7 @@ from ... import sql from ... import util from ...engine import default from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -2648,7 +2649,8 @@ class MySQLDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return connection.exec_driver_sql("SELECT DATABASE()").scalar() - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) if schema is None: @@ -2670,7 +2672,8 @@ class MySQLDialect(default.DefaultDialect): ) return bool(rs.scalar()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -2847,14 +2850,20 @@ class MySQLDialect(default.DefaultDialect): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return parsed_state.table_options + if parsed_state.table_options: + return parsed_state.table_options + else: + return ReflectionDefaults.table_options() @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return parsed_state.columns + if parsed_state.columns: + return parsed_state.columns + else: + return ReflectionDefaults.columns() @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): @@ -2866,7 +2875,7 @@ class MySQLDialect(default.DefaultDialect): # There can be only one. cols = [s[0] for s in key["columns"]] return {"constrained_columns": cols, "name": None} - return {"constrained_columns": [], "name": None} + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2909,7 +2918,7 @@ class MySQLDialect(default.DefaultDialect): if self._needs_correct_for_88718_96365: self._correct_for_mysql_bugs_88718_96365(fkeys, connection) - return fkeys + return fkeys if fkeys else ReflectionDefaults.foreign_keys() def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): # Foreign key is always in lower case (MySQL 8.0) @@ -3000,21 +3009,22 @@ class MySQLDialect(default.DefaultDialect): connection, table_name, schema, **kw ) - return [ + cks = [ {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] + return cks if cks else ReflectionDefaults.check_constraints() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return { - "text": parsed_state.table_options.get( - "%s_comment" % self.name, None - ) - } + comment = parsed_state.table_options.get(f"{self.name}_comment", None) + if comment is not None: + return {"text": comment} + else: + return ReflectionDefaults.table_comment() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -3058,7 +3068,8 @@ class MySQLDialect(default.DefaultDialect): if flavor: index_d["type"] = flavor indexes.append(index_d) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + return indexes if indexes else ReflectionDefaults.indexes() @reflection.cache def get_unique_constraints( @@ -3068,7 +3079,7 @@ class MySQLDialect(default.DefaultDialect): connection, table_name, schema, **kw ) - return [ + ucs = [ { "name": key["name"], "column_names": [col[0] for col in key["columns"]], @@ -3077,6 +3088,11 @@ class MySQLDialect(default.DefaultDialect): for key in parsed_state.keys if key["type"] == "UNIQUE" ] + ucs.sort(key=lambda d: d["name"] or "~") # sort None as last + if ucs: + return ucs + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): @@ -3088,6 +3104,9 @@ class MySQLDialect(default.DefaultDialect): sql = self._show_create_table( connection, None, charset, full_name=full_name ) + if sql.upper().startswith("CREATE TABLE"): + # it's a table, not a view + raise exc.NoSuchTableError(full_name) return sql def _parsed_state_or_create( |
