diff options
author | Timothy Crosley <timothy.crosley@gmail.com> | 2019-12-26 14:14:22 -0800 |
---|---|---|
committer | Timothy Crosley <timothy.crosley@gmail.com> | 2019-12-26 14:14:22 -0800 |
commit | 955e5b73cff6ac54c91a3a4df46e9ef2fb772e18 (patch) | |
tree | 0bf582587f60c9a634ff4b907145bccd9c9278be | |
parent | ecfe14a5745ca394267838cb12e4c93df57638b1 (diff) | |
download | isort-feature/contiguous-import-sorting.tar.gz |
Add support for nested import sectionsfeature/contiguous-import-sorting
-rw-r--r-- | isort/api.py | 34 | ||||
-rw-r--r-- | tests/test_isort.py | 29 |
2 files changed, 53 insertions, 10 deletions
diff --git a/isort/api.py b/isort/api.py index 4c6bc949..fa512d61 100644 --- a/isort/api.py +++ b/isort/api.py @@ -1,4 +1,5 @@ import re +import textwrap from io import StringIO from itertools import chain from pathlib import Path @@ -154,6 +155,7 @@ def sort_imports( in_top_comment: bool = False first_import_section: bool = True section_comments = [f"# {heading}" for heading in config.import_headings.values()] + indent: str = "" for index, line in enumerate(chain(input_stream, (None,))): if line is None: @@ -171,7 +173,7 @@ def sort_imports( stripped_line = line.strip() if ( (index == 0 or (index == 1 and not contains_imports)) - and line.startswith("#") + and stripped_line.startswith("#") and stripped_line not in section_comments ): in_top_comment = True @@ -180,7 +182,7 @@ def sort_imports( in_top_comment = False first_comment_index_end = index - 1 - if not line.startswith("#") and '"' in line or "'" in line: + if not stripped_line.startswith("#") and '"' in line or "'" in line: char_index = 0 if first_comment_index_start == -1 and ( line.startswith('"') or line.startswith("'") @@ -211,6 +213,9 @@ def sort_imports( if not stripped_line or stripped_line.startswith("#"): import_section += line elif stripped_line.startswith(IMPORT_START_IDENTIFIERS): + contains_imports = True + + indent = line[:-len(line.lstrip())] import_section += line while stripped_line.endswith("\\") or ( "(" in stripped_line and ")" not in stripped_line @@ -225,8 +230,6 @@ def sort_imports( line = input_stream.readline() stripped_line = line.strip().split("#")[0] import_section += line - - contains_imports = True else: not_imports = True @@ -243,12 +246,13 @@ def sort_imports( add_imports = [] if import_section: - if add_imports: + if add_imports and not indent: import_section += line_separator.join(add_imports) + line_separator contains_imports = True add_imports = [] - import_section += line + if not indent: + import_section += line if not contains_imports: output_stream.write(import_section) else: @@ -257,11 +261,21 @@ def sort_imports( ).startswith(COMMENT_INDICATORS): import_section = import_section.lstrip(line_separator) first_import_section = False - output_stream.write( - output.sorted_imports( - parse.file_contents(import_section, config=config), config, extension - ) + + if indent: + import_section = line_separator.join(line.lstrip() for line in import_section.split(line_separator)) + sorted_import_section = output.sorted_imports( + parse.file_contents(import_section, config=config), config, extension ) + if indent: + sorted_import_section = textwrap.indent(sorted_import_section, indent) + line_separator + + output_stream.write(sorted_import_section) + + if indent: + output_stream.write(line) + indent = "" + contains_imports = False import_section = "" else: diff --git a/tests/test_isort.py b/tests/test_isort.py index e4632f41..8b6ae620 100644 --- a/tests/test_isort.py +++ b/tests/test_isort.py @@ -2922,6 +2922,7 @@ def test_to_ensure_importing_from_imports_module_works_issue_662() -> None: "@wraps(fun)\n" "def __inner(*args, **kwargs):\n" " from .imports import qualname\n" + "\n" " warn(description=description or qualname(fun), deprecation=deprecation, " "removal=removal)\n" ) @@ -4158,3 +4159,31 @@ def test_isort_with_single_character_import() -> None: """ test_input = "from django.db.models import CASCADE, SET_NULL, Q\n" assert SortImports(file_contents=test_input).output == test_input + + +def test_isort_nested_imports() -> None: + """Ensure imports in a nested block get sorted correctly""" + test_input = """ + def import_test(): + import sys + import os + + # my imports + from . import def + from . import abc + + return True + """ + assert ( + SortImports(file_contents=test_input).output + == """ + def import_test(): + import os + import sys + + # my imports + from . import abc, def + + return True + """ + ) |