diff options
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 27 | ||||
-rw-r--r-- | test/sql/test_tablesample.py | 14 |
3 files changed, 24 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ab1b42b2f..dbf919cf3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1325,10 +1325,9 @@ class SQLCompiler(Compiled): return "LATERAL %s" % self.visit_alias(lateral, **kw) def visit_tablesample(self, tablesample, asfrom=False, **kw): - text = "%s TABLESAMPLE %s(%s)" % ( + text = "%s TABLESAMPLE %s" % ( self.visit_alias(tablesample, asfrom=True, **kw), - tablesample.method, - tablesample.arg) + tablesample._get_method()._compiler_dispatch(self, **kw)) if tablesample.seed is not None: text += " REPEATABLE (%s)" % ( diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 0bb590bd0..c554429be 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -183,7 +183,7 @@ def lateral(selectable, name=None): return selectable.lateral(name=name) -def tablesample(selectable, arg, name=None, method=None, seed=None): +def tablesample(selectable, sampling, name=None, seed=None): """Return a :class:`.TableSample` object. :class:`.TableSample` is an :class:`.Alias` subclass that represents @@ -193,15 +193,13 @@ def tablesample(selectable, arg, name=None, method=None, seed=None): percentage of rows from a table. It supports multiple sampling methods, most commonly BERNOULLI and SYSTEM. - :param arg: a ``float`` percentage between 0 and 100. - - :param method: string name of the method to use. - Commonly accepted methods are ``"BERNOULLI"`` and ``"SYSTEM"``. + :param sampling: a ``float`` percentage between 0 and 100 or + :class:`.functions.Function`. :param seed: any real-valued SQL expression. """ - return selectable.sample(arg, name=name, method=method, seed=seed) + return selectable.sample(sampling, name=name, seed=seed) class Selectable(ClauseElement): @@ -471,14 +469,14 @@ class FromClause(Selectable): """ return Lateral(self, name) - def sample(self, arg, name=None, method=None, seed=None): + def sample(self, sampling, name=None, seed=None): """Return a TABLESAMPLE alias of this :class:`.FromClause`. The return value is the :class:`.TableSample` construct also provided by the top-level :func:`~.expression.tablesample` function. """ - return TableSample(self, arg, name, method, seed) + return TableSample(self, sampling, name, seed) def is_derived_from(self, fromclause): """Return True if this FromClause is 'derived' from the given @@ -1285,15 +1283,20 @@ class TableSample(Alias): __visit_name__ = 'tablesample' - def __init__(self, selectable, arg, + def __init__(self, selectable, sampling, name=None, - method=None, seed=None): - self.arg = arg - self.method = method or 'SYSTEM' + self.sampling = sampling self.seed = seed super(TableSample, self).__init__(selectable, name=name) + @util.dependencies("sqlalchemy.sql.functions") + def _get_method(self, functions): + if isinstance(self.sampling, functions.Function): + return self.sampling + else: + return functions.func.system(self.sampling) + class CTE(Generative, HasSuffixes, Alias): """Represent a Common Table Expression. diff --git a/test/sql/test_tablesample.py b/test/sql/test_tablesample.py index 6c21f1173..5470e1748 100644 --- a/test/sql/test_tablesample.py +++ b/test/sql/test_tablesample.py @@ -1,6 +1,6 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message -from sqlalchemy.sql import select, func +from sqlalchemy.sql import select, func, text from sqlalchemy.engine import default from sqlalchemy import exc from sqlalchemy import Table, Integer, String, Column @@ -28,26 +28,26 @@ class TableSampleTest(fixtures.TablesTest, AssertsCompiledSQL): # context of a FROM clause self.assert_compile( tablesample(table1, 1, name='alias'), - 'people AS alias TABLESAMPLE SYSTEM(1)' + 'people AS alias TABLESAMPLE system(:system_1)' ) self.assert_compile( table1.sample(1, name='alias'), - 'people AS alias TABLESAMPLE SYSTEM(1)' + 'people AS alias TABLESAMPLE system(:system_1)' ) self.assert_compile( - tablesample(table1, 1, name='alias', method='BERNOULLI', + tablesample(table1, func.bernoulli(1), name='alias', seed=func.random()), - 'people AS alias TABLESAMPLE BERNOULLI(1) REPEATABLE (random())' + 'people AS alias TABLESAMPLE bernoulli(:bernoulli_1) REPEATABLE (random())' ) def test_select_from(self): table1 = self.tables.people self.assert_compile( - select([table1.sample(1, name='alias').c.people_id]), + select([table1.sample(text('1'), name='alias').c.people_id]), 'SELECT alias.people_id FROM ' - 'people AS alias TABLESAMPLE SYSTEM(1)' + 'people AS alias TABLESAMPLE system(1)' ) |