diff options
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 30 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 20 | ||||
-rw-r--r-- | test/engine/test_reflection.py | 37 |
6 files changed, 141 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 9d856e271..2642b5fdc 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2212,6 +2212,21 @@ class MySQLDialect(default.DefaultDialect): return indexes @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + + return [ + { + 'name': key['name'], + 'column_names': [col[0] for col in key['columns']] + } + for key in parsed_state.keys + if key['type'] == 'UNIQUE' + ] + + @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): charset = self._connection_charset diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 00d0acc2c..0810e0384 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1950,6 +1950,36 @@ class PGDialect(default.DefaultDialect): index_d['unique'] = unique return indexes + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + UNIQUE_SQL = """ + SELECT + cons.conname as name, + ARRAY_AGG(a.attname) as column_names + FROM + pg_catalog.pg_constraint cons + left outer join pg_attribute a + on cons.conrelid = a.attrelid and a.attnum = ANY(cons.conkey) + WHERE + cons.conrelid = :table_oid AND + cons.contype = 'u' + GROUP BY + cons.conname + """ + + t = sql.text(UNIQUE_SQL, + typemap={'column_names': ARRAY(sqltypes.Unicode)}) + c = connection.execute(t, table_oid=table_oid) + + return [ + {'name': row.name, 'column_names': row.column_names} + for row in c.fetchall() + ] + def _load_enums(self, connection): if not self.supports_native_enum: return {} diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c7e09b164..3e2a158a0 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -917,6 +917,26 @@ class SQLiteDialect(default.DefaultDialect): cols.append(row[2]) return indexes + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + UNIQUE_SQL = """ + SELECT sql + FROM + sqlite_master + WHERE + type='table' AND + name=:table_name + """ + c = connection.execute(UNIQUE_SQL, table_name=table_name) + table_data = c.fetchone()[0] + + UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' + return [ + {'name': name, 'column_names': [c.strip() for c in cols.split(',')]} + for name, cols in re.findall(UNIQUE_PATTERN, table_data) + ] + def _pragma_cursor(cursor): """work around SQLite issue whereby cursor.description diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index f623a2a61..d5fe5c5e2 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -338,6 +338,25 @@ class Dialect(object): raise NotImplementedError() + def get_unique_constraints(self, table_name, schema=None, **kw): + """Return information about unique constraints in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + unique constraint information as a list of dicts with these keys: + + name + the unique constraint's name + + column_names + list of column names in order + + \**kw + other options passed to the dialect's get_unique_constraints() method. + + """ + + raise NotImplementedError() + def normalize_name(self, name): """convert the given name to lowercase if it is detected as case insensitive. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index cf2caf679..1926e693a 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -347,6 +347,26 @@ class Inspector(object): schema, info_cache=self.info_cache, **kw) + def get_unique_constraints(self, table_name, schema=None, **kw): + """Return information about unique constraints in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + unique constraint information as a list of dicts with these keys: + + name + the unique constraint's name + + column_names + list of column names in order + + \**kw + other options passed to the dialect's get_unique_constraints() method. + + """ + + return self.dialect.get_unique_constraints( + self.bind, table_name, schema, info_cache=self.info_cache, **kw) + def reflecttable(self, table, include_columns, exclude_columns=()): """Given a Table object, load its internal constructs based on introspection. diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index ac0fa5153..fd9411874 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,3 +1,5 @@ +import operator + import unicodedata import sqlalchemy as sa from sqlalchemy import schema, events, event, inspect @@ -878,6 +880,41 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert set([t2.c.name, t2.c.id]) == set(r2.columns) assert set([t2.c.name]) == set(r3.columns) + @testing.provide_metadata + def test_unique_constraints_reflection(self): + uniques = sorted( + [ + {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']}, + {'name': 'unique_a_c', 'column_names': ['a', 'c']}, + {'name': 'unique_b_c', 'column_names': ['b', 'c']}, + ], + key=operator.itemgetter('name') + ) + + try: + orig_meta = sa.MetaData(bind=testing.db) + table = Table( + 'testtbl', orig_meta, + Column('a', sa.String(20)), + Column('b', sa.String(30)), + Column('c', sa.Integer), + ) + for uc in uniques: + table.append_constraint( + sa.UniqueConstraint(*uc['column_names'], name=uc['name']) + ) + orig_meta.create_all() + + inspector = inspect(testing.db) + reflected = sorted( + inspector.get_unique_constraints('testtbl'), + key=operator.itemgetter('name') + ) + + assert uniques == reflected + finally: + testing.db.execute('drop table if exists testtbl;') + @testing.requires.views @testing.provide_metadata def test_views(self): |