summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2009-04-27 06:10:40 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2009-04-27 06:10:40 +0200
commit5366e32e3bc1f83cf4273776f2efd93212921977 (patch)
tree12c98d3922214535cdba61c0a406e7ce1b91fe55
parent8ef584ca6ff45bebd2934d712ee33f7c6bb10235 (diff)
downloadsqlparse-5366e32e3bc1f83cf4273776f2efd93212921977.tar.gz
Improved handling of invalid identifiers, like for example "a.".
-rw-r--r--sqlparse/sql.py20
-rw-r--r--tests/test_grouping.py8
2 files changed, 26 insertions, 2 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 3ab93ba..2d94440 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -325,8 +325,24 @@ class Identifier(TokenList):
if dot is None:
return self.token_next_by_type(0, T.Name).value
else:
- return self.token_next_by_type(self.token_index(dot),
- (T.Name, T.Wildcard)).value
+ next_ = self.token_next_by_type(self.token_index(dot),
+ (T.Name, T.Wildcard))
+ if next_ is None: # invalid identifier, e.g. "a."
+ return None
+ return next_.value
+
+ def get_parent_name(self):
+ """Return name of the parent object if any.
+
+ A parent object is identified by the first occuring dot.
+ """
+ dot = self.token_next_match(0, T.Punctuation, '.')
+ if dot is None:
+ return None
+ prev_ = self.token_prev(self.token_index(dot))
+ if prev_ is None: # something must be verry wrong here..
+ return None
+ return prev_.value
def is_wildcard(self):
"""Return ``True`` if this identifier contains a wildcard."""
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index f3617ad..3c32b18 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -55,6 +55,14 @@ class TestGrouping(TestCaseBase):
self.assertEqual(t.get_name(), '*')
self.assertEqual(t.is_wildcard(), True)
+ def test_indentifier_invalid(self):
+ p = sqlparse.parse('a.')[0]
+ self.assert_(isinstance(p.tokens[0], Identifier))
+ self.assertEqual(p.tokens[0].has_alias(), False)
+ self.assertEqual(p.tokens[0].get_name(), None)
+ self.assertEqual(p.tokens[0].get_real_name(), None)
+ self.assertEqual(p.tokens[0].get_parent_name(), 'a')
+
def test_identifier_list(self):
p = sqlparse.parse('a, b, c')[0]
self.assert_(isinstance(p.tokens[0], IdentifierList))