summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <pcmanticore@gmail.com>2020-06-23 08:17:33 +0200
committerGitHub <noreply@github.com>2020-06-23 08:17:33 +0200
commitec96745c0fdb9432549d182e381164d1836e8a4b (patch)
tree3e2e9354011aec6cfc5f23f1055d3dc22bf8f356
parent4b7566b0c8365613198af493d4115d32d7d4c66e (diff)
downloadastroid-git-ec96745c0fdb9432549d182e381164d1836e8a4b.tar.gz
Separate string and bytes classes patching (#807)
Fixes PyCQA/pylint#3599
-rw-r--r--ChangeLog4
-rw-r--r--astroid/brain/brain_builtin_inference.py132
-rw-r--r--tests/unittest_brain.py15
-rw-r--r--tests/unittest_inference.py12
4 files changed, 108 insertions, 55 deletions
diff --git a/ChangeLog b/ChangeLog
index 8fab2804..737b682e 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -9,6 +9,10 @@ Release Date: TBA
* Added a brain for ``sqlalchemy.orm.session``
+* Separate string and bytes classes patching
+
+ Fixes PyCQA/pylint#3599
+
* Added missing methods to the brain for ``mechanize``, to fix pylint false positives
Close #793
diff --git a/astroid/brain/brain_builtin_inference.py b/astroid/brain/brain_builtin_inference.py
index f3b4c155..074ec476 100644
--- a/astroid/brain/brain_builtin_inference.py
+++ b/astroid/brain/brain_builtin_inference.py
@@ -40,52 +40,90 @@ from astroid import util
OBJECT_DUNDER_NEW = "object.__new__"
-
-def _extend_str(class_node, rvalue):
+STR_CLASS = """
+class whatever(object):
+ def join(self, iterable):
+ return {rvalue}
+ def replace(self, old, new, count=None):
+ return {rvalue}
+ def format(self, *args, **kwargs):
+ return {rvalue}
+ def encode(self, encoding='ascii', errors=None):
+ return b''
+ def decode(self, encoding='ascii', errors=None):
+ return u''
+ def capitalize(self):
+ return {rvalue}
+ def title(self):
+ return {rvalue}
+ def lower(self):
+ return {rvalue}
+ def upper(self):
+ return {rvalue}
+ def swapcase(self):
+ return {rvalue}
+ def index(self, sub, start=None, end=None):
+ return 0
+ def find(self, sub, start=None, end=None):
+ return 0
+ def count(self, sub, start=None, end=None):
+ return 0
+ def strip(self, chars=None):
+ return {rvalue}
+ def lstrip(self, chars=None):
+ return {rvalue}
+ def rstrip(self, chars=None):
+ return {rvalue}
+ def rjust(self, width, fillchar=None):
+ return {rvalue}
+ def center(self, width, fillchar=None):
+ return {rvalue}
+ def ljust(self, width, fillchar=None):
+ return {rvalue}
+"""
+
+
+BYTES_CLASS = """
+class whatever(object):
+ def join(self, iterable):
+ return {rvalue}
+ def replace(self, old, new, count=None):
+ return {rvalue}
+ def decode(self, encoding='ascii', errors=None):
+ return u''
+ def capitalize(self):
+ return {rvalue}
+ def title(self):
+ return {rvalue}
+ def lower(self):
+ return {rvalue}
+ def upper(self):
+ return {rvalue}
+ def swapcase(self):
+ return {rvalue}
+ def index(self, sub, start=None, end=None):
+ return 0
+ def find(self, sub, start=None, end=None):
+ return 0
+ def count(self, sub, start=None, end=None):
+ return 0
+ def strip(self, chars=None):
+ return {rvalue}
+ def lstrip(self, chars=None):
+ return {rvalue}
+ def rstrip(self, chars=None):
+ return {rvalue}
+ def rjust(self, width, fillchar=None):
+ return {rvalue}
+ def center(self, width, fillchar=None):
+ return {rvalue}
+ def ljust(self, width, fillchar=None):
+ return {rvalue}
+"""
+
+
+def _extend_string_class(class_node, code, rvalue):
"""function to extend builtin str/unicode class"""
- code = dedent(
- """
- class whatever(object):
- def join(self, iterable):
- return {rvalue}
- def replace(self, old, new, count=None):
- return {rvalue}
- def format(self, *args, **kwargs):
- return {rvalue}
- def encode(self, encoding='ascii', errors=None):
- return ''
- def decode(self, encoding='ascii', errors=None):
- return u''
- def capitalize(self):
- return {rvalue}
- def title(self):
- return {rvalue}
- def lower(self):
- return {rvalue}
- def upper(self):
- return {rvalue}
- def swapcase(self):
- return {rvalue}
- def index(self, sub, start=None, end=None):
- return 0
- def find(self, sub, start=None, end=None):
- return 0
- def count(self, sub, start=None, end=None):
- return 0
- def strip(self, chars=None):
- return {rvalue}
- def lstrip(self, chars=None):
- return {rvalue}
- def rstrip(self, chars=None):
- return {rvalue}
- def rjust(self, width, fillchar=None):
- return {rvalue}
- def center(self, width, fillchar=None):
- return {rvalue}
- def ljust(self, width, fillchar=None):
- return {rvalue}
- """
- )
code = code.format(rvalue=rvalue)
fake = AstroidBuilder(MANAGER).string_build(code)["whatever"]
for method in fake.mymethods():
@@ -106,8 +144,8 @@ def _extend_builtins(class_transforms):
_extend_builtins(
{
- "bytes": partial(_extend_str, rvalue="b''"),
- "str": partial(_extend_str, rvalue="''"),
+ "bytes": partial(_extend_string_class, code=BYTES_CLASS, rvalue="b''"),
+ "str": partial(_extend_string_class, code=STR_CLASS, rvalue="''"),
}
)
diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py
index 0a833664..25e2bb5b 100644
--- a/tests/unittest_brain.py
+++ b/tests/unittest_brain.py
@@ -2020,5 +2020,20 @@ def test_dataclasses():
assert isinstance(name[0], astroid.Unknown)
+@pytest.mark.parametrize(
+ "code,expected_class,expected_value",
+ [
+ ("'hey'.encode()", astroid.Const, b""),
+ ("b'hey'.decode()", astroid.Const, ""),
+ ("'hey'.encode().decode()", astroid.Const, ""),
+ ],
+)
+def test_str_and_bytes(code, expected_class, expected_value):
+ node = astroid.extract_node(code)
+ inferred = next(node.infer())
+ assert isinstance(inferred, expected_class)
+ assert inferred.value == expected_value
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py
index 140648d6..76c7e879 100644
--- a/tests/unittest_inference.py
+++ b/tests/unittest_inference.py
@@ -2083,8 +2083,6 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
def test_str_methods(self):
code = """
' '.decode() #@
-
- ' '.encode() #@
' '.join('abcd') #@
' '.replace('a', 'b') #@
' '.format('a') #@
@@ -2106,15 +2104,13 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
"""
ast = extract_node(code, __name__)
self.assertInferConst(ast[0], "")
- for i in range(1, 16):
+ for i in range(1, 15):
self.assertInferConst(ast[i], "")
- for i in range(16, 19):
+ for i in range(15, 18):
self.assertInferConst(ast[i], 0)
def test_unicode_methods(self):
code = """
- u' '.encode() #@
-
u' '.decode() #@
u' '.join('abcd') #@
u' '.replace('a', 'b') #@
@@ -2137,9 +2133,9 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
"""
ast = extract_node(code, __name__)
self.assertInferConst(ast[0], "")
- for i in range(1, 16):
+ for i in range(1, 15):
self.assertInferConst(ast[i], "")
- for i in range(16, 19):
+ for i in range(15, 18):
self.assertInferConst(ast[i], 0)
def test_scope_lookup_same_attributes(self):