summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin J. Hsu <martin.hsu@gmail.com>2015-09-25 16:15:28 +0800
committerMartin J. Hsu <martin.hsu@gmail.com>2015-10-15 10:46:33 +0800
commitc7d04beeac6ad54d638afb01783dee2d769aef9d (patch)
treed085f54e97c06138ecbb5ce30898f5b8aeda936f
parent91255618ddb47553774c620a23479adf88c27b74 (diff)
downloadsqlalchemy-pr/204.tar.gz
- wrap ColumnDefault empty arg callables like functools.wraps, setting __name__, __doc__, and __module__pr/204
-rw-r--r--lib/sqlalchemy/sql/schema.py5
-rw-r--r--lib/sqlalchemy/util/__init__.py2
-rw-r--r--lib/sqlalchemy/util/langhelpers.py26
-rw-r--r--test/sql/test_defaults.py61
4 files changed, 91 insertions, 3 deletions
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 137208584..0c433d16e 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -1981,13 +1981,14 @@ class ColumnDefault(DefaultGenerator):
try:
argspec = util.get_callable_argspec(fn, no_self=True)
except TypeError:
- return lambda ctx: fn()
+ return util.wrap_callable(fn)
defaulted = argspec[3] is not None and len(argspec[3]) or 0
positionals = len(argspec[0]) - defaulted
if positionals == 0:
- return lambda ctx: fn()
+ return util.wrap_callable(fn)
+
elif positionals == 1:
return fn
else:
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index ed968f168..36a81dbce 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
safe_reraise,\
get_callable_argspec, only_once, attrsetter, ellipses_string, \
- warn_limited, map_bits, MemoizedSlots, EnsureKWArgType
+ warn_limited, map_bits, MemoizedSlots, EnsureKWArgType, wrap_callable
from .deprecations import warn_deprecated, warn_pending_deprecation, \
deprecated, pending_deprecation, inject_docstring_text
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 743afccfd..9f259aea3 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -1377,3 +1377,29 @@ class EnsureKWArgType(type):
return fn(*arg)
return update_wrapper(wrap, fn)
+
+def wrap_callable(fn):
+ """Wrap callable and set __name__, __doc__, and __module__.
+
+ :param fn:
+ object with __call__ method
+ """
+ if hasattr(fn, '__name__'):
+ _f = update_wrapper(lambda ctx: fn(), fn)
+ _f.__doc__ = _f.__doc__ or fn.__name__
+ return _f
+ else:
+ _f = lambda ctx: fn()
+ _f.__name__ = fn.__class__.__name__
+ _f.__module__ = fn.__module__
+
+ if hasattr(fn.__call__, '__doc__') and fn.__call__.__doc__:
+ _f.__doc__ = fn.__call__.__doc__
+ elif hasattr(fn.__class__, '__doc__') and fn.__class__.__doc__:
+ _f.__doc__ = fn.__class__.__doc__
+ elif fn.__doc__:
+ _f.__doc__ = fn.__doc__
+ else:
+ _f.__doc__ = fn.__class__.__name__
+
+ return _f
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 673085cf7..e2250e834 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -301,6 +301,67 @@ class DefaultTest(fixtures.TestBase):
c = sa.ColumnDefault(fn)
c.arg("context")
+ def test_wrapping_update_wrapper_fn(self):
+ def my_fancy_default():
+ """run the fancy default"""
+ return 10
+
+ c = sa.ColumnDefault(my_fancy_default)
+ eq_(c.arg.__name__, "my_fancy_default")
+ eq_(c.arg.__doc__, "run the fancy default")
+
+ def test_wrapping_update_wrapper_fn_nodocstring(self):
+ def my_fancy_default():
+ return 10
+
+ c = sa.ColumnDefault(my_fancy_default)
+ eq_(c.arg.__name__, "my_fancy_default")
+ eq_(c.arg.__doc__, "my_fancy_default")
+
+ def test_wrapping_update_wrapper_cls(self):
+ class MyFancyDefault(object):
+ """a fancy default"""
+
+ def __call__(self):
+ """run the fancy default"""
+ return 10
+
+ c = sa.ColumnDefault(MyFancyDefault())
+ eq_(c.arg.__name__, "MyFancyDefault")
+ eq_(c.arg.__doc__, "run the fancy default")
+
+ def test_wrapping_update_wrapper_cls_noclassdocstring(self):
+ class MyFancyDefault(object):
+
+ def __call__(self):
+ """run the fancy default"""
+ return 10
+
+ c = sa.ColumnDefault(MyFancyDefault())
+ eq_(c.arg.__name__, "MyFancyDefault")
+ eq_(c.arg.__doc__, "run the fancy default")
+
+ def test_wrapping_update_wrapper_cls_nomethoddocstring(self):
+ class MyFancyDefault(object):
+ """a fancy default"""
+
+ def __call__(self):
+ return 10
+
+ c = sa.ColumnDefault(MyFancyDefault())
+ eq_(c.arg.__name__, "MyFancyDefault")
+ eq_(c.arg.__doc__, "a fancy default")
+
+ def test_wrapping_update_wrapper_cls_noclassdocstring_nomethoddocstring(self):
+ class MyFancyDefault(object):
+
+ def __call__(self):
+ return 10
+
+ c = sa.ColumnDefault(MyFancyDefault())
+ eq_(c.arg.__name__, "MyFancyDefault")
+ eq_(c.arg.__doc__, "MyFancyDefault")
+
@testing.fails_on('firebird', 'Data type unknown')
def test_standalone(self):
c = testing.db.engine.contextual_connect()