summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/sql/compiler.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py27
-rw-r--r--test/sql/test_tablesample.py14
3 files changed, 24 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ab1b42b2f..dbf919cf3 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1325,10 +1325,9 @@ class SQLCompiler(Compiled):
return "LATERAL %s" % self.visit_alias(lateral, **kw)
def visit_tablesample(self, tablesample, asfrom=False, **kw):
- text = "%s TABLESAMPLE %s(%s)" % (
+ text = "%s TABLESAMPLE %s" % (
self.visit_alias(tablesample, asfrom=True, **kw),
- tablesample.method,
- tablesample.arg)
+ tablesample._get_method()._compiler_dispatch(self, **kw))
if tablesample.seed is not None:
text += " REPEATABLE (%s)" % (
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 0bb590bd0..c554429be 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -183,7 +183,7 @@ def lateral(selectable, name=None):
return selectable.lateral(name=name)
-def tablesample(selectable, arg, name=None, method=None, seed=None):
+def tablesample(selectable, sampling, name=None, seed=None):
"""Return a :class:`.TableSample` object.
:class:`.TableSample` is an :class:`.Alias` subclass that represents
@@ -193,15 +193,13 @@ def tablesample(selectable, arg, name=None, method=None, seed=None):
percentage of rows from a table. It supports multiple sampling methods,
most commonly BERNOULLI and SYSTEM.
- :param arg: a ``float`` percentage between 0 and 100.
-
- :param method: string name of the method to use.
- Commonly accepted methods are ``"BERNOULLI"`` and ``"SYSTEM"``.
+ :param sampling: a ``float`` percentage between 0 and 100 or
+ :class:`.functions.Function`.
:param seed: any real-valued SQL expression.
"""
- return selectable.sample(arg, name=name, method=method, seed=seed)
+ return selectable.sample(sampling, name=name, seed=seed)
class Selectable(ClauseElement):
@@ -471,14 +469,14 @@ class FromClause(Selectable):
"""
return Lateral(self, name)
- def sample(self, arg, name=None, method=None, seed=None):
+ def sample(self, sampling, name=None, seed=None):
"""Return a TABLESAMPLE alias of this :class:`.FromClause`.
The return value is the :class:`.TableSample` construct also
provided by the top-level :func:`~.expression.tablesample` function.
"""
- return TableSample(self, arg, name, method, seed)
+ return TableSample(self, sampling, name, seed)
def is_derived_from(self, fromclause):
"""Return True if this FromClause is 'derived' from the given
@@ -1285,15 +1283,20 @@ class TableSample(Alias):
__visit_name__ = 'tablesample'
- def __init__(self, selectable, arg,
+ def __init__(self, selectable, sampling,
name=None,
- method=None,
seed=None):
- self.arg = arg
- self.method = method or 'SYSTEM'
+ self.sampling = sampling
self.seed = seed
super(TableSample, self).__init__(selectable, name=name)
+ @util.dependencies("sqlalchemy.sql.functions")
+ def _get_method(self, functions):
+ if isinstance(self.sampling, functions.Function):
+ return self.sampling
+ else:
+ return functions.func.system(self.sampling)
+
class CTE(Generative, HasSuffixes, Alias):
"""Represent a Common Table Expression.
diff --git a/test/sql/test_tablesample.py b/test/sql/test_tablesample.py
index 6c21f1173..5470e1748 100644
--- a/test/sql/test_tablesample.py
+++ b/test/sql/test_tablesample.py
@@ -1,6 +1,6 @@
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message
-from sqlalchemy.sql import select, func
+from sqlalchemy.sql import select, func, text
from sqlalchemy.engine import default
from sqlalchemy import exc
from sqlalchemy import Table, Integer, String, Column
@@ -28,26 +28,26 @@ class TableSampleTest(fixtures.TablesTest, AssertsCompiledSQL):
# context of a FROM clause
self.assert_compile(
tablesample(table1, 1, name='alias'),
- 'people AS alias TABLESAMPLE SYSTEM(1)'
+ 'people AS alias TABLESAMPLE system(:system_1)'
)
self.assert_compile(
table1.sample(1, name='alias'),
- 'people AS alias TABLESAMPLE SYSTEM(1)'
+ 'people AS alias TABLESAMPLE system(:system_1)'
)
self.assert_compile(
- tablesample(table1, 1, name='alias', method='BERNOULLI',
+ tablesample(table1, func.bernoulli(1), name='alias',
seed=func.random()),
- 'people AS alias TABLESAMPLE BERNOULLI(1) REPEATABLE (random())'
+ 'people AS alias TABLESAMPLE bernoulli(:bernoulli_1) REPEATABLE (random())'
)
def test_select_from(self):
table1 = self.tables.people
self.assert_compile(
- select([table1.sample(1, name='alias').c.people_id]),
+ select([table1.sample(text('1'), name='alias').c.people_id]),
'SELECT alias.people_id FROM '
- 'people AS alias TABLESAMPLE SYSTEM(1)'
+ 'people AS alias TABLESAMPLE system(1)'
)