diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/mysqldb.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 34 |
4 files changed, 57 insertions, 28 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 6af860133..7941d4c41 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -64,19 +64,11 @@ class MySQLExecutionContext_mysqldb(MySQLExecutionContext): class MySQLCompiler_mysqldb(MySQLCompiler): - def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) - - def post_process_text(self, text): - return text.replace('%', '%%') + pass class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): - - def _escape_identifier(self, value): - value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace("%", "%%") + pass class MySQLDialect_mysqldb(MySQLDialect): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 50328143e..31792a492 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -455,18 +455,11 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) - - def post_process_text(self, text): - return text.replace('%', '%%') + pass class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): - def _escape_identifier(self, value): - value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + pass class PGDialect_psycopg2(PGDialect): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b18f90312..cc4248009 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -695,7 +695,6 @@ class SQLCompiler(Compiled): name = self.escape_literal_column(name) else: name = self.preparer.quote(name) - table = column.table if table is None or not include_table or not table.named_with_column: return name @@ -715,12 +714,6 @@ class SQLCompiler(Compiled): self.preparer.quote(tablename) + \ "." + name - def escape_literal_column(self, text): - """provide escaping for the literal_column() construct.""" - - # TODO: some dialects might need different behavior here - return text.replace('%', '%%') - def visit_fromclause(self, fromclause, **kwargs): return fromclause.name @@ -732,6 +725,13 @@ class SQLCompiler(Compiled): return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): + if self.preparer._double_percents: + text = text.replace('%', '%%') + return text + + def escape_literal_column(self, text): + if self.preparer._double_percents: + text = text.replace('%', '%%') return text def visit_textclause(self, textclause, **kw): @@ -1048,6 +1048,14 @@ class SQLCompiler(Compiled): else: return self._generate_generic_binary(binary, opstring, **kw) + def visit_mod_binary(self, binary, operator, **kw): + if self.preparer._double_percents: + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + else: + return self.process(binary.left, **kw) + " % " + \ + self.process(binary.right, **kw) + def visit_custom_op_binary(self, element, operator, **kw): kw['eager_grouping'] = operator.eager_grouping return self._generate_generic_binary( @@ -2888,6 +2896,7 @@ class IdentifierPreparer(object): self.escape_to_quote = self.escape_quote * 2 self.omit_schema = omit_schema self._strings = {} + self._double_percents = self.dialect.paramstyle in ('format', 'pyformat') def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) @@ -2902,7 +2911,10 @@ class IdentifierPreparer(object): escaping behavior. """ - return value.replace(self.escape_quote, self.escape_to_quote) + value = value.replace(self.escape_quote, self.escape_to_quote) + if self._double_percents: + value = value.replace('%', '%%') + return value def _unescape_identifier(self, value): """Canonicalize an escaped identifier. diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 00884a212..0e62c347f 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -1,9 +1,11 @@ from .. import fixtures, config from ..config import requirements from sqlalchemy import exc -from sqlalchemy import Integer, String +from sqlalchemy import Integer, String, select, literal_column from .. import assert_raises from ..schema import Table, Column +from .. import provide_metadata +from .. import eq_ class ExceptionTest(fixtures.TablesTest): @@ -39,3 +41,33 @@ class ExceptionTest(fixtures.TablesTest): self.tables.manual_pk.insert(), {'id': 1, 'data': 'd1'} ) + + +class EscapingTest(fixtures.TestBase): + @provide_metadata + def test_percent_sign_round_trip(self): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + m = self.metadata + t = Table('t', m, Column('data', String(50))) + t.create(config.db) + with config.db.begin() as conn: + conn.execute(t.insert(), dict(data="some % value")) + conn.execute(t.insert(), dict(data="some %% other value")) + + eq_( + conn.scalar( + select([t.c.data]).where( + t.c.data == literal_column("'some % value'")) + ), + "some % value" + ) + + eq_( + conn.scalar( + select([t.c.data]).where( + t.c.data == literal_column("'some %% other value'")) + ), "some %% other value" + ) |
