summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-03-05 20:31:44 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-03-05 20:31:44 +0000
commit9c4f3c0480f54e08b3aa2800ed76e89f957f8131 (patch)
treee7cad83cbd55ff0e2a3f4103160e7e8fed6b6a2c
parentc1d0c2dffc0eedfa63de5b90addb70bfd3a81540 (diff)
downloadsqlalchemy-9c4f3c0480f54e08b3aa2800ed76e89f957f8131.tar.gz
got column onupdate working
improvement to Function so that they can more easily be called standalone without having to throw them into a select().
-rw-r--r--doc/build/content/sqlconstruction.myt11
-rw-r--r--lib/sqlalchemy/ansisql.py62
-rw-r--r--lib/sqlalchemy/databases/oracle.py5
-rw-r--r--lib/sqlalchemy/databases/postgres.py16
-rw-r--r--lib/sqlalchemy/engine.py46
-rw-r--r--lib/sqlalchemy/schema.py10
-rw-r--r--lib/sqlalchemy/sql.py9
-rw-r--r--test/defaults.py53
8 files changed, 169 insertions, 43 deletions
diff --git a/doc/build/content/sqlconstruction.myt b/doc/build/content/sqlconstruction.myt
index c38670506..065ef2bcc 100644
--- a/doc/build/content/sqlconstruction.myt
+++ b/doc/build/content/sqlconstruction.myt
@@ -341,6 +341,17 @@ WHERE substr(users.user_name, :substr) = :substr_1
</&>
</&>
+ <p>Functions also are callable as standalone values:</p>
+ <&|formatting.myt:code &>
+ # call the "now()" function
+ time = func.now(engine=myengine).scalar()
+
+ # call myfunc(1,2,3)
+ myvalue = func.myfunc(1, 2, 3, engine=db).execute()
+
+ # or call them off the engine
+ db.func.now().scalar()
+ </&>
</&>
<&|doclib.myt:item, name="literals", description="Literals" &>
<p>You can drop in a literal value anywhere there isnt a column to attach to via the <span class="codeline">literal</span> keyword:</p>
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 7c0002aa5..7b39d5358 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -15,6 +15,18 @@ from sqlalchemy.sql import *
from sqlalchemy.util import *
import string, re
+ANSI_FUNCS = HashSet([
+'CURRENT_TIME',
+'CURRENT_TIMESTAMP',
+'CURRENT_DATE',
+'LOCAL_TIME',
+'LOCAL_TIMESTAMP',
+'CURRENT_USER',
+'SESSION_USER',
+'USER'
+])
+
+
def engine(**params):
return ANSISQLEngine(**params)
@@ -57,6 +69,7 @@ class ANSICompiler(sql.Compiled):
self.select_stack = []
self.typemap = typemap or {}
self.isinsert = False
+ self.isupdate = False
self.bindtemplate = ":%s"
if engine is not None:
self.paramstyle = engine.paramstyle
@@ -89,7 +102,7 @@ class ANSICompiler(sql.Compiled):
self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement])
def get_from_text(self, obj):
- return self.froms[obj]
+ return self.froms.get(obj, None)
def get_str(self, obj):
return self.strings[obj]
@@ -158,6 +171,11 @@ class ANSICompiler(sql.Compiled):
else:
return parameters
+ def default_from(self):
+ """called when a SELECT statement has no froms, and no FROM clause is to be appended.
+ gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ return ""
+
def visit_label(self, label):
if len(self.select_stack):
self.typemap.setdefault(label.name.lower(), label.obj.type)
@@ -211,7 +229,12 @@ class ANSICompiler(sql.Compiled):
self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
def visit_function(self, func):
- self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+ if len(self.select_stack):
+ self.typemap.setdefault(func.name, func.type)
+ if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
+ self.strings[func] = func.name
+ else:
+ self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
def visit_compound_select(self, cs):
text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
@@ -325,7 +348,9 @@ class ANSICompiler(sql.Compiled):
if len(froms):
text += " \nFROM "
text += string.join(froms, ', ')
-
+ else:
+ text += self.default_from()
+
if whereclause is not None:
t = self.get_str(whereclause)
if t:
@@ -384,21 +409,33 @@ class ANSICompiler(sql.Compiled):
def visit_insert_column_default(self, column, default):
"""called when visiting an Insert statement, for each column in the table that
- contains a ColumnDefault object."""
+ contains a ColumnDefault object. adds a blank 'placeholder' parameter so the
+ Insert gets compiled with this column's name in its column and VALUES clauses."""
+ self.parameters.setdefault(column.key, None)
+
+ def visit_update_column_default(self, column, default):
+ """called when visiting an Update statement, for each column in the table that
+ contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the
+ Update gets compiled with this column's name as one of its SET clauses."""
self.parameters.setdefault(column.key, None)
def visit_insert_sequence(self, column, sequence):
"""called when visiting an Insert statement, for each column in the table that
- contains a Sequence object."""
+ contains a Sequence object. Overridden by compilers that support sequences to place
+ a blank 'placeholder' parameter, so the Insert gets compiled with this column's
+ name in its column and VALUES clauses."""
pass
def visit_insert_column(self, column):
"""called when visiting an Insert statement, for each column in the table
- that is a NULL insert into the table"""
+ that is a NULL insert into the table. Overridden by compilers who disallow
+ NULL columns being set in an Insert where there is a default value on the column
+ (i.e. postgres), to remove the column from the parameter list."""
pass
def visit_insert(self, insert_stmt):
- # set up a call for the defaults and sequences inside the table
+ # scan the table's columns for defaults that have to be pre-set for an INSERT
+ # add these columns to the parameter list via visit_insert_XXX methods
class DefaultVisitor(schema.SchemaVisitor):
def visit_column(s, c):
self.visit_insert_column(c)
@@ -424,6 +461,17 @@ class ANSICompiler(sql.Compiled):
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
+ # scan the table's columns for onupdates that have to be pre-set for an UPDATE
+ # add these columns to the parameter list via visit_update_XXX methods
+ class OnUpdateVisitor(schema.SchemaVisitor):
+ def visit_column_onupdate(s, cd):
+ self.visit_update_column_default(c, cd)
+ vis = OnUpdateVisitor()
+ for c in update_stmt.table.c:
+ if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
+ c.accept_schema_visitor(vis)
+
+ self.isupdate = True
colparams = self._get_colparams(update_stmt)
def create_param(p):
if isinstance(p, sql.BindParamClause):
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 6f5e98265..eab200317 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -209,6 +209,11 @@ class OracleCompiler(ansisql.ANSICompiler):
self._use_ansi = use_ansi
ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs)
+ def default_from(self):
+ """called when a SELECT statement has no froms, and no FROM clause is to be appended.
+ gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ return " FROM DUAL"
+
def visit_join(self, join):
if self._use_ansi:
return ansisql.ANSICompiler.visit_join(self, join)
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 105fe7a76..db20b636c 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -103,16 +103,6 @@ class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
-ANSI_FUNCS = util.HashSet([
-'CURRENT_TIME',
-'CURRENT_TIMESTAMP',
-'CURRENT_DATE',
-'LOCAL_TIME',
-'LOCAL_TIMESTAMP',
-'CURRENT_USER',
-'SESSION_USER',
-'USER'
-])
pg2_colspecs = {
sqltypes.Integer : PGInteger,
@@ -283,12 +273,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
class PGCompiler(ansisql.ANSICompiler):
- def visit_function(self, func):
- # PG has a bunch of funcs that explicitly need no parenthesis
- if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
- self.strings[func] = func.name
- else:
- super(PGCompiler, self).visit_function(func)
def visit_insert_column(self, column):
# Postgres advises against OID usage and turns it off in 8.1,
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
index 7d158cb7e..3703169fa 100644
--- a/lib/sqlalchemy/engine.py
+++ b/lib/sqlalchemy/engine.py
@@ -135,6 +135,12 @@ class DefaultRunner(schema.SchemaVisitor):
else:
return None
+ def get_column_onupdate(self, column):
+ if column.onupdate is not None:
+ return column.onupdate.accept_schema_visitor(self)
+ else:
+ return None
+
def visit_passive_default(self, default):
"""passive defaults by definition return None on the app side,
and are post-fetched to get the DB-side value"""
@@ -147,7 +153,15 @@ class DefaultRunner(schema.SchemaVisitor):
def exec_default_sql(self, default):
c = sql.select([default.arg], engine=self.engine).compile()
return self.proxy(str(c), c.get_params()).fetchone()[0]
-
+
+ def visit_column_onupdate(self, onupdate):
+ if isinstance(onupdate.arg, sql.ClauseElement):
+ return self.exec_default_sql(onupdate)
+ elif callable(onupdate.arg):
+ return onupdate.arg()
+ else:
+ return onupdate.arg
+
def visit_column_default(self, default):
if isinstance(default.arg, sql.ClauseElement):
return self.exec_default_sql(default)
@@ -245,6 +259,13 @@ class SQLEngine(schema.SchemaEngine):
typeobj = typeobj()
return typeobj
+ def _func(self):
+ class FunctionGateway(object):
+ def __getattr__(s, name):
+ return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs)
+ return FunctionGateway()
+ func = property(_func)
+
def text(self, text, *args, **kwargs):
"""returns a sql.text() object for performing literal queries."""
return sql.text(text, engine=self, *args, **kwargs)
@@ -426,6 +447,15 @@ class SQLEngine(schema.SchemaEngine):
self.context.tcount = None
def _process_defaults(self, proxy, compiled, parameters, **kwargs):
+ """INSERT and UPDATE statements, when compiled, may have additional columns added to their
+ VALUES and SET lists corresponding to column defaults/onupdates that are present on the
+ Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those
+ DefaultGenerator objects that require pre-execution and sets their values within the
+ parameter list, and flags the thread-local state about
+ PassiveDefault objects that may require post-fetching the row after it is inserted/updated.
+ This method relies upon logic within the ANSISQLCompiler in its visit_insert and
+ visit_update methods that add the appropriate column clauses to the statement when its
+ being compiled, so that these parameters can be bound to the statement."""
if compiled is None: return
if getattr(compiled, "isinsert", False):
if isinstance(parameters, list):
@@ -454,7 +484,19 @@ class SQLEngine(schema.SchemaEngine):
self.context.last_inserted_ids = None
else:
self.context.last_inserted_ids = last_inserted_ids
-
+ elif getattr(compiled, 'isupdate', False):
+ if isinstance(parameters, list):
+ plist = parameters
+ else:
+ plist = [parameters]
+ drunner = self.defaultrunner(proxy)
+ for param in plist:
+ for c in compiled.statement.table.c:
+ if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
+ value = drunner.get_column_onupdate(c)
+ if value is not None:
+ param[c.name] = value
+
def lastrow_has_defaults(self):
return self.context.lastrow_has_defaults
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 57ae7ba5a..5cb9f2043 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -364,6 +364,8 @@ class Column(sql.ColumnClause, SchemaItem):
then calls visit_column on the visitor."""
if self.default is not None:
self.default.accept_schema_visitor(visitor)
+ if self.onupdate is not None:
+ self.onupdate.accept_schema_visitor(visitor)
if self.foreign_key is not None:
self.foreign_key.accept_schema_visitor(visitor)
visitor.visit_column(self)
@@ -473,7 +475,10 @@ class ColumnDefault(DefaultGenerator):
self.arg = arg
def accept_schema_visitor(self, visitor):
"""calls the visit_column_default method on the given visitor."""
- return visitor.visit_column_default(self)
+ if self.for_update:
+ return visitor.visit_column_onupdate(self)
+ else:
+ return visitor.visit_column_default(self)
def __repr__(self):
return "ColumnDefault(%s)" % repr(self.arg)
@@ -599,6 +604,9 @@ class SchemaVisitor(sql.ClauseVisitor):
def visit_column_default(self, default):
"""visit a ColumnDefault."""
pass
+ def visit_column_onupdate(self, onupdate):
+ """visit a ColumnDefault with the "for_update" flag set."""
+ pass
def visit_sequence(self, sequence):
"""visit a Sequence."""
pass
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index f05310e42..cee328b53 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -762,6 +762,9 @@ class Function(ClauseList, ColumnElement):
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = kwargs.get('type', sqltypes.NULLTYPE)
+ self._engine = kwargs.get('engine', None)
+ if self._engine is not None:
+ self.type = self._engine.type_descriptor(self.type)
ClauseList.__init__(self, parens=True, *clauses)
key = property(lambda self:self.name)
def append(self, clause):
@@ -771,6 +774,8 @@ class Function(ClauseList, ColumnElement):
else:
clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
self.clauses.append(clause)
+ def _process_from_dict(self, data, asfrom):
+ data.setdefault(self, self)
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return Function(self.name, type=self.type, *clauses)
@@ -782,6 +787,10 @@ class Function(ClauseList, ColumnElement):
return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
def select(self):
return select([self])
+ def scalar(self):
+ return select([self]).scalar()
+ def execute(self):
+ return select([self]).execute()
def _compare_type(self, obj):
return self.type
diff --git a/test/defaults.py b/test/defaults.py
index 459b3abfe..c2c8877eb 100644
--- a/test/defaults.py
+++ b/test/defaults.py
@@ -7,11 +7,11 @@ from sqlalchemy import *
import sqlalchemy
db = testbase.db
-
+testbase.echo=False
class DefaultTest(PersistTest):
def setUpAll(self):
- global t, f, ts
+ global t, f, ts, currenttime
x = {'x':50}
def mydefault():
x['x'] += 1
@@ -22,18 +22,19 @@ class DefaultTest(PersistTest):
# select "count(1)" from the DB which returns different results
# on different DBs
+ currenttime = db.func.current_date(type=Date);
if is_oracle:
- f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
- ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
- def1 = func.sysdate()
+ ts = db.func.sysdate().scalar()
+ f = select([func.count(1) + 5], engine=db).scalar()
+ def1 = currenttime
def2 = text("sysdate")
deftype = Date
elif use_function_defaults:
f = select([func.count(1) + 5], engine=db).scalar()
- def1 = func.current_date()
+ def1 = currenttime
def2 = text("current_date")
deftype = Date
- ts = select([func.current_date()], engine=db).scalar()
+ ts = db.func.current_date().scalar()
else:
f = select([func.count(1) + 5], engine=db).scalar()
def1 = def2 = "3"
@@ -45,20 +46,29 @@ class DefaultTest(PersistTest):
Column('col1', Integer, primary_key=True, default=mydefault),
# python literal
- Column('col2', String(20), default="imthedefault"),
+ Column('col2', String(20), default="imthedefault", onupdate="im the update"),
# preexecute expression
- Column('col3', Integer, default=func.count(1) + 5),
+ Column('col3', Integer, default=func.count(1) + 5, onupdate=func.count(1) + 14),
# SQL-side default from sql expression
Column('col4', deftype, PassiveDefault(def1)),
# SQL-side default from literal expression
- Column('col5', deftype, PassiveDefault(def2))
+ Column('col5', deftype, PassiveDefault(def2)),
+
+ # preexecute + update timestamp
+ Column('col6', Date, default=currenttime, onupdate=currenttime)
)
t.create()
- def teststandalonedefaults(self):
+ def tearDownAll(self):
+ t.drop()
+
+ def tearDown(self):
+ t.delete().execute()
+
+ def teststandalone(self):
x = t.c.col1.default.execute()
y = t.c.col2.default.execute()
z = t.c.col3.default.execute()
@@ -66,18 +76,27 @@ class DefaultTest(PersistTest):
self.assert_(y == 'imthedefault')
self.assert_(z == 6)
- def testinsertdefaults(self):
+ def testinsert(self):
t.insert().execute()
self.assert_(t.engine.lastrow_has_defaults())
t.insert().execute()
t.insert().execute()
-
- l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
- def tearDownAll(self):
- t.drop()
+ ctexec = currenttime.scalar()
+ self.echo("Currenttime "+ repr(ctexec))
+ l = t.select().execute()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)])
+ def testupdate(self):
+ t.insert().execute()
+ pk = t.engine.last_inserted_ids()[0]
+ t.update(t.c.col1==pk).execute(col4=None, col5=None)
+ ctexec = currenttime.scalar()
+ self.echo("Currenttime "+ repr(ctexec))
+ l = t.select(t.c.col1==pk).execute()
+ l = l.fetchone()
+ self.assert_(l == (pk, 'im the update', 15, None, None, ctexec))
+
class SequenceTest(PersistTest):
def setUpAll(self):