diff options
| author | John Bodley <john.bodley@airbnb.com> | 2020-03-28 15:25:34 -0700 |
|---|---|---|
| committer | Andi Albrecht <albrecht.andi@gmail.com> | 2020-03-29 19:08:25 +0200 |
| commit | 28a2acdd9b307c95d2111f1f831b1a5afc634691 (patch) | |
| tree | 3e5b40d852869eec83df554fa6bba2e9fdcc2a16 | |
| parent | 6e39e02e3898b77415c558a822079a1f885524d5 (diff) | |
| download | sqlparse-28a2acdd9b307c95d2111f1f831b1a5afc634691.tar.gz | |
[fix] Adding TypedLiteral to comparison
| -rw-r--r-- | sqlparse/engine/grouping.py | 2 | ||||
| -rw-r--r-- | tests/test_grouping.py | 13 |
2 files changed, 13 insertions, 2 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index daaffb0..e7a4211 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -192,7 +192,7 @@ def group_assignment(tlist): def group_comparison(tlist): sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier, - sql.Operation) + sql.Operation, sql.TypedLiteral) ttypes = T_NUMERICAL + T_STRING + T_NAME def match(token): diff --git a/tests/test_grouping.py b/tests/test_grouping.py index a147063..87dcf11 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -36,7 +36,7 @@ def test_grouping_assignment(s): @pytest.mark.parametrize('s', ["x > DATE '2020-01-01'", "x > TIMESTAMP '2020-01-01 00:00:00'"]) def test_grouping_typed_literal(s): parsed = sqlparse.parse(s)[0] - assert isinstance(parsed[4], sql.TypedLiteral) + assert isinstance(parsed[0][4], sql.TypedLiteral) @pytest.mark.parametrize('s, a, b', [ @@ -550,6 +550,17 @@ def test_comparison_with_functions(): assert p.tokens[0].right.value == 'bar.baz' +def test_comparison_with_typed_literal(): + p = sqlparse.parse("foo = DATE 'bar.baz'")[0] + assert len(p.tokens) == 1 + comp = p.tokens[0] + assert isinstance(comp, sql.Comparison) + assert len(comp.tokens) == 5 + assert comp.left.value == 'foo' + assert isinstance(comp.right, sql.TypedLiteral) + assert comp.right.value == "DATE 'bar.baz'" + + @pytest.mark.parametrize('start', ['FOR', 'FOREACH']) def test_forloops(start): p = sqlparse.parse('{0} foo in bar LOOP foobar END LOOP'.format(start))[0] |
