summaryrefslogtreecommitdiff
path: root/isort/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'isort/api.py')
-rw-r--r--isort/api.py41
1 files changed, 36 insertions, 5 deletions
diff --git a/isort/api.py b/isort/api.py
index 0aae7711..b7126999 100644
--- a/isort/api.py
+++ b/isort/api.py
@@ -18,6 +18,9 @@ from .io import File
from .settings import DEFAULT_CONFIG, FILE_SKIP_COMMENTS, Config
IMPORT_START_IDENTIFIERS = ("from ", "from.import", "import ", "import*")
+
+CIMPORT_IDENTIFIERS = ("cimport ", "cimport*", "from.cimport")
+IMPORT_START_IDENTIFIERS = ("from ", "from.import", "import ", "import*") + CIMPORT_IDENTIFIERS
COMMENT_INDICATORS = ('"""', "'''", "'", '"', "#")
@@ -149,6 +152,7 @@ def sort_imports(
line_separator: str = config.line_ending
add_imports: List[str] = [format_natural(addition) for addition in config.add_imports]
import_section: str = ""
+ next_import_section: str = ""
in_quote: str = ""
first_comment_index_start: int = -1
first_comment_index_end: int = -1
@@ -158,6 +162,7 @@ def sort_imports(
section_comments = [f"# {heading}" for heading in config.import_headings.values()]
indent: str = ""
isort_off: bool = False
+ cimports: bool = False
for index, line in enumerate(chain(input_stream, (None,))):
if line is None:
@@ -226,7 +231,7 @@ def sort_imports(
contains_imports = True
indent = line[: -len(line.lstrip())]
- import_section += line
+ import_statement = line
while stripped_line.endswith("\\") or (
"(" in stripped_line and ")" not in stripped_line
):
@@ -234,12 +239,32 @@ def sort_imports(
while stripped_line and stripped_line.endswith("\\"):
line = input_stream.readline()
stripped_line = line.strip().split("#")[0]
- import_section += line
+ import_statement += line
else:
while ")" not in stripped_line:
line = input_stream.readline()
stripped_line = line.strip().split("#")[0]
- import_section += line
+ import_statement += line
+
+ cimport_statement: bool = False
+ if (
+ import_statement.lstrip().startswith(CIMPORT_IDENTIFIERS)
+ or "cimport " in import_statement
+ or "cimport*" in import_statement
+ or "cimport(" in import_statement
+ ):
+ cimport_statement = True
+
+ if cimport_statement != cimports:
+ if import_section:
+ next_import_section = import_statement
+ import_statement = ""
+ not_imports = True
+ line = ""
+ else:
+ cimports = cimport_statement
+
+ import_section + import_statement
else:
not_imports = True
@@ -277,7 +302,10 @@ def sort_imports(
line.lstrip() for line in import_section.split(line_separator)
)
sorted_import_section = output.sorted_imports(
- parse.file_contents(import_section, config=config), config, extension
+ parse.file_contents(import_section, config=config),
+ config,
+ extension,
+ import_type="cimport" if cimports else "import",
)
if indent:
sorted_import_section = (
@@ -291,7 +319,10 @@ def sort_imports(
indent = ""
contains_imports = False
- import_section = ""
+ if next_import_section:
+ cimports = not cimports
+ import_section = next_import_section
+ next_import_section = ""
else:
output_stream.write(line)
not_imports = False