From 4dc2a71af88d59b664a8bfdcfcc7acd6412bed76 Mon Sep 17 00:00:00 2001 From: Timothy Crosley Date: Mon, 6 Jan 2020 08:50:52 -0800 Subject: Add cdef and cpdef support --- isort/api.py | 4 +++- isort/output.py | 11 ++++------- tests/test_isort.py | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/isort/api.py b/isort/api.py index bf86cd0e..004c5f5b 100644 --- a/isort/api.py +++ b/isort/api.py @@ -323,9 +323,11 @@ def sort_imports( output_stream.write(line) indent = "" - contains_imports = False if next_import_section: cimports = not cimports + contains_imports = True + else: + contains_imports = False import_section = next_import_section next_import_section = "" else: diff --git a/isort/output.py b/isort/output.py index 5f2659f7..cd925c91 100644 --- a/isort/output.py +++ b/isort/output.py @@ -1,7 +1,7 @@ import copy import itertools from functools import partial -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from isort.format import format_simplified @@ -9,6 +9,8 @@ from . import parse, sorting, wrap from .comments import add_to_line as with_comments from .settings import DEFAULT_CONFIG, Config +STATEMENT_DECLERATIONS: Tuple[str, ...] = ("def ", "cdef ", "cpdef ", "class ", "@", "async def") + def sorted_imports( parsed: parse.ParsedContent, @@ -202,12 +204,7 @@ def sorted_imports( if config.lines_after_imports != -1: formatted_output[imports_tail:0] = ["" for line in range(config.lines_after_imports)] - elif extension != "pyi" and ( - next_construct.startswith("def ") - or next_construct.startswith("class ") - or next_construct.startswith("@") - or next_construct.startswith("async def") - ): + elif extension != "pyi" and next_construct.startswith(STATEMENT_DECLERATIONS): formatted_output[imports_tail:0] = ["", ""] else: formatted_output[imports_tail:0] = [""] diff --git a/tests/test_isort.py b/tests/test_isort.py index 11a7079f..0228b414 100644 --- a/tests/test_isort.py +++ b/tests/test_isort.py @@ -4627,6 +4627,44 @@ IF CEF_VERSION == 3: SortImports(file_contents=test_input).output == expected_output +def test_cdef_support(): + assert ( + SortImports( + file_contents=""" +from cpython.version cimport PY_MAJOR_VERSION + +cdef extern from *: + ctypedef CefString ConstCefString "const CefString" +""" + ).output + == """ +from cpython.version cimport PY_MAJOR_VERSION + + +cdef extern from *: + ctypedef CefString ConstCefString "const CefString" +""" + ) + + assert ( + SortImports( + file_contents=""" +from cpython.version cimport PY_MAJOR_VERSION + +cpdef extern from *: + ctypedef CefString ConstCefString "const CefString" +""" + ).output + == """ +from cpython.version cimport PY_MAJOR_VERSION + + +cpdef extern from *: + ctypedef CefString ConstCefString "const CefString" +""" + ) + + def test_top_level_import_order() -> None: test_input = ( "from rest_framework import throttling, viewsets\n" -- cgit v1.2.1