summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorR. Tyler Ballance <tyler@slide.com>2009-07-13 08:00:26 +0800
committerJames Abbatiello <abbeyj@gmail.com>2009-07-13 08:04:51 +0800
commit6429c3a055e1f4029ea59f40c76acb7e6ae50b1f (patch)
tree15fc093f571cee6cde2658caac11e6da6ec74b46
parent714a46b4ed922d7afaf8ef95144890cfcd74007e (diff)
downloadpython-cheetah-6429c3a055e1f4029ea59f40c76acb7e6ae50b1f.tar.gz
Add preliminary support for multiple inheritance via the #extends directive
This is covered in mantis #26 Signed-off-by: James Abbatiello <abbeyj@gmail.com>
-rw-r--r--src/Compiler.py64
-rw-r--r--src/Parser.py39
-rw-r--r--src/Tests/Template.py4
3 files changed, 74 insertions, 33 deletions
diff --git a/src/Compiler.py b/src/Compiler.py
index 4b8e44e..39c7f51 100644
--- a/src/Compiler.py
+++ b/src/Compiler.py
@@ -1739,38 +1739,40 @@ class ModuleCompiler(SettingsManager, GenUtils):
# - We also assume that the final . separates the classname from the
# module name. This might break if people do something really fancy
# with their dots and namespaces.
- chunks = baseClassName.split('.')
- if len(chunks)==1:
- self._getActiveClassCompiler().setBaseClass(baseClassName)
- if baseClassName not in self.importedVarNames():
- modName = baseClassName
- # we assume the class name to be the module name
- # and that it's not a builtin:
- importStatement = "from %s import %s" % (modName, baseClassName)
- self.addImportStatement(importStatement)
- self.addImportedVarNames( [baseClassName,] )
- else:
- needToAddImport = True
- modName = chunks[0]
- #print chunks, ':', self.importedVarNames()
- for chunk in chunks[1:-1]:
- if modName in self.importedVarNames():
- needToAddImport = False
- finalBaseClassName = baseClassName.replace(modName+'.', '')
- self._getActiveClassCompiler().setBaseClass(finalBaseClassName)
- break
- else:
- modName += '.'+chunk
- if needToAddImport:
- modName, finalClassName = '.'.join(chunks[:-1]), chunks[-1]
- #if finalClassName != chunks[:-1][-1]:
- if finalClassName != chunks[-2]:
+ baseclasses = baseClassName.split(',')
+ for klass in baseclasses:
+ chunks = klass.split('.')
+ if len(chunks)==1:
+ self._getActiveClassCompiler().setBaseClass(klass)
+ if klass not in self.importedVarNames():
+ modName = klass
# we assume the class name to be the module name
- modName = '.'.join(chunks)
- self._getActiveClassCompiler().setBaseClass(finalClassName)
- importStatement = "from %s import %s" % (modName, finalClassName)
- self.addImportStatement(importStatement)
- self.addImportedVarNames( [finalClassName,] )
+ # and that it's not a builtin:
+ importStatement = "from %s import %s" % (modName, klass)
+ self.addImportStatement(importStatement)
+ self.addImportedVarNames((klass,))
+ else:
+ needToAddImport = True
+ modName = chunks[0]
+ #print chunks, ':', self.importedVarNames()
+ for chunk in chunks[1:-1]:
+ if modName in self.importedVarNames():
+ needToAddImport = False
+ finalBaseClassName = klass.replace(modName+'.', '')
+ self._getActiveClassCompiler().setBaseClass(finalBaseClassName)
+ break
+ else:
+ modName += '.'+chunk
+ if needToAddImport:
+ modName, finalClassName = '.'.join(chunks[:-1]), chunks[-1]
+ #if finalClassName != chunks[:-1][-1]:
+ if finalClassName != chunks[-2]:
+ # we assume the class name to be the module name
+ modName = '.'.join(chunks)
+ self._getActiveClassCompiler().setBaseClass(finalClassName)
+ importStatement = "from %s import %s" % (modName, finalClassName)
+ self.addImportStatement(importStatement)
+ self.addImportedVarNames( [finalClassName,] )
def setCompilerSetting(self, key, valueExpr):
self.setSetting(key, eval(valueExpr) )
diff --git a/src/Parser.py b/src/Parser.py
index 3e6e7fe..7436e9c 100644
--- a/src/Parser.py
+++ b/src/Parser.py
@@ -596,6 +596,42 @@ class _LowLevelParser(SourceReader):
if not match:
raise ParseError(self, msg='Invalid multi-line comment end token')
return self.readTo(match.end())
+
+ def getCommaSeparatedSymbols(self):
+ """
+ Loosely based on getDottedName to pull out comma separated
+ named chunks
+ """
+ srcLen = len(self)
+ pieces = []
+ nameChunks = []
+
+ if not self.peek() in identchars:
+ raise ParseError(self)
+
+ while self.pos() < srcLen:
+ c = self.peek()
+ if c in namechars:
+ nameChunk = self.getIdentifier()
+ nameChunks.append(nameChunk)
+ elif c == '.':
+ if self.pos()+1 <srcLen and self.peek(1) in identchars:
+ nameChunks.append(self.getc())
+ else:
+ break
+ elif c == ',':
+ self.getc()
+ pieces.append(''.join(nameChunks))
+ nameChunks = []
+ elif c in (' ', '\t'):
+ self.getc()
+ else:
+ break
+
+ if nameChunks:
+ pieces.append(''.join(nameChunks))
+
+ return pieces
def getDottedName(self):
srcLen = len(self)
@@ -2037,7 +2073,8 @@ class _HighLevelParser(_LowLevelParser):
if self.setting('allowExpressionsInExtendsDirective'):
baseName = self.getExpression()
else:
- baseName = self.getDottedName()
+ baseName = self.getCommaSeparatedSymbols()
+ baseName = ', '.join(baseName)
baseName = self._applyExpressionFilters(baseName, 'extends', startPos=startPos)
self._compiler.setBaseClass(baseName) # in compiler
diff --git a/src/Tests/Template.py b/src/Tests/Template.py
index 06f9768..085180d 100644
--- a/src/Tests/Template.py
+++ b/src/Tests/Template.py
@@ -338,7 +338,9 @@ class MultipleInheritanceSupport(TemplateTest):
#return [4,5] + $boink()
#end def
'''
- template = Template.compile(template)
+ template = Template.compile(template,
+ moduleGlobals={'Useless' : Useless},
+ compilerSettings={'autoImportForExtendsDirective' : False})
template = template()
result = template.foo()
print result