diff options
| -rw-r--r-- | CHANGES | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 25 | ||||
| -rw-r--r-- | test/engine/test_execute.py | 10 | ||||
| -rw-r--r-- | test/orm/test_query.py | 6 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 26 | ||||
| -rw-r--r-- | test/sql/test_functions.py | 2 | ||||
| -rw-r--r-- | test/sql/test_query.py | 63 |
7 files changed, 123 insertions, 18 deletions
@@ -328,6 +328,15 @@ underneath "0.7.xx". docs for "Registering New Dialects". [ticket:2462] + - [feature] The "required" flag is set to + True by default, if not passed explicitly, + on bindparam() if the "value" or "callable" + parameters are not passed. + This will cause statement execution to check + for the parameter being present in the final + collection of bound parameters, rather than + implicitly assigning None. [ticket:2556] + - [bug] The names of the columns on the .c. attribute of a select().apply_labels() is now based on <tablename>_<colkey> instead diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b7b965ea9..6b184d1ca 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -54,6 +54,7 @@ __all__ = [ 'tuple_', 'type_coerce', 'union', 'union_all', 'update', ] PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT') +NO_ARG = util.symbol('NO_ARG') def nullsfirst(column): """Return a NULLS FIRST ``ORDER BY`` clause element. @@ -990,7 +991,7 @@ def table(name, *columns): """ return TableClause(name, *columns) -def bindparam(key, value=None, type_=None, unique=False, required=False, +def bindparam(key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, quote=None, callable_=None): """Create a bind parameter clause with the given key. @@ -1007,6 +1008,14 @@ def bindparam(key, value=None, type_=None, unique=False, required=False, overridden by the dictionary of parameters sent to statement compilation/execution. + Defaults to ``None``, however if neither ``value`` nor + ``callable`` are passed explicitly, the ``required`` flag will be set to + ``True`` which has the effect of requiring a value be present + when the statement is actually executed. + + .. versionchanged:: 0.8 The ``required`` flag is set to ``True`` + automatically if ``value`` or ``callable`` is not passed. + :param callable\_: A callable function that takes the place of "value". The function will be called at statement execution time to determine the @@ -1026,7 +1035,14 @@ def bindparam(key, value=None, type_=None, unique=False, required=False, :class:`.ClauseElement`. :param required: - a value is required at execution time. + If ``True``, a value is required at execution time. If not passed, + is set to ``True`` or ``False`` based on whether or not + one of ``value`` or ``callable`` were passed.. + + .. versionchanged:: 0.8 If the ``required`` flag is not specified, + it will be set automatically to ``True`` or ``False`` depending + on whether or not the ``value`` or ``callable`` parameters + were specified. :param quote: True if this parameter name requires quoting and is not @@ -1037,6 +1053,10 @@ def bindparam(key, value=None, type_=None, unique=False, required=False, if isinstance(key, ColumnClause): type_ = key.type key = key.name + if required is NO_ARG: + required = (value is NO_ARG and callable_ is None) + if value is NO_ARG: + value = None return BindParameter(key, value, type_=type_, callable_=callable_, unique=unique, required=required, @@ -1703,6 +1723,7 @@ class ClauseElement(Visitable): def visit_bindparam(bind): if bind.key in kwargs: bind.value = kwargs[bind.key] + bind.required = False if unique: bind._convert_to_unique() return cloned_traverse(self, {}, {'bindparam': visit_bindparam}) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 1067600df..900a3c8ee 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -19,7 +19,7 @@ from sqlalchemy.engine.base import Connection, Engine from test.lib import fixtures import StringIO -users, metadata = None, None +users, metadata, users_autoinc = None, None, None class ExecuteTest(fixtures.TestBase): @classmethod def setup_class(cls): @@ -315,11 +315,9 @@ class ExecuteTest(fixtures.TestBase): def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" - result = \ - testing.db.execute(users_autoinc.insert(). - values(user_name=bindparam('name')), []) - eq_(testing.db.execute(users_autoinc.select()).fetchall(), [(1, - None)]) + testing.db.execute(users_autoinc.insert(). + values(user_name=bindparam('name', None)), []) + eq_(testing.db.execute(users_autoinc.select()).fetchall(), [(1, None)]) @testing.requires.ad_hoc_engines def test_engine_level_options(self): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index a2f1ff1dc..04b62f8c9 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -947,7 +947,8 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): session = create_session() - q = session.query(User.id).filter(User.id==bindparam('foo')).params(foo=7).subquery() + q = session.query(User.id).filter(User.id == bindparam('foo')).\ + params(foo=7).subquery() q = session.query(User).filter(User.id.in_(q)) @@ -957,7 +958,8 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address session = create_session() - s = session.query(User.id).join(User.addresses).group_by(User.id).having(func.count(Address.id) > 2) + s = session.query(User.id).join(User.addresses).group_by(User.id).\ + having(func.count(Address.id) > 2) eq_( session.query(User).filter(User.id.in_(s)).all(), [User(id=8)] diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 55b583071..40d29f222 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1,5 +1,15 @@ #! coding:utf-8 +""" +compiler tests. + +These tests are among the very first that were written when SQLAlchemy +began in 2005. As a result the testing style here is very dense; +it's an ongoing job to break these into much smaller tests with correct pep8 +styling and coherent test organization. + +""" + from test.lib.testing import eq_, is_, assert_raises, assert_raises_message import datetime, re, operator, decimal from sqlalchemy import * @@ -1446,21 +1456,24 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): # test Text embedded within select_from(), using binds generate_series = text( "generate_series(:x, :y, :z) as s(a)", - bindparams=[bindparam('x'), bindparam('y'), bindparam('z')] + bindparams=[bindparam('x', None), + bindparam('y', None), bindparam('z', None)] ) - s =select([ + s = select([ (func.current_date() + literal_column("s.a")).label("dates") ]).select_from(generate_series) self.assert_compile( s, - "SELECT CURRENT_DATE + s.a AS dates FROM generate_series(:x, :y, :z) as s(a)", + "SELECT CURRENT_DATE + s.a AS dates FROM " + "generate_series(:x, :y, :z) as s(a)", checkparams={'y': None, 'x': None, 'z': None} ) self.assert_compile( s.params(x=5, y=6, z=7), - "SELECT CURRENT_DATE + s.a AS dates FROM generate_series(:x, :y, :z) as s(a)", + "SELECT CURRENT_DATE + s.a AS dates FROM " + "generate_series(:x, :y, :z) as s(a)", checkparams={'y': 6, 'x': 5, 'z': 7} ) @@ -1879,7 +1892,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "UNION (SELECT foo, bar FROM bat INTERSECT SELECT foo, bar FROM bat)" ) - @testing.uses_deprecated() def test_binds(self): for ( stmt, @@ -1947,13 +1959,15 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): {'myid_1':5, 'myid_2': 6}, {'myid_1':5, 'myid_2':6}, [5,6] ), ( - bindparam('test', type_=String) + text("'hi'"), + bindparam('test', type_=String, required=False) + text("'hi'"), ":test || 'hi'", "? || 'hi'", {'test':None}, [None], {}, {'test':None}, [None] ), ( + # testing select.params() here - bindparam() objects + # must get required flag set to False select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myotherid'))).\ params({'myid':8, 'myotherid':7}), diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index f0fcd4b72..8e5c6bc58 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -268,7 +268,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_functions_with_cols(self): users = table('users', column('id'), column('name'), column('fullname')) calculate = select([column('q'), column('z'), column('r')], - from_obj=[func.calculate(bindparam('x'), bindparam('y'))]) + from_obj=[func.calculate(bindparam('x', None), bindparam('y', None))]) self.assert_compile(select([users], users.c.id > calculate.c.z), "SELECT users.id, users.name, users.fullname " diff --git a/test/sql/test_query.py b/test/sql/test_query.py index e79bf32e3..670fb2c64 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1,4 +1,4 @@ -from test.lib.testing import eq_, assert_raises_message, assert_raises +from test.lib.testing import eq_, assert_raises_message, assert_raises, is_ import datetime from sqlalchemy import * from sqlalchemy import exc, sql, util @@ -1216,6 +1216,67 @@ class QueryTest(fixtures.TestBase): r = s.execute().fetchall() assert len(r) == 1 +class RequiredBindTest(fixtures.TablesTest): + run_create_tables = None + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table('foo', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('x', Integer) + ) + + def _assert_raises(self, stmt, params): + assert_raises_message( + exc.StatementError, + "A value is required for bind parameter 'x'", + testing.db.execute, stmt, **params) + + assert_raises_message( + exc.StatementError, + "A value is required for bind parameter 'x'", + testing.db.execute, stmt, params) + + def test_insert(self): + stmt = self.tables.foo.insert().values(x=bindparam('x'), + data=bindparam('data')) + self._assert_raises( + stmt, {'data': 'data'} + ) + + def test_select_where(self): + stmt = select([self.tables.foo]).\ + where(self.tables.foo.c.data == bindparam('data')).\ + where(self.tables.foo.c.x == bindparam('x')) + self._assert_raises( + stmt, {'data': 'data'} + ) + + def test_select_columns(self): + stmt = select([bindparam('data'), bindparam('x')]) + self._assert_raises( + stmt, {'data': 'data'} + ) + + def test_text(self): + stmt = text("select * from foo where x=:x and data=:data1") + self._assert_raises( + stmt, {'data1': 'data'} + ) + + def test_required_flag(self): + is_(bindparam('foo').required, True) + is_(bindparam('foo', required=False).required, False) + is_(bindparam('foo', 'bar').required, False) + is_(bindparam('foo', 'bar', required=True).required, True) + + c = lambda: None + is_(bindparam('foo', callable_=c, required=True).required, True) + is_(bindparam('foo', callable_=c).required, False) + is_(bindparam('foo', callable_=c, required=False).required, False) + class TableInsertTest(fixtures.TablesTest): """test for consistent insert behavior across dialects regarding the inline=True flag, lower-case 't' tables. |
