summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArmin Ronacher <armin.ronacher@active-4.com>2012-05-27 08:34:05 -0700
committerArmin Ronacher <armin.ronacher@active-4.com>2012-05-27 08:34:05 -0700
commit7415f6f8816e8eae1e34db76504062e98e5e6ab0 (patch)
tree3a95e625866fe66c95107905a9651ee34364ba87
parent98caea1496846935dd60a0e170c401e91ce9029a (diff)
parent931232563fc4dc5d9795f4c478517a4e9d2fe4cf (diff)
downloadmarkupsafe-7415f6f8816e8eae1e34db76504062e98e5e6ab0.tar.gz
Merge pull request #6 from bukzor/subclass_tests
tests for subclassing Markup, overriding escape
-rw-r--r--markupsafe/__init__.py34
-rw-r--r--markupsafe/tests.py44
2 files changed, 62 insertions, 16 deletions
diff --git a/markupsafe/__init__.py b/markupsafe/__init__.py
index 1fe38f1..480252c 100644
--- a/markupsafe/__init__.py
+++ b/markupsafe/__init__.py
@@ -75,13 +75,13 @@ class Markup(unicode):
return self
def __add__(self, other):
- if hasattr(other, '__html__') or isinstance(other, basestring):
- return self.__class__(unicode(self) + unicode(escape(other)))
+ if isinstance(other, basestring) or hasattr(other, '__html__'):
+ return self.__class__(super(Markup, self).__add__(self.escape(other)))
return NotImplemented
def __radd__(self, other):
if hasattr(other, '__html__') or isinstance(other, basestring):
- return self.__class__(unicode(escape(other)) + unicode(self))
+ return self.escape(other).__add__(self)
return NotImplemented
def __mul__(self, num):
@@ -92,9 +92,9 @@ class Markup(unicode):
def __mod__(self, arg):
if isinstance(arg, tuple):
- arg = tuple(imap(_MarkupEscapeHelper, arg))
+ arg = tuple(imap(_MarkupEscapeHelper, arg, self.escape))
else:
- arg = _MarkupEscapeHelper(arg)
+ arg = _MarkupEscapeHelper(arg, self.escape)
return self.__class__(unicode.__mod__(self, arg))
def __repr__(self):
@@ -104,7 +104,7 @@ class Markup(unicode):
)
def join(self, seq):
- return self.__class__(unicode.join(self, imap(escape, seq)))
+ return self.__class__(unicode.join(self, imap(self.escape, seq)))
join.__doc__ = unicode.join.__doc__
def split(self, *args, **kwargs):
@@ -166,8 +166,8 @@ class Markup(unicode):
def make_wrapper(name):
orig = getattr(unicode, name)
def func(self, *args, **kwargs):
- args = _escape_argspec(list(args), enumerate(args))
- _escape_argspec(kwargs, kwargs.iteritems())
+ args = _escape_argspec(list(args), enumerate(args), self.escape)
+ #_escape_argspec(kwargs, kwargs.iteritems(), None)
return self.__class__(orig(self, *args, **kwargs))
func.__name__ = orig.__name__
func.__doc__ = orig.__doc__
@@ -183,10 +183,10 @@ class Markup(unicode):
if hasattr(unicode, 'partition'):
def partition(self, sep):
return tuple(map(self.__class__,
- unicode.partition(self, escape(sep))))
+ unicode.partition(self, self.escape(sep))))
def rpartition(self, sep):
return tuple(map(self.__class__,
- unicode.rpartition(self, escape(sep))))
+ unicode.rpartition(self, self.escape(sep))))
# new in python 2.6
if hasattr(unicode, 'format'):
@@ -199,7 +199,7 @@ class Markup(unicode):
del method, make_wrapper
-def _escape_argspec(obj, iterable):
+def _escape_argspec(obj, iterable, escape):
"""Helper for various string-wrapped functions."""
for key, value in iterable:
if hasattr(value, '__html__') or isinstance(value, basestring):
@@ -210,13 +210,13 @@ def _escape_argspec(obj, iterable):
class _MarkupEscapeHelper(object):
"""Helper for Markup.__mod__"""
- def __init__(self, obj):
+ def __init__(self, obj, escape):
self.obj = obj
+ self.escape = escape
- __getitem__ = lambda s, x: _MarkupEscapeHelper(s.obj[x])
- __str__ = lambda s: str(escape(s.obj))
- __unicode__ = lambda s: unicode(escape(s.obj))
- __repr__ = lambda s: str(escape(repr(s.obj)))
+ __getitem__ = lambda s, x: _MarkupEscapeHelper(s.obj[x], s.escape)
+ __unicode__ = lambda s: unicode(s.escape(s.obj))
+ __repr__ = lambda s: str(s.escape(repr(s.obj)))
__int__ = lambda s: int(s.obj)
__float__ = lambda s: float(s.obj)
@@ -227,3 +227,5 @@ try:
from markupsafe._speedups import escape, escape_silent, soft_unicode
except ImportError:
from markupsafe._native import escape, escape_silent, soft_unicode
+
+# vim:sts=4:sw=4:et:
diff --git a/markupsafe/tests.py b/markupsafe/tests.py
index 77a0efb..dbb3125 100644
--- a/markupsafe/tests.py
+++ b/markupsafe/tests.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
import gc
import unittest
from markupsafe import Markup, escape, escape_silent
@@ -18,6 +19,9 @@ class MarkupTestCase(unittest.TestCase):
'username': '<bad user>'
} == '<em>&lt;bad user&gt;</em>'
+ assert Markup('%i') % 3.14 == '3'
+ assert Markup('%.2f') % 3.14 == '3.14'
+
# an escaped object is markup too
assert type(Markup('foo') + 'bar') is Markup
@@ -64,10 +68,48 @@ class MarkupLeakTestCase(unittest.TestCase):
counts.add(len(gc.get_objects()))
assert len(counts) == 1, 'ouch, c extension seems to leak objects'
+class EncodedMarkup(Markup):
+ __slots__ = ()
+ encoding = 'utf8'
+
+ @classmethod
+ def escape(cls, s):
+ if isinstance(s, str):
+ s = s.decode('utf8')
+ return super(EncodedMarkup, cls).escape(s)
+
+class MarkupSubclassTestCase(unittest.TestCase):
+ # The Russian name of Russia (Rossija)
+ russia = u'Россия'
+ utf8 = russia.encode('utf8')
+
+ def test_escape(self):
+ myval = EncodedMarkup.escape(self.utf8)
+ assert myval == self.russia, repr(myval)
+ def test_add(self):
+ myval = EncodedMarkup() + self.utf8
+ assert myval == self.russia, repr(myval)
+ def test_radd(self):
+ myval = self.utf8 + EncodedMarkup()
+ assert myval == self.russia, repr(myval)
+ def test_join(self):
+ myval = EncodedMarkup().join([self.utf8])
+ assert myval == self.russia, repr(myval)
+ def test_partition(self):
+ assert EncodedMarkup(self.russia).partition(self.utf8)[1] == self.russia
+ assert EncodedMarkup(self.russia).rpartition(self.utf8)[1] == self.russia
+ def test_mod(self):
+ assert EncodedMarkup('%s') % self.utf8 == self.russia
+ assert EncodedMarkup('%r') % self.utf8 == escape(repr(self.utf8))
+ def test_strip(self):
+ assert EncodedMarkup(self.russia).strip(self.utf8) == u''
+ assert EncodedMarkup(self.russia).rstrip(self.utf8) == u''
+
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MarkupTestCase))
+ suite.addTest(unittest.makeSuite(MarkupSubclassTestCase))
# this test only tests the c extension
if not hasattr(escape, 'func_code'):
@@ -78,3 +120,5 @@ def suite():
if __name__ == '__main__':
unittest.main(defaultTest='suite')
+
+# vim:sts=4:sw=4:et: