summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-03-04 20:23:37 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-03-04 20:23:37 +0000
commitc1d0c2dffc0eedfa63de5b90addb70bfd3a81540 (patch)
treef9caefebacd3b3ed4eac8000f74697fcead46f9b
parentffc8dfbdec0903db84370c3466ccce5cb8316b7e (diff)
downloadsqlalchemy-c1d0c2dffc0eedfa63de5b90addb70bfd3a81540.tar.gz
got column defaults to be executeable
-rw-r--r--lib/sqlalchemy/engine.py11
-rw-r--r--lib/sqlalchemy/schema.py30
-rw-r--r--test/defaults.py31
3 files changed, 53 insertions, 19 deletions
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
index 0f6b65909..7d158cb7e 100644
--- a/lib/sqlalchemy/engine.py
+++ b/lib/sqlalchemy/engine.py
@@ -596,6 +596,17 @@ class SQLEngine(schema.SchemaEngine):
def _executemany(self, c, statement, parameters):
c.executemany(statement, parameters)
self.context.rowcount = c.rowcount
+
+ def proxy(self, statement=None, parameters=None):
+ executemany = parameters is not None and isinstance(parameters, list)
+
+ if self.positional:
+ if executemany:
+ parameters = [p.values() for p in parameters]
+ else:
+ parameters = parameters.values()
+
+ return self.execute(statement, parameters)
def log(self, msg):
"""logs a message using this SQLEngine's logger stream."""
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 17e421f22..57ae7ba5a 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -283,6 +283,7 @@ class Column(sql.ColumnClause, SchemaItem):
self.default = kwargs.pop('default', None)
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
+ self.onupdate = kwargs.pop('onupdate', None)
if self.index is not None and self.unique is not None:
raise ArgumentError("Column may not define both index and unique")
self._foreign_key = None
@@ -302,7 +303,7 @@ class Column(sql.ColumnClause, SchemaItem):
return "Column(%s)" % string.join(
[repr(self.name)] + [repr(self.type)] +
[repr(x) for x in [self.foreign_key] if x is not None] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default']]
+ ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']]
, ',')
def append_item(self, item):
@@ -326,6 +327,9 @@ class Column(sql.ColumnClause, SchemaItem):
if self.default is not None:
self.default = ColumnDefault(self.default)
self._init_items(self.default)
+ if self.onupdate is not None:
+ self.onupdate = ColumnDefault(self.onupdate, for_update=True)
+ self._init_items(self.onupdate)
self._init_items(*self.args)
self.args = None
@@ -435,17 +439,26 @@ class ForeignKey(SchemaItem):
class DefaultGenerator(SchemaItem):
"""Base class for column "default" values."""
+ def __init__(self, for_update=False, engine=None):
+ self.for_update = for_update
+ self.engine = engine
def _set_parent(self, column):
self.column = column
- self.column.default = self
+ if self.engine is None:
+ self.engine = column.table.engine
+ if self.for_update:
+ self.column.onupdate = self
+ else:
+ self.column.default = self
def execute(self):
- return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute))
+ return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.proxy))
def __repr__(self):
return "DefaultGenerator()"
class PassiveDefault(DefaultGenerator):
"""a default that takes effect on the database side"""
- def __init__(self, arg):
+ def __init__(self, arg, **kwargs):
+ super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
def accept_schema_visitor(self, visitor):
return visitor.visit_passive_default(self)
@@ -455,7 +468,8 @@ class PassiveDefault(DefaultGenerator):
class ColumnDefault(DefaultGenerator):
"""A plain default value on a column. this could correspond to a constant,
a callable function, or a SQL clause."""
- def __init__(self, arg):
+ def __init__(self, arg, **kwargs):
+ super(ColumnDefault, self).__init__(**kwargs)
self.arg = arg
def accept_schema_visitor(self, visitor):
"""calls the visit_column_default method on the given visitor."""
@@ -465,12 +479,12 @@ class ColumnDefault(DefaultGenerator):
class Sequence(DefaultGenerator):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, start = None, increment = None, optional=False, engine=None):
+ def __init__(self, name, start = None, increment = None, optional=False, **kwargs):
+ super(Sequence, self).__init__(**kwargs)
self.name = name
self.start = start
self.increment = increment
self.optional=optional
- self.engine = engine
def __repr__(self):
return "Sequence(%s)" % string.join(
[repr(self.name)] +
@@ -479,8 +493,6 @@ class Sequence(DefaultGenerator):
def _set_parent(self, column):
super(Sequence, self)._set_parent(column)
column.sequence = self
- if self.engine is None:
- self.engine = column.table.engine
def create(self):
self.engine.create(self)
return self
diff --git a/test/defaults.py b/test/defaults.py
index fcf852a86..459b3abfe 100644
--- a/test/defaults.py
+++ b/test/defaults.py
@@ -10,7 +10,8 @@ db = testbase.db
class DefaultTest(PersistTest):
- def testdefaults(self):
+ def setUpAll(self):
+ global t, f, ts
x = {'x':50}
def mydefault():
x['x'] += 1
@@ -56,16 +57,26 @@ class DefaultTest(PersistTest):
Column('col5', deftype, PassiveDefault(def2))
)
t.create()
- try:
- t.insert().execute()
- self.assert_(t.engine.lastrow_has_defaults())
- t.insert().execute()
- t.insert().execute()
+
+ def teststandalonedefaults(self):
+ x = t.c.col1.default.execute()
+ y = t.c.col2.default.execute()
+ z = t.c.col3.default.execute()
+ self.assert_(50 <= x <= 57)
+ self.assert_(y == 'imthedefault')
+ self.assert_(z == 6)
- l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
- finally:
- t.drop()
+ def testinsertdefaults(self):
+ t.insert().execute()
+ self.assert_(t.engine.lastrow_has_defaults())
+ t.insert().execute()
+ t.insert().execute()
+
+ l = t.select().execute()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
+
+ def tearDownAll(self):
+ t.drop()
class SequenceTest(PersistTest):