summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimothy Crosley <timothy.crosley@gmail.com>2019-12-23 01:37:14 -0800
committerTimothy Crosley <timothy.crosley@gmail.com>2019-12-23 01:37:14 -0800
commit058331e60b57a81e5094cdf418dc7a61071d3cfb (patch)
tree88c381a1f89ac3b13b0fb79024f9f2259b51bfc6
parentdd339f68e89c8cfa821f66039808d9cd272e3b42 (diff)
downloadisort-058331e60b57a81e5094cdf418dc7a61071d3cfb.tar.gz
Fully move add_imports to main sort_imports entry point
-rw-r--r--isort/api.py43
1 files changed, 38 insertions, 5 deletions
diff --git a/isort/api.py b/isort/api.py
index 2ce9ead3..00b4c4fd 100644
--- a/isort/api.py
+++ b/isort/api.py
@@ -11,11 +11,12 @@ from .exceptions import (
IntroducedSyntaxErrors,
UnableToDetermineEncoding,
)
-from .format import remove_whitespace, show_unified_diff, format_natural
+from .format import format_natural, remove_whitespace, show_unified_diff
from .io import File
from .settings import DEFAULT_CONFIG, FILE_SKIP_COMMENT, Config
IMPORT_START_IDENTIFIERS = ("from ", "from.import", "import ", "import*")
+COMMENT_INDICATORS = ('"""', "'''", "'", '"', "#")
def _config(
@@ -142,15 +143,38 @@ def sort_imports(
- `output_stream`: Text stream to output sorted inputs into.
- `config`: Config settings to use when sorting imports. Defaults settings.DEFAULT_CONFIG.
"""
- add_imports = (format_natural(addition) for addition in config.add_imports)
+ line_separator: str = config.line_ending
+ add_imports = [format_natural(addition) for addition in config.add_imports]
import_section: str = ""
in_quote: str = ""
first_comment_index_start: int = -1
first_comment_index_end: int = -1
contains_imports: bool = False
in_top_comment: bool = False
+ first_import_section: bool = True
section_comments = [f"# {heading}" for heading in config.import_headings.values()]
+
+ def additional_imports() -> str:
+ nonlocal add_imports
+ nonlocal line_separator
+ nonlocal contains_imports
+
+ if not add_imports:
+ return ""
+
+ if not line_separator:
+ line_separator = "\n"
+
+ fomatted_imports: str = line_separator.join(add_imports) + line_separator
+ contains_imports = True
+ add_imports = []
+ return fomatted_imports
+
+ index: int = 0
for index, line in enumerate(input_stream):
+ if not line_separator:
+ line_separator = line[-1]
+
if index == 1 and line.startswith("#"):
in_top_comment = True
elif in_top_comment:
@@ -208,12 +232,18 @@ def sort_imports(
not_imports = True
if not_imports:
+ if not in_top_comment and not in_quote and not import_section and not line.lstrip().startswith(COMMENT_INDICATORS):
+ import_section = additional_imports()
+
if import_section:
- import_section += config.line_ending.join(add_imports)
+ import_section += additional_imports()
import_section += line
if not contains_imports:
output_stream.write(import_section)
else:
+ if first_import_section:
+ import_section = import_section.lstrip(line_separator)
+ first_import_section = False
output_stream.write(
output.sorted_imports(
parse.file_contents(import_section, config=config), config, extension
@@ -226,13 +256,16 @@ def sort_imports(
not_imports = False
if import_section:
+ import_section += additional_imports()
if not contains_imports:
output_stream.write(import_section)
else:
- import_section += config.line_ending.join(add_imports)
+ if first_import_section:
+ import_section = import_section.lstrip(line_separator)
output_stream.write(
output.sorted_imports(
parse.file_contents(import_section, config=config), config, extension
)
)
- output_stream.write(config.line_ending.join(add_imports))
+ elif index > 1 or config.force_adds:
+ output_stream.write(additional_imports())