summaryrefslogtreecommitdiff
path: root/isort/isort.py
diff options
context:
space:
mode:
Diffstat (limited to 'isort/isort.py')
-rw-r--r--isort/isort.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/isort/isort.py b/isort/isort.py
index 991aed51..3ab68c2d 100644
--- a/isort/isort.py
+++ b/isort/isort.py
@@ -49,7 +49,7 @@ class SortImports(object):
def __init__(self, file_path=None, file_contents=None, file_=None, write_to_stdout=False, check=False,
show_diff=False, settings_path=None, ask_to_apply=False, run_path='', check_skip=True,
- **setting_overrides):
+ extension=None, **setting_overrides):
if not settings_path and file_path:
settings_path = os.path.dirname(os.path.abspath(file_path))
settings_path = settings_path or os.getcwd()
@@ -181,6 +181,11 @@ class SortImports(object):
self.in_lines.append(add_import)
self.number_of_lines = len(self.in_lines)
+ if not extension:
+ self.extension = file_name.split('.')[-1] if file_name else "py"
+ else:
+ self.extension = extension
+
self.out_lines = []
self.comments = {'from': {}, 'straight': {}, 'nested': {}, 'above': {'straight': {}, 'from': {}}}
self.imports = OrderedDict()
@@ -672,8 +677,10 @@ class SortImports(object):
if self.config['lines_after_imports'] != -1:
self.out_lines[imports_tail:0] = ["" for line in range(self.config['lines_after_imports'])]
- elif next_construct.startswith("def ") or next_construct.startswith("class ") or \
- next_construct.startswith("@") or next_construct.startswith("async def"):
+ elif self.extension != "pyi" and (next_construct.startswith("def ") or
+ next_construct.startswith("class ") or
+ next_construct.startswith("@") or
+ next_construct.startswith("async def")):
self.out_lines[imports_tail:0] = ["", ""]
else:
self.out_lines[imports_tail:0] = [""]