This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,337 @@
"""A package for handling imports
This package provides tools for modifying module imports after
refactorings or as a separate task.
"""
import rope.base.evaluate
from rope.base import libutils
from rope.base.change import ChangeSet, ChangeContents
from rope.refactor import occurrences, rename
from rope.refactor.importutils import module_imports, actions
from rope.refactor.importutils.importinfo import NormalImport, FromImport
import rope.base.codeanalyze
class ImportOrganizer(object):
"""Perform some import-related commands
Each method returns a `rope.base.change.Change` object.
"""
def __init__(self, project):
self.project = project
self.import_tools = ImportTools(self.project)
def organize_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.organize_imports, resource, offset
)
def expand_star_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.expand_stars, resource, offset
)
def froms_to_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.froms_to_imports, resource, offset
)
def relatives_to_absolutes(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.relatives_to_absolutes, resource, offset
)
def handle_long_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.handle_long_imports, resource, offset
)
def _perform_command_on_import_tools(self, method, resource, offset):
pymodule = self.project.get_pymodule(resource)
before_performing = pymodule.source_code
import_filter = None
if offset is not None:
import_filter = self._line_filter(pymodule.lines.get_line_number(offset))
result = method(pymodule, import_filter=import_filter)
if result is not None and result != before_performing:
changes = ChangeSet(
method.__name__.replace("_", " ") + " in <%s>" % resource.path
)
changes.add_change(ChangeContents(resource, result))
return changes
def _line_filter(self, lineno):
def import_filter(import_stmt):
return import_stmt.start_line <= lineno < import_stmt.end_line
return import_filter
class ImportTools(object):
def __init__(self, project):
self.project = project
def get_import(self, resource):
"""The import statement for `resource`"""
module_name = libutils.modname(resource)
return NormalImport(((module_name, None),))
def get_from_import(self, resource, name):
"""The from import statement for `name` in `resource`"""
module_name = libutils.modname(resource)
names = []
if isinstance(name, list):
names = [(imported, None) for imported in name]
else:
names = [
(name, None),
]
return FromImport(module_name, 0, tuple(names))
def module_imports(self, module, imports_filter=None):
return module_imports.ModuleImports(self.project, module, imports_filter)
def froms_to_imports(self, pymodule, import_filter=None):
pymodule = self._clean_up_imports(pymodule, import_filter)
module_imports = self.module_imports(pymodule, import_filter)
for import_stmt in module_imports.imports:
if import_stmt.readonly or not self._is_transformable_to_normal(
import_stmt.import_info
):
continue
pymodule = self._from_to_normal(pymodule, import_stmt)
# Adding normal imports in place of froms
module_imports = self.module_imports(pymodule, import_filter)
for import_stmt in module_imports.imports:
if not import_stmt.readonly and self._is_transformable_to_normal(
import_stmt.import_info
):
import_stmt.import_info = NormalImport(
((import_stmt.import_info.module_name, None),)
)
module_imports.remove_duplicates()
return module_imports.get_changed_source()
def expand_stars(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
module_imports.expand_stars()
return module_imports.get_changed_source()
def _from_to_normal(self, pymodule, import_stmt):
resource = pymodule.get_resource()
from_import = import_stmt.import_info
module_name = from_import.module_name
for name, alias in from_import.names_and_aliases:
imported = name
if alias is not None:
imported = alias
occurrence_finder = occurrences.create_finder(
self.project, imported, pymodule[imported], imports=False
)
source = rename.rename_in_module(
occurrence_finder,
module_name + "." + name,
pymodule=pymodule,
replace_primary=True,
)
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
return pymodule
def _clean_up_imports(self, pymodule, import_filter):
resource = pymodule.get_resource()
module_with_imports = self.module_imports(pymodule, import_filter)
module_with_imports.expand_stars()
source = module_with_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
source = self.relatives_to_absolutes(pymodule)
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
module_with_imports = self.module_imports(pymodule, import_filter)
module_with_imports.remove_duplicates()
module_with_imports.remove_unused_imports()
source = module_with_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
return pymodule
def relatives_to_absolutes(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
to_be_absolute_list = module_imports.get_relative_to_absolute_list()
for name, absolute_name in to_be_absolute_list:
pymodule = self._rename_in_module(pymodule, name, absolute_name)
module_imports = self.module_imports(pymodule, import_filter)
module_imports.get_relative_to_absolute_list()
source = module_imports.get_changed_source()
if source is None:
source = pymodule.source_code
return source
def _is_transformable_to_normal(self, import_info):
if not isinstance(import_info, FromImport):
return False
return True
def organize_imports(
self,
pymodule,
unused=True,
duplicates=True,
selfs=True,
sort=True,
import_filter=None,
):
if unused or duplicates:
module_imports = self.module_imports(pymodule, import_filter)
if unused:
module_imports.remove_unused_imports()
if self.project.prefs.get("split_imports"):
module_imports.force_single_imports()
if duplicates:
module_imports.remove_duplicates()
source = module_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
if selfs:
pymodule = self._remove_self_imports(pymodule, import_filter)
if sort:
return self.sort_imports(pymodule, import_filter)
else:
return pymodule.source_code
def _remove_self_imports(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
(
to_be_fixed,
to_be_renamed,
) = module_imports.get_self_import_fix_and_rename_list()
for name in to_be_fixed:
try:
pymodule = self._rename_in_module(pymodule, name, "", till_dot=True)
except ValueError:
# There is a self import with direct access to it
return pymodule
for name, new_name in to_be_renamed:
pymodule = self._rename_in_module(pymodule, name, new_name)
module_imports = self.module_imports(pymodule, import_filter)
module_imports.get_self_import_fix_and_rename_list()
source = module_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
return pymodule
def _rename_in_module(self, pymodule, name, new_name, till_dot=False):
old_name = name.split(".")[-1]
old_pyname = rope.base.evaluate.eval_str(pymodule.get_scope(), name)
occurrence_finder = occurrences.create_finder(
self.project, old_name, old_pyname, imports=False
)
changes = rope.base.codeanalyze.ChangeCollector(pymodule.source_code)
for occurrence in occurrence_finder.find_occurrences(pymodule=pymodule):
start, end = occurrence.get_primary_range()
if till_dot:
new_end = pymodule.source_code.index(".", end) + 1
space = pymodule.source_code[end : new_end - 1].strip()
if not space == "":
for c in space:
if not c.isspace() and c not in "\\":
raise ValueError()
end = new_end
changes.add_change(start, end, new_name)
source = changes.get_changed()
if source is not None:
pymodule = libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
return pymodule
def sort_imports(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
module_imports.sort_imports()
return module_imports.get_changed_source()
def handle_long_imports(
self, pymodule, maxdots=2, maxlength=27, import_filter=None
):
# IDEA: `maxdots` and `maxlength` can be specified in project config
# adding new from imports
module_imports = self.module_imports(pymodule, import_filter)
to_be_fixed = module_imports.handle_long_imports(maxdots, maxlength)
# performing the renaming
pymodule = libutils.get_string_module(
self.project,
module_imports.get_changed_source(),
resource=pymodule.get_resource(),
)
for name in to_be_fixed:
pymodule = self._rename_in_module(pymodule, name, name.split(".")[-1])
# organizing imports
return self.organize_imports(
pymodule, selfs=False, sort=False, import_filter=import_filter
)
def get_imports(project, pydefined):
"""A shortcut for getting the `ImportInfo` used in a scope"""
pymodule = pydefined.get_module()
module = module_imports.ModuleImports(project, pymodule)
if pymodule == pydefined:
return [stmt.import_info for stmt in module.imports]
return module.get_used_imports(pydefined)
def get_module_imports(project, pymodule):
"""A shortcut for creating a `module_imports.ModuleImports` object"""
return module_imports.ModuleImports(project, pymodule)
def add_import(project, pymodule, module_name, name=None):
imports = get_module_imports(project, pymodule)
candidates = []
names = []
selected_import = None
# from mod import name
if name is not None:
from_import = FromImport(module_name, 0, [(name, None)])
names.append(name)
candidates.append(from_import)
# from pkg import mod
if "." in module_name:
pkg, mod = module_name.rsplit(".", 1)
from_import = FromImport(pkg, 0, [(mod, None)])
if project.prefs.get("prefer_module_from_imports"):
selected_import = from_import
candidates.append(from_import)
if name:
names.append(mod + "." + name)
else:
names.append(mod)
# import mod
normal_import = NormalImport([(module_name, None)])
if name:
names.append(module_name + "." + name)
else:
names.append(module_name)
candidates.append(normal_import)
visitor = actions.AddingVisitor(project, candidates)
if selected_import is None:
selected_import = normal_import
for import_statement in imports.imports:
if import_statement.accept(visitor):
selected_import = visitor.import_info
break
imports.add_import(selected_import)
imported_name = names[candidates.index(selected_import)]
return imports.get_changed_source(), imported_name

View file

@ -0,0 +1,367 @@
from rope.base import libutils
from rope.base import pyobjects, exceptions, stdmods
from rope.refactor import occurrences
from rope.refactor.importutils import importinfo
class ImportInfoVisitor(object):
def dispatch(self, import_):
try:
method_name = "visit" + import_.import_info.__class__.__name__
method = getattr(self, method_name)
return method(import_, import_.import_info)
except exceptions.ModuleNotFoundError:
pass
def visitEmptyImport(self, import_stmt, import_info):
pass
def visitNormalImport(self, import_stmt, import_info):
pass
def visitFromImport(self, import_stmt, import_info):
pass
class RelativeToAbsoluteVisitor(ImportInfoVisitor):
def __init__(self, project, current_folder):
self.to_be_absolute = []
self.project = project
self.folder = current_folder
self.context = importinfo.ImportContext(project, current_folder)
def visitNormalImport(self, import_stmt, import_info):
self.to_be_absolute.extend(self._get_relative_to_absolute_list(import_info))
new_pairs = []
for name, alias in import_info.names_and_aliases:
resource = self.project.find_module(name, folder=self.folder)
if resource is None:
new_pairs.append((name, alias))
continue
absolute_name = libutils.modname(resource)
new_pairs.append((absolute_name, alias))
if not import_info._are_name_and_alias_lists_equal(
new_pairs, import_info.names_and_aliases
):
import_stmt.import_info = importinfo.NormalImport(new_pairs)
def _get_relative_to_absolute_list(self, import_info):
result = []
for name, alias in import_info.names_and_aliases:
if alias is not None:
continue
resource = self.project.find_module(name, folder=self.folder)
if resource is None:
continue
absolute_name = libutils.modname(resource)
if absolute_name != name:
result.append((name, absolute_name))
return result
def visitFromImport(self, import_stmt, import_info):
resource = import_info.get_imported_resource(self.context)
if resource is None:
return None
absolute_name = libutils.modname(resource)
if import_info.module_name != absolute_name:
import_stmt.import_info = importinfo.FromImport(
absolute_name, 0, import_info.names_and_aliases
)
class FilteringVisitor(ImportInfoVisitor):
def __init__(self, project, folder, can_select):
self.to_be_absolute = []
self.project = project
self.can_select = self._transform_can_select(can_select)
self.context = importinfo.ImportContext(project, folder)
def _transform_can_select(self, can_select):
def can_select_name_and_alias(name, alias):
imported = name
if alias is not None:
imported = alias
return can_select(imported)
return can_select_name_and_alias
def visitNormalImport(self, import_stmt, import_info):
new_pairs = []
for name, alias in import_info.names_and_aliases:
if self.can_select(name, alias):
new_pairs.append((name, alias))
return importinfo.NormalImport(new_pairs)
def visitFromImport(self, import_stmt, import_info):
if _is_future(import_info):
return import_info
new_pairs = []
if import_info.is_star_import():
for name in import_info.get_imported_names(self.context):
if self.can_select(name, None):
new_pairs.append(import_info.names_and_aliases[0])
break
else:
for name, alias in import_info.names_and_aliases:
if self.can_select(name, alias):
new_pairs.append((name, alias))
return importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
class RemovingVisitor(ImportInfoVisitor):
def __init__(self, project, folder, can_select):
self.to_be_absolute = []
self.project = project
self.filtering = FilteringVisitor(project, folder, can_select)
def dispatch(self, import_):
result = self.filtering.dispatch(import_)
if result is not None:
import_.import_info = result
class AddingVisitor(ImportInfoVisitor):
"""A class for adding imports
Given a list of `ImportInfo`, it tries to add each import to the
module and returns `True` and gives up when an import can be added
to older ones.
"""
def __init__(self, project, import_list):
self.project = project
self.import_list = import_list
self.import_info = None
def dispatch(self, import_):
for import_info in self.import_list:
self.import_info = import_info
if ImportInfoVisitor.dispatch(self, import_):
return True
# TODO: Handle adding relative and absolute imports
def visitNormalImport(self, import_stmt, import_info):
if not isinstance(self.import_info, import_info.__class__):
return False
# Adding ``import x`` and ``import x.y`` that results ``import x.y``
if (
len(import_info.names_and_aliases)
== len(self.import_info.names_and_aliases)
== 1
):
imported1 = import_info.names_and_aliases[0]
imported2 = self.import_info.names_and_aliases[0]
if imported1[1] == imported2[1] is None:
if imported1[0].startswith(imported2[0] + "."):
return True
if imported2[0].startswith(imported1[0] + "."):
import_stmt.import_info = self.import_info
return True
# Multiple imports using a single import statement is discouraged
# so we won't bother adding them.
if self.import_info._are_name_and_alias_lists_equal(
import_info.names_and_aliases, self.import_info.names_and_aliases
):
return True
def visitFromImport(self, import_stmt, import_info):
if (
isinstance(self.import_info, import_info.__class__)
and import_info.module_name == self.import_info.module_name
and import_info.level == self.import_info.level
):
if import_info.is_star_import():
return True
if self.import_info.is_star_import():
import_stmt.import_info = self.import_info
return True
if self.project.prefs.get("split_imports"):
return (
self.import_info.names_and_aliases == import_info.names_and_aliases
)
new_pairs = list(import_info.names_and_aliases)
for pair in self.import_info.names_and_aliases:
if pair not in new_pairs:
new_pairs.append(pair)
import_stmt.import_info = importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
return True
class ExpandStarsVisitor(ImportInfoVisitor):
def __init__(self, project, folder, can_select):
self.project = project
self.filtering = FilteringVisitor(project, folder, can_select)
self.context = importinfo.ImportContext(project, folder)
def visitNormalImport(self, import_stmt, import_info):
self.filtering.dispatch(import_stmt)
def visitFromImport(self, import_stmt, import_info):
if import_info.is_star_import():
new_pairs = []
for name in import_info.get_imported_names(self.context):
new_pairs.append((name, None))
new_import = importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
import_stmt.import_info = self.filtering.visitFromImport(None, new_import)
else:
self.filtering.dispatch(import_stmt)
class SelfImportVisitor(ImportInfoVisitor):
def __init__(self, project, current_folder, resource):
self.project = project
self.folder = current_folder
self.resource = resource
self.to_be_fixed = set()
self.to_be_renamed = set()
self.context = importinfo.ImportContext(project, current_folder)
def visitNormalImport(self, import_stmt, import_info):
new_pairs = []
for name, alias in import_info.names_and_aliases:
resource = self.project.find_module(name, folder=self.folder)
if resource is not None and resource == self.resource:
imported = name
if alias is not None:
imported = alias
self.to_be_fixed.add(imported)
else:
new_pairs.append((name, alias))
if not import_info._are_name_and_alias_lists_equal(
new_pairs, import_info.names_and_aliases
):
import_stmt.import_info = importinfo.NormalImport(new_pairs)
def visitFromImport(self, import_stmt, import_info):
resource = import_info.get_imported_resource(self.context)
if resource is None:
return
if resource == self.resource:
self._importing_names_from_self(import_info, import_stmt)
return
pymodule = self.project.get_pymodule(resource)
new_pairs = []
for name, alias in import_info.names_and_aliases:
try:
result = pymodule[name].get_object()
if (
isinstance(result, pyobjects.PyModule)
and result.get_resource() == self.resource
):
imported = name
if alias is not None:
imported = alias
self.to_be_fixed.add(imported)
else:
new_pairs.append((name, alias))
except exceptions.AttributeNotFoundError:
new_pairs.append((name, alias))
if not import_info._are_name_and_alias_lists_equal(
new_pairs, import_info.names_and_aliases
):
import_stmt.import_info = importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
def _importing_names_from_self(self, import_info, import_stmt):
if not import_info.is_star_import():
for name, alias in import_info.names_and_aliases:
if alias is not None:
self.to_be_renamed.add((alias, name))
import_stmt.empty_import()
class SortingVisitor(ImportInfoVisitor):
def __init__(self, project, current_folder):
self.project = project
self.folder = current_folder
self.standard = set()
self.third_party = set()
self.in_project = set()
self.future = set()
self.context = importinfo.ImportContext(project, current_folder)
def visitNormalImport(self, import_stmt, import_info):
if import_info.names_and_aliases:
name, alias = import_info.names_and_aliases[0]
resource = self.project.find_module(name, folder=self.folder)
self._check_imported_resource(import_stmt, resource, name)
def visitFromImport(self, import_stmt, import_info):
resource = import_info.get_imported_resource(self.context)
self._check_imported_resource(import_stmt, resource, import_info.module_name)
def _check_imported_resource(self, import_stmt, resource, imported_name):
info = import_stmt.import_info
if resource is not None and resource.project == self.project:
self.in_project.add(import_stmt)
elif _is_future(info):
self.future.add(import_stmt)
elif imported_name.split(".")[0] in stdmods.standard_modules():
self.standard.add(import_stmt)
else:
self.third_party.add(import_stmt)
class LongImportVisitor(ImportInfoVisitor):
def __init__(self, current_folder, project, maxdots, maxlength):
self.maxdots = maxdots
self.maxlength = maxlength
self.to_be_renamed = set()
self.current_folder = current_folder
self.project = project
self.new_imports = []
def visitNormalImport(self, import_stmt, import_info):
for name, alias in import_info.names_and_aliases:
if alias is None and self._is_long(name):
self.to_be_renamed.add(name)
last_dot = name.rindex(".")
from_ = name[:last_dot]
imported = name[last_dot + 1 :]
self.new_imports.append(
importinfo.FromImport(from_, 0, ((imported, None),))
)
def _is_long(self, name):
return name.count(".") > self.maxdots or (
"." in name and len(name) > self.maxlength
)
class RemovePyNameVisitor(ImportInfoVisitor):
def __init__(self, project, pymodule, pyname, folder):
self.pymodule = pymodule
self.pyname = pyname
self.context = importinfo.ImportContext(project, folder)
def visitFromImport(self, import_stmt, import_info):
new_pairs = []
if not import_info.is_star_import():
for name, alias in import_info.names_and_aliases:
try:
pyname = self.pymodule[alias or name]
if occurrences.same_pyname(self.pyname, pyname):
continue
except exceptions.AttributeNotFoundError:
pass
new_pairs.append((name, alias))
return importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
def dispatch(self, import_):
result = ImportInfoVisitor.dispatch(self, import_)
if result is not None:
import_.import_info = result
def _is_future(info):
return isinstance(info, importinfo.FromImport) and info.module_name == "__future__"

View file

@ -0,0 +1,203 @@
class ImportStatement(object):
"""Represent an import in a module
`readonly` attribute controls whether this import can be changed
by import actions or not.
"""
def __init__(
self, import_info, start_line, end_line, main_statement=None, blank_lines=0
):
self.start_line = start_line
self.end_line = end_line
self.readonly = False
self.main_statement = main_statement
self._import_info = None
self.import_info = import_info
self._is_changed = False
self.new_start = None
self.blank_lines = blank_lines
def _get_import_info(self):
return self._import_info
def _set_import_info(self, new_import):
if (
not self.readonly
and new_import is not None
and not new_import == self._import_info
):
self._is_changed = True
self._import_info = new_import
import_info = property(_get_import_info, _set_import_info)
def get_import_statement(self):
if self._is_changed or self.main_statement is None:
return self.import_info.get_import_statement()
else:
return self.main_statement
def empty_import(self):
self.import_info = ImportInfo.get_empty_import()
def move(self, lineno, blank_lines=0):
self.new_start = lineno
self.blank_lines = blank_lines
def get_old_location(self):
return self.start_line, self.end_line
def get_new_start(self):
return self.new_start
def is_changed(self):
return self._is_changed or (
self.new_start is not None or self.new_start != self.start_line
)
def accept(self, visitor):
return visitor.dispatch(self)
class ImportInfo(object):
def get_imported_primaries(self, context):
pass
def get_imported_names(self, context):
return [
primary.split(".")[0] for primary in self.get_imported_primaries(context)
]
def get_import_statement(self):
pass
def is_empty(self):
pass
def __hash__(self):
return hash(self.get_import_statement())
def _are_name_and_alias_lists_equal(self, list1, list2):
if len(list1) != len(list2):
return False
for pair1, pair2 in zip(list1, list2):
if pair1 != pair2:
return False
return True
def __eq__(self, obj):
return (
isinstance(obj, self.__class__)
and self.get_import_statement() == obj.get_import_statement()
)
def __ne__(self, obj):
return not self.__eq__(obj)
@staticmethod
def get_empty_import():
return EmptyImport()
class NormalImport(ImportInfo):
def __init__(self, names_and_aliases):
self.names_and_aliases = names_and_aliases
def get_imported_primaries(self, context):
result = []
for name, alias in self.names_and_aliases:
if alias:
result.append(alias)
else:
result.append(name)
return result
def get_import_statement(self):
result = "import "
for name, alias in self.names_and_aliases:
result += name
if alias:
result += " as " + alias
result += ", "
return result[:-2]
def is_empty(self):
return len(self.names_and_aliases) == 0
class FromImport(ImportInfo):
def __init__(self, module_name, level, names_and_aliases):
self.module_name = module_name
self.level = level
self.names_and_aliases = names_and_aliases
def get_imported_primaries(self, context):
if self.names_and_aliases[0][0] == "*":
module = self.get_imported_module(context)
return [name for name in module if not name.startswith("_")]
result = []
for name, alias in self.names_and_aliases:
if alias:
result.append(alias)
else:
result.append(name)
return result
def get_imported_resource(self, context):
"""Get the imported resource
Returns `None` if module was not found.
"""
if self.level == 0:
return context.project.find_module(self.module_name, folder=context.folder)
else:
return context.project.find_relative_module(
self.module_name, context.folder, self.level
)
def get_imported_module(self, context):
"""Get the imported `PyModule`
Raises `rope.base.exceptions.ModuleNotFoundError` if module
could not be found.
"""
if self.level == 0:
return context.project.get_module(self.module_name, context.folder)
else:
return context.project.get_relative_module(
self.module_name, context.folder, self.level
)
def get_import_statement(self):
result = "from " + "." * self.level + self.module_name + " import "
for name, alias in self.names_and_aliases:
result += name
if alias:
result += " as " + alias
result += ", "
return result[:-2]
def is_empty(self):
return len(self.names_and_aliases) == 0
def is_star_import(self):
return len(self.names_and_aliases) > 0 and self.names_and_aliases[0][0] == "*"
class EmptyImport(ImportInfo):
names_and_aliases = []
def is_empty(self):
return True
def get_imported_primaries(self, context):
return []
class ImportContext(object):
def __init__(self, project, folder):
self.project = project
self.folder = folder

View file

@ -0,0 +1,551 @@
from rope.base import ast
from rope.base import exceptions
from rope.base import pynames
from rope.base import utils
from rope.refactor.importutils import actions
from rope.refactor.importutils import importinfo
class ModuleImports(object):
def __init__(self, project, pymodule, import_filter=None):
self.project = project
self.pymodule = pymodule
self.separating_lines = 0
self.filter = import_filter
self.sorted = False
@property
@utils.saveit
def imports(self):
finder = _GlobalImportFinder(self.pymodule)
result = finder.find_import_statements()
self.separating_lines = finder.get_separating_line_count()
if self.filter is not None:
for import_stmt in result:
if not self.filter(import_stmt):
import_stmt.readonly = True
return result
def _get_unbound_names(self, defined_pyobject):
visitor = _GlobalUnboundNameFinder(self.pymodule, defined_pyobject)
ast.walk(self.pymodule.get_ast(), visitor)
return visitor.unbound
def _get_all_star_list(self, pymodule):
result = set()
try:
all_star_list = pymodule.get_attribute("__all__")
except exceptions.AttributeNotFoundError:
return result
# FIXME: Need a better way to recursively infer possible values.
# Currently pyobjects can recursively infer type, but not values.
# Do a very basic 1-level value inference
for assignment in all_star_list.assignments:
if isinstance(assignment.ast_node, ast.List):
stack = list(assignment.ast_node.elts)
while stack:
el = stack.pop()
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
name = pymodule.get_attribute(el.id)
if isinstance(name, pynames.AssignedName):
for av in name.assignments:
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(el, ast.IfExp):
stack.append(el.body)
stack.append(el.orelse)
return result
def remove_unused_imports(self):
can_select = _OneTimeSelector(
self._get_unbound_names(self.pymodule)
| self._get_all_star_list(self.pymodule)
)
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
import_statement.accept(visitor)
def get_used_imports(self, defined_pyobject):
result = []
can_select = _OneTimeSelector(self._get_unbound_names(defined_pyobject))
visitor = actions.FilteringVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
new_import = import_statement.accept(visitor)
if new_import is not None and not new_import.is_empty():
result.append(new_import)
return result
def get_changed_source(self):
if not self.project.prefs.get("pull_imports_to_top") and not self.sorted:
return "".join(self._rewrite_imports(self.imports))
# Make sure we forward a removed import's preceding blank
# lines count to the following import statement.
prev_stmt = None
for stmt in self.imports:
if prev_stmt is not None and prev_stmt.import_info.is_empty():
stmt.blank_lines = max(prev_stmt.blank_lines, stmt.blank_lines)
prev_stmt = stmt
# The new list of imports.
imports = [stmt for stmt in self.imports if not stmt.import_info.is_empty()]
after_removing = self._remove_imports(self.imports)
first_non_blank = self._first_non_blank_line(after_removing, 0)
first_import = self._first_import_line() - 1
result = []
# Writing module docs
result.extend(after_removing[first_non_blank:first_import])
# Writing imports
sorted_imports = sorted(imports, key=self._get_location)
for stmt in sorted_imports:
if stmt != sorted_imports[0]:
result.append("\n" * stmt.blank_lines)
result.append(stmt.get_import_statement() + "\n")
if sorted_imports and first_non_blank < len(after_removing):
result.append("\n" * self.separating_lines)
# Writing the body
first_after_imports = self._first_non_blank_line(after_removing, first_import)
result.extend(after_removing[first_after_imports:])
return "".join(result)
def _get_import_location(self, stmt):
start = stmt.get_new_start()
if start is None:
start = stmt.get_old_location()[0]
return start
def _get_location(self, stmt):
if stmt.get_new_start() is not None:
return stmt.get_new_start()
else:
return stmt.get_old_location()[0]
def _remove_imports(self, imports):
lines = self.pymodule.source_code.splitlines(True)
after_removing = []
first_import_line = self._first_import_line()
last_index = 0
for stmt in imports:
start, end = stmt.get_old_location()
blank_lines = 0
if start != first_import_line:
blank_lines = _count_blank_lines(
lines.__getitem__, start - 2, last_index - 1, -1
)
after_removing.extend(lines[last_index : start - 1 - blank_lines])
last_index = end - 1
after_removing.extend(lines[last_index:])
return after_removing
def _rewrite_imports(self, imports):
lines = self.pymodule.source_code.splitlines(True)
after_rewriting = []
last_index = 0
for stmt in imports:
start, end = stmt.get_old_location()
after_rewriting.extend(lines[last_index : start - 1])
if not stmt.import_info.is_empty():
after_rewriting.append(stmt.get_import_statement() + "\n")
last_index = end - 1
after_rewriting.extend(lines[last_index:])
return after_rewriting
def _first_non_blank_line(self, lines, lineno):
return lineno + _count_blank_lines(lines.__getitem__, lineno, len(lines))
def add_import(self, import_info):
visitor = actions.AddingVisitor(self.project, [import_info])
for import_statement in self.imports:
if import_statement.accept(visitor):
break
else:
lineno = self._get_new_import_lineno()
blanks = self._get_new_import_blanks()
self.imports.append(
importinfo.ImportStatement(
import_info, lineno, lineno, blank_lines=blanks
)
)
def _get_new_import_blanks(self):
return 0
def _get_new_import_lineno(self):
if self.imports:
return self.imports[-1].end_line
return 1
def filter_names(self, can_select):
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
import_statement.accept(visitor)
def expand_stars(self):
can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule))
visitor = actions.ExpandStarsVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
import_statement.accept(visitor)
def remove_duplicates(self):
added_imports = []
for import_stmt in self.imports:
visitor = actions.AddingVisitor(self.project, [import_stmt.import_info])
for added_import in added_imports:
if added_import.accept(visitor):
import_stmt.empty_import()
else:
added_imports.append(import_stmt)
def force_single_imports(self):
"""force a single import per statement"""
for import_stmt in self.imports[:]:
import_info = import_stmt.import_info
if import_info.is_empty() or import_stmt.readonly:
continue
if len(import_info.names_and_aliases) > 1:
for name_and_alias in import_info.names_and_aliases:
if hasattr(import_info, "module_name"):
new_import = importinfo.FromImport(
import_info.module_name, import_info.level, [name_and_alias]
)
else:
new_import = importinfo.NormalImport([name_and_alias])
self.add_import(new_import)
import_stmt.empty_import()
def get_relative_to_absolute_list(self):
visitor = actions.RelativeToAbsoluteVisitor(
self.project, self._current_folder()
)
for import_stmt in self.imports:
if not import_stmt.readonly:
import_stmt.accept(visitor)
return visitor.to_be_absolute
def get_self_import_fix_and_rename_list(self):
visitor = actions.SelfImportVisitor(
self.project, self._current_folder(), self.pymodule.get_resource()
)
for import_stmt in self.imports:
if not import_stmt.readonly:
import_stmt.accept(visitor)
return visitor.to_be_fixed, visitor.to_be_renamed
def _current_folder(self):
return self.pymodule.get_resource().parent
def sort_imports(self):
if self.project.prefs.get("sort_imports_alphabetically"):
sort_kwargs = dict(key=self._get_import_name)
else:
sort_kwargs = dict(key=self._key_imports)
# IDEA: Sort from import list
visitor = actions.SortingVisitor(self.project, self._current_folder())
for import_statement in self.imports:
import_statement.accept(visitor)
in_projects = sorted(visitor.in_project, **sort_kwargs)
third_party = sorted(visitor.third_party, **sort_kwargs)
standards = sorted(visitor.standard, **sort_kwargs)
future = sorted(visitor.future, **sort_kwargs)
last_index = self._first_import_line()
last_index = self._move_imports(future, last_index, 0)
last_index = self._move_imports(standards, last_index, 1)
last_index = self._move_imports(third_party, last_index, 1)
last_index = self._move_imports(in_projects, last_index, 1)
self.separating_lines = 2
self.sorted = True
def _first_import_line(self):
nodes = self.pymodule.get_ast().body
lineno = 0
if self.pymodule.get_doc() is not None:
lineno = 1
if len(nodes) > lineno:
if isinstance(nodes[lineno], ast.Import) or isinstance(
nodes[lineno], ast.ImportFrom
):
return nodes[lineno].lineno
lineno = self.pymodule.logical_lines.logical_line_in(nodes[lineno].lineno)[
0
]
else:
lineno = self.pymodule.lines.length()
return lineno - _count_blank_lines(
self.pymodule.lines.get_line, lineno - 1, 1, -1
)
def _get_import_name(self, import_stmt):
import_info = import_stmt.import_info
if hasattr(import_info, "module_name"):
return "%s.%s" % (
import_info.module_name,
import_info.names_and_aliases[0][0],
)
else:
return import_info.names_and_aliases[0][0]
def _key_imports(self, stm1):
str1 = stm1.get_import_statement()
return str1.startswith("from "), str1
# str1 = stmt1.get_import_statement()
# str2 = stmt2.get_import_statement()
# if str1.startswith('from ') and not str2.startswith('from '):
# return 1
# if not str1.startswith('from ') and str2.startswith('from '):
# return -1
# return cmp(str1, str2)
def _move_imports(self, imports, index, blank_lines):
if imports:
imports[0].move(index, blank_lines)
index += 1
if len(imports) > 1:
for stmt in imports[1:]:
stmt.move(index)
index += 1
return index
def handle_long_imports(self, maxdots, maxlength):
visitor = actions.LongImportVisitor(
self._current_folder(), self.project, maxdots, maxlength
)
for import_statement in self.imports:
if not import_statement.readonly:
import_statement.accept(visitor)
for import_info in visitor.new_imports:
self.add_import(import_info)
return visitor.to_be_renamed
def remove_pyname(self, pyname):
"""Removes pyname when imported in ``from mod import x``"""
visitor = actions.RemovePyNameVisitor(
self.project, self.pymodule, pyname, self._current_folder()
)
for import_stmt in self.imports:
import_stmt.accept(visitor)
def _count_blank_lines(get_line, start, end, step=1):
count = 0
for idx in range(start, end, step):
if get_line(idx).strip() == "":
count += 1
else:
break
return count
class _OneTimeSelector(object):
def __init__(self, names):
self.names = names
self.selected_names = set()
def __call__(self, imported_primary):
if self._can_name_be_added(imported_primary):
for name in self._get_dotted_tokens(imported_primary):
self.selected_names.add(name)
return True
return False
def _get_dotted_tokens(self, imported_primary):
tokens = imported_primary.split(".")
for i in range(len(tokens)):
yield ".".join(tokens[: i + 1])
def _can_name_be_added(self, imported_primary):
for name in self._get_dotted_tokens(imported_primary):
if name in self.names and name not in self.selected_names:
return True
return False
class _UnboundNameFinder(object):
def __init__(self, pyobject):
self.pyobject = pyobject
def _visit_child_scope(self, node):
pyobject = (
self.pyobject.get_module()
.get_scope()
.get_inner_scope_for_line(node.lineno)
.pyobject
)
visitor = _LocalUnboundNameFinder(pyobject, self)
for child in ast.get_child_nodes(node):
ast.walk(child, visitor)
def _FunctionDef(self, node):
self._visit_child_scope(node)
def _ClassDef(self, node):
self._visit_child_scope(node)
def _Name(self, node):
if self._get_root()._is_node_interesting(node) and not self.is_bound(node.id):
self.add_unbound(node.id)
def _Attribute(self, node):
result = []
while isinstance(node, ast.Attribute):
result.append(node.attr)
node = node.value
if isinstance(node, ast.Name):
result.append(node.id)
primary = ".".join(reversed(result))
if self._get_root()._is_node_interesting(node) and not self.is_bound(
primary
):
self.add_unbound(primary)
else:
ast.walk(node, self)
def _get_root(self):
pass
def is_bound(self, name, propagated=False):
pass
def add_unbound(self, name):
pass
class _GlobalUnboundNameFinder(_UnboundNameFinder):
def __init__(self, pymodule, wanted_pyobject):
super(_GlobalUnboundNameFinder, self).__init__(pymodule)
self.unbound = set()
self.names = set()
for name, pyname in pymodule._get_structural_attributes().items():
if not isinstance(pyname, (pynames.ImportedName, pynames.ImportedModule)):
self.names.add(name)
wanted_scope = wanted_pyobject.get_scope()
self.start = wanted_scope.get_start()
self.end = wanted_scope.get_end() + 1
def _get_root(self):
return self
def is_bound(self, primary, propagated=False):
name = primary.split(".")[0]
return name in self.names
def add_unbound(self, name):
names = name.split(".")
for i in range(len(names)):
self.unbound.add(".".join(names[: i + 1]))
def _is_node_interesting(self, node):
return self.start <= node.lineno < self.end
class _LocalUnboundNameFinder(_UnboundNameFinder):
def __init__(self, pyobject, parent):
super(_LocalUnboundNameFinder, self).__init__(pyobject)
self.parent = parent
def _get_root(self):
return self.parent._get_root()
def is_bound(self, primary, propagated=False):
name = primary.split(".")[0]
if propagated:
names = self.pyobject.get_scope().get_propagated_names()
else:
names = self.pyobject.get_scope().get_names()
if name in names or self.parent.is_bound(name, propagated=True):
return True
return False
def add_unbound(self, name):
self.parent.add_unbound(name)
class _GlobalImportFinder(object):
def __init__(self, pymodule):
self.current_folder = None
if pymodule.get_resource():
self.current_folder = pymodule.get_resource().parent
self.pymodule = pymodule
self.imports = []
self.pymodule = pymodule
self.lines = self.pymodule.lines
def visit_import(self, node, end_line):
start_line = node.lineno
import_statement = importinfo.ImportStatement(
importinfo.NormalImport(self._get_names(node.names)),
start_line,
end_line,
self._get_text(start_line, end_line),
blank_lines=self._count_empty_lines_before(start_line),
)
self.imports.append(import_statement)
def _count_empty_lines_before(self, lineno):
return _count_blank_lines(self.lines.get_line, lineno - 1, 0, -1)
def _count_empty_lines_after(self, lineno):
return _count_blank_lines(self.lines.get_line, lineno + 1, self.lines.length())
def get_separating_line_count(self):
if not self.imports:
return 0
return self._count_empty_lines_after(self.imports[-1].end_line - 1)
def _get_text(self, start_line, end_line):
result = []
for index in range(start_line, end_line):
result.append(self.lines.get_line(index))
return "\n".join(result)
def visit_from(self, node, end_line):
level = 0
if node.level:
level = node.level
import_info = importinfo.FromImport(
node.module or "", # see comment at rope.base.ast.walk
level,
self._get_names(node.names),
)
start_line = node.lineno
self.imports.append(
importinfo.ImportStatement(
import_info,
node.lineno,
end_line,
self._get_text(start_line, end_line),
blank_lines=self._count_empty_lines_before(start_line),
)
)
def _get_names(self, alias_names):
result = []
for alias in alias_names:
result.append((alias.name, alias.asname))
return result
def find_import_statements(self):
nodes = self.pymodule.get_ast().body
for index, node in enumerate(nodes):
if isinstance(node, (ast.Import, ast.ImportFrom)):
lines = self.pymodule.logical_lines
end_line = lines.logical_line_in(node.lineno)[1] + 1
if isinstance(node, ast.Import):
self.visit_import(node, end_line)
if isinstance(node, ast.ImportFrom):
self.visit_from(node, end_line)
return self.imports