diff options
Diffstat (limited to 'Lib/test/string_tests.py')
| -rw-r--r-- | Lib/test/string_tests.py | 39 | 
1 files changed, 39 insertions, 0 deletions
| diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 527f505c01..840d7bb755 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -5,6 +5,7 @@ Common tests shared by test_unicode, test_userstring and test_bytes.  import unittest, string, sys, struct  from test import support  from collections import UserList +import random  class Sequence:      def __init__(self, seq='wxyz'): self.seq = seq @@ -317,6 +318,44 @@ class BaseTest:          else:              self.checkraises(TypeError, 'hello', 'rindex', 42) +    def test_find_periodic_pattern(self): +        """Cover the special path for periodic patterns.""" +        def reference_find(p, s): +            for i in range(len(s)): +                if s.startswith(p, i): +                    return i +            return -1 + +        rr = random.randrange +        choices = random.choices +        for _ in range(1000): +            p0 = ''.join(choices('abcde', k=rr(10))) * rr(10, 20) +            p = p0[:len(p0) - rr(10)] # pop off some characters +            left = ''.join(choices('abcdef', k=rr(2000))) +            right = ''.join(choices('abcdef', k=rr(2000))) +            text = left + p + right +            with self.subTest(p=p, text=text): +                self.checkequal(reference_find(p, text), +                                text, 'find', p) + +    def test_find_shift_table_overflow(self): +        """When the table of 8-bit shifts overflows.""" +        N = 2**8 + 100 + +        # first check the periodic case +        # here, the shift for 'b' is N + 1. +        pattern1 = 'a' * N + 'b' + 'a' * N +        text1 = 'babbaa' * N + pattern1 +        self.checkequal(len(text1)-len(pattern1), +                        text1, 'find', pattern1) + +        # now check the non-periodic case +        # here, the shift for 'd' is 3*(N+1)+1 +        pattern2 = 'ddd' + 'abc' * N + "eee" +        text2 = pattern2[:-1] + "ddeede" * 2 * N + pattern2 + "de" * N +        self.checkequal(len(text2) - N*len("de") - len(pattern2), +                        text2, 'find', pattern2) +      def test_lower(self):          self.checkequal('hello', 'HeLLo', 'lower')          self.checkequal('hello', 'hello', 'lower') | 
