This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,68 @@
"""rope refactor package
This package contains modules that perform python refactorings.
Refactoring classes perform refactorings in 4 steps:
1. Collect some data for performing the refactoring and use them
to construct a refactoring class. Like::
renamer = Rename(project, resource, offset)
2. Some refactorings give you useful information about the
refactoring after their construction. Like::
print(renamer.get_old_name())
3. Give the refactoring class more information about how to
perform the refactoring and get the changes this refactoring is
going to make. This is done by calling `get_changes` method of the
refactoring class. Like::
changes = renamer.get_changes(new_name)
4. You can commit the changes. Like::
project.do(changes)
These steps are like the steps IDEs usually do for performing a
refactoring. These are the things an IDE does in each step:
1. Construct a refactoring object by giving it information like
resource, offset and ... . Some of the refactoring problems (like
performing rename refactoring on language keywords) can be reported
here.
2. Print some information about the refactoring and ask the user
about the information that are necessary for completing the
refactoring (like new name).
3. Call the `get_changes` by passing it information asked from
the user (if necessary) and get and preview the changes returned by
it.
4. perform the refactoring.
From ``0.5m5`` release the `get_changes()` method of some time-
consuming refactorings take an optional `rope.base.taskhandle.
TaskHandle` parameter. You can use this object for stopping or
monitoring the progress of refactorings.
"""
from rope.refactor.importutils import ImportOrganizer # noqa
from rope.refactor.topackage import ModuleToPackage # noqa
__all__ = [
"rename",
"move",
"inline",
"extract",
"restructure",
"topackage",
"importutils",
"usefunction",
"change_signature",
"encapsulate_field",
"introduce_factory",
"introduce_parameter",
"localtofield",
"method_object",
"multiproject",
]

View file

@ -0,0 +1,378 @@
import copy
import rope.base.exceptions
from rope.base import codeanalyze
from rope.base import evaluate
from rope.base import pyobjects
from rope.base import taskhandle
from rope.base import utils
from rope.base import worder
from rope.base.change import ChangeContents, ChangeSet
from rope.refactor import occurrences, functionutils
class ChangeSignature(object):
def __init__(self, project, resource, offset):
self.project = project
self.resource = resource
self.offset = offset
self._set_name_and_pyname()
if (
self.pyname is None
or self.pyname.get_object() is None
or not isinstance(self.pyname.get_object(), pyobjects.PyFunction)
):
raise rope.base.exceptions.RefactoringError(
"Change method signature should be performed on functions"
)
def _set_name_and_pyname(self):
self.name = worder.get_name_at(self.resource, self.offset)
this_pymodule = self.project.get_pymodule(self.resource)
self.primary, self.pyname = evaluate.eval_location2(this_pymodule, self.offset)
if self.pyname is None:
return
pyobject = self.pyname.get_object()
if isinstance(pyobject, pyobjects.PyClass) and "__init__" in pyobject:
self.pyname = pyobject["__init__"]
self.name = "__init__"
pyobject = self.pyname.get_object()
self.others = None
if (
self.name == "__init__"
and isinstance(pyobject, pyobjects.PyFunction)
and isinstance(pyobject.parent, pyobjects.PyClass)
):
pyclass = pyobject.parent
self.others = (pyclass.get_name(), pyclass.parent[pyclass.get_name()])
def _change_calls(
self,
call_changer,
in_hierarchy=None,
resources=None,
handle=taskhandle.NullTaskHandle(),
):
if resources is None:
resources = self.project.get_python_files()
changes = ChangeSet("Changing signature of <%s>" % self.name)
job_set = handle.create_jobset("Collecting Changes", len(resources))
finder = occurrences.create_finder(
self.project,
self.name,
self.pyname,
instance=self.primary,
in_hierarchy=in_hierarchy and self.is_method(),
)
if self.others:
name, pyname = self.others
constructor_finder = occurrences.create_finder(
self.project, name, pyname, only_calls=True
)
finder = _MultipleFinders([finder, constructor_finder])
for file in resources:
job_set.started_job(file.path)
change_calls = _ChangeCallsInModule(
self.project, finder, file, call_changer
)
changed_file = change_calls.get_changed_module()
if changed_file is not None:
changes.add_change(ChangeContents(file, changed_file))
job_set.finished_job()
return changes
def get_args(self):
"""Get function arguments.
Return a list of ``(name, default)`` tuples for all but star
and double star arguments. For arguments that don't have a
default, `None` will be used.
"""
return self._definfo().args_with_defaults
def is_method(self):
pyfunction = self.pyname.get_object()
return isinstance(pyfunction.parent, pyobjects.PyClass)
@utils.deprecated("Use `ChangeSignature.get_args()` instead")
def get_definition_info(self):
return self._definfo()
def _definfo(self):
return functionutils.DefinitionInfo.read(self.pyname.get_object())
@utils.deprecated()
def normalize(self):
changer = _FunctionChangers(
self.pyname.get_object(), self.get_definition_info(), [ArgumentNormalizer()]
)
return self._change_calls(changer)
@utils.deprecated()
def remove(self, index):
changer = _FunctionChangers(
self.pyname.get_object(),
self.get_definition_info(),
[ArgumentRemover(index)],
)
return self._change_calls(changer)
@utils.deprecated()
def add(self, index, name, default=None, value=None):
changer = _FunctionChangers(
self.pyname.get_object(),
self.get_definition_info(),
[ArgumentAdder(index, name, default, value)],
)
return self._change_calls(changer)
@utils.deprecated()
def inline_default(self, index):
changer = _FunctionChangers(
self.pyname.get_object(),
self.get_definition_info(),
[ArgumentDefaultInliner(index)],
)
return self._change_calls(changer)
@utils.deprecated()
def reorder(self, new_ordering):
changer = _FunctionChangers(
self.pyname.get_object(),
self.get_definition_info(),
[ArgumentReorderer(new_ordering)],
)
return self._change_calls(changer)
def get_changes(
self,
changers,
in_hierarchy=False,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Get changes caused by this refactoring
`changers` is a list of `_ArgumentChanger`. If `in_hierarchy`
is `True` the changers are applyed to all matching methods in
the class hierarchy.
`resources` can be a list of `rope.base.resource.File` that
should be searched for occurrences; if `None` all python files
in the project are searched.
"""
function_changer = _FunctionChangers(
self.pyname.get_object(), self._definfo(), changers
)
return self._change_calls(
function_changer, in_hierarchy, resources, task_handle
)
class _FunctionChangers(object):
def __init__(self, pyfunction, definition_info, changers=None):
self.pyfunction = pyfunction
self.definition_info = definition_info
self.changers = changers
self.changed_definition_infos = self._get_changed_definition_infos()
def _get_changed_definition_infos(self):
result = []
definition_info = self.definition_info
result.append(definition_info)
for changer in self.changers:
definition_info = copy.deepcopy(definition_info)
changer.change_definition_info(definition_info)
result.append(definition_info)
return result
def change_definition(self, call):
return self.changed_definition_infos[-1].to_string()
def change_call(self, primary, pyname, call):
call_info = functionutils.CallInfo.read(
primary, pyname, self.definition_info, call
)
mapping = functionutils.ArgumentMapping(self.definition_info, call_info)
for definition_info, changer in zip(
self.changed_definition_infos, self.changers
):
changer.change_argument_mapping(definition_info, mapping)
return mapping.to_call_info(self.changed_definition_infos[-1]).to_string()
class _ArgumentChanger(object):
def change_definition_info(self, definition_info):
pass
def change_argument_mapping(self, definition_info, argument_mapping):
pass
class ArgumentNormalizer(_ArgumentChanger):
pass
class ArgumentRemover(_ArgumentChanger):
def __init__(self, index):
self.index = index
def change_definition_info(self, call_info):
if self.index < len(call_info.args_with_defaults):
del call_info.args_with_defaults[self.index]
elif (
self.index == len(call_info.args_with_defaults)
and call_info.args_arg is not None
):
call_info.args_arg = None
elif (
self.index == len(call_info.args_with_defaults)
and call_info.args_arg is None
and call_info.keywords_arg is not None
) or (
self.index == len(call_info.args_with_defaults) + 1
and call_info.args_arg is not None
and call_info.keywords_arg is not None
):
call_info.keywords_arg = None
def change_argument_mapping(self, definition_info, mapping):
if self.index < len(definition_info.args_with_defaults):
name = definition_info.args_with_defaults[0]
if name in mapping.param_dict:
del mapping.param_dict[name]
class ArgumentAdder(_ArgumentChanger):
def __init__(self, index, name, default=None, value=None):
self.index = index
self.name = name
self.default = default
self.value = value
def change_definition_info(self, definition_info):
for pair in definition_info.args_with_defaults:
if pair[0] == self.name:
raise rope.base.exceptions.RefactoringError(
"Adding duplicate parameter: <%s>." % self.name
)
definition_info.args_with_defaults.insert(self.index, (self.name, self.default))
def change_argument_mapping(self, definition_info, mapping):
if self.value is not None:
mapping.param_dict[self.name] = self.value
class ArgumentDefaultInliner(_ArgumentChanger):
def __init__(self, index):
self.index = index
self.remove = False
def change_definition_info(self, definition_info):
if self.remove:
definition_info.args_with_defaults[self.index] = (
definition_info.args_with_defaults[self.index][0],
None,
)
def change_argument_mapping(self, definition_info, mapping):
default = definition_info.args_with_defaults[self.index][1]
name = definition_info.args_with_defaults[self.index][0]
if default is not None and name not in mapping.param_dict:
mapping.param_dict[name] = default
class ArgumentReorderer(_ArgumentChanger):
def __init__(self, new_order, autodef=None):
"""Construct an `ArgumentReorderer`
Note that the `new_order` is a list containing the new
position of parameters; not the position each parameter
is going to be moved to. (changed in ``0.5m4``)
For example changing ``f(a, b, c)`` to ``f(c, a, b)``
requires passing ``[2, 0, 1]`` and *not* ``[1, 2, 0]``.
The `autodef` (automatic default) argument, forces rope to use
it as a default if a default is needed after the change. That
happens when an argument without default is moved after
another that has a default value. Note that `autodef` should
be a string or `None`; the latter disables adding automatic
default.
"""
self.new_order = new_order
self.autodef = autodef
def change_definition_info(self, definition_info):
new_args = list(definition_info.args_with_defaults)
for new_index, index in enumerate(self.new_order):
new_args[new_index] = definition_info.args_with_defaults[index]
seen_default = False
for index, (arg, default) in enumerate(list(new_args)):
if default is not None:
seen_default = True
if seen_default and default is None and self.autodef is not None:
new_args[index] = (arg, self.autodef)
definition_info.args_with_defaults = new_args
class _ChangeCallsInModule(object):
def __init__(self, project, occurrence_finder, resource, call_changer):
self.project = project
self.occurrence_finder = occurrence_finder
self.resource = resource
self.call_changer = call_changer
def get_changed_module(self):
word_finder = worder.Worder(self.source)
change_collector = codeanalyze.ChangeCollector(self.source)
for occurrence in self.occurrence_finder.find_occurrences(self.resource):
if not occurrence.is_called() and not occurrence.is_defined():
continue
start, end = occurrence.get_primary_range()
begin_parens, end_parens = word_finder.get_word_parens_range(end - 1)
if occurrence.is_called():
primary, pyname = occurrence.get_primary_and_pyname()
changed_call = self.call_changer.change_call(
primary, pyname, self.source[start:end_parens]
)
else:
changed_call = self.call_changer.change_definition(
self.source[start:end_parens]
)
if changed_call is not None:
change_collector.add_change(start, end_parens, changed_call)
return change_collector.get_changed()
@property
@utils.saveit
def pymodule(self):
return self.project.get_pymodule(self.resource)
@property
@utils.saveit
def source(self):
if self.resource is not None:
return self.resource.read()
else:
return self.pymodule.source_code
@property
@utils.saveit
def lines(self):
return self.pymodule.lines
class _MultipleFinders(object):
def __init__(self, finders):
self.finders = finders
def find_occurrences(self, resource=None, pymodule=None):
all_occurrences = []
for finder in self.finders:
all_occurrences.extend(finder.find_occurrences(resource, pymodule))
all_occurrences.sort(key=lambda x: x.get_primary_range())
return all_occurrences

View file

@ -0,0 +1,221 @@
from rope.base import evaluate
from rope.base import exceptions
from rope.base import libutils
from rope.base import pynames
from rope.base import taskhandle
from rope.base import utils
from rope.base import worder
from rope.base.change import ChangeSet, ChangeContents
from rope.refactor import sourceutils, occurrences
class EncapsulateField(object):
def __init__(self, project, resource, offset):
self.project = project
self.name = worder.get_name_at(resource, offset)
this_pymodule = self.project.get_pymodule(resource)
self.pyname = evaluate.eval_location(this_pymodule, offset)
if not self._is_an_attribute(self.pyname):
raise exceptions.RefactoringError(
"Encapsulate field should be performed on class attributes."
)
self.resource = self.pyname.get_definition_location()[0].get_resource()
def get_changes(
self,
getter=None,
setter=None,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Get the changes this refactoring makes
If `getter` is not `None`, that will be the name of the
getter, otherwise ``get_${field_name}`` will be used. The
same is true for `setter` and if it is None set_${field_name} is
used.
`resources` can be a list of `rope.base.resource.File` that
the refactoring should be applied on; if `None` all python
files in the project are searched.
"""
if resources is None:
resources = self.project.get_python_files()
changes = ChangeSet("Encapsulate field <%s>" % self.name)
job_set = task_handle.create_jobset("Collecting Changes", len(resources))
if getter is None:
getter = "get_" + self.name
if setter is None:
setter = "set_" + self.name
renamer = GetterSetterRenameInModule(
self.project, self.name, self.pyname, getter, setter
)
for file in resources:
job_set.started_job(file.path)
if file == self.resource:
result = self._change_holding_module(changes, renamer, getter, setter)
changes.add_change(ChangeContents(self.resource, result))
else:
result = renamer.get_changed_module(file)
if result is not None:
changes.add_change(ChangeContents(file, result))
job_set.finished_job()
return changes
def get_field_name(self):
"""Get the name of the field to be encapsulated"""
return self.name
def _is_an_attribute(self, pyname):
if pyname is not None and isinstance(pyname, pynames.AssignedName):
pymodule, lineno = self.pyname.get_definition_location()
scope = pymodule.get_scope().get_inner_scope_for_line(lineno)
if scope.get_kind() == "Class":
return pyname in scope.get_names().values()
parent = scope.parent
if parent is not None and parent.get_kind() == "Class":
return pyname in parent.get_names().values()
return False
def _get_defining_class_scope(self):
defining_scope = self._get_defining_scope()
if defining_scope.get_kind() == "Function":
defining_scope = defining_scope.parent
return defining_scope
def _get_defining_scope(self):
pymodule, line = self.pyname.get_definition_location()
return pymodule.get_scope().get_inner_scope_for_line(line)
def _change_holding_module(self, changes, renamer, getter, setter):
pymodule = self.project.get_pymodule(self.resource)
class_scope = self._get_defining_class_scope()
defining_object = self._get_defining_scope().pyobject
start, end = sourceutils.get_body_region(defining_object)
new_source = renamer.get_changed_module(
pymodule=pymodule, skip_start=start, skip_end=end
)
if new_source is not None:
pymodule = libutils.get_string_module(
self.project, new_source, self.resource
)
class_scope = pymodule.get_scope().get_inner_scope_for_line(
class_scope.get_start()
)
indents = sourceutils.get_indent(self.project) * " "
getter = "def %s(self):\n%sreturn self.%s" % (getter, indents, self.name)
setter = "def %s(self, value):\n%sself.%s = value" % (
setter,
indents,
self.name,
)
new_source = sourceutils.add_methods(pymodule, class_scope, [getter, setter])
return new_source
class GetterSetterRenameInModule(object):
def __init__(self, project, name, pyname, getter, setter):
self.project = project
self.name = name
self.finder = occurrences.create_finder(project, name, pyname)
self.getter = getter
self.setter = setter
def get_changed_module(
self, resource=None, pymodule=None, skip_start=0, skip_end=0
):
change_finder = _FindChangesForModule(
self, resource, pymodule, skip_start, skip_end
)
return change_finder.get_changed_module()
class _FindChangesForModule(object):
def __init__(self, finder, resource, pymodule, skip_start, skip_end):
self.project = finder.project
self.finder = finder.finder
self.getter = finder.getter
self.setter = finder.setter
self.resource = resource
self.pymodule = pymodule
self.last_modified = 0
self.last_set = None
self.set_index = None
self.skip_start = skip_start
self.skip_end = skip_end
def get_changed_module(self):
result = []
for occurrence in self.finder.find_occurrences(self.resource, self.pymodule):
start, end = occurrence.get_word_range()
if self.skip_start <= start < self.skip_end:
continue
self._manage_writes(start, result)
result.append(self.source[self.last_modified : start])
if self._is_assigned_in_a_tuple_assignment(occurrence):
raise exceptions.RefactoringError(
"Cannot handle tuple assignments in encapsulate field."
)
if occurrence.is_written():
assignment_type = self.worder.get_assignment_type(start)
if assignment_type == "=":
result.append(self.setter + "(")
else:
var_name = (
self.source[occurrence.get_primary_range()[0] : start]
+ self.getter
+ "()"
)
result.append(
self.setter + "(" + var_name + " %s " % assignment_type[:-1]
)
current_line = self.lines.get_line_number(start)
start_line, end_line = self.pymodule.logical_lines.logical_line_in(
current_line
)
self.last_set = self.lines.get_line_end(end_line)
end = self.source.index("=", end) + 1
self.set_index = len(result)
else:
result.append(self.getter + "()")
self.last_modified = end
if self.last_modified != 0:
self._manage_writes(len(self.source), result)
result.append(self.source[self.last_modified :])
return "".join(result)
return None
def _manage_writes(self, offset, result):
if self.last_set is not None and self.last_set <= offset:
result.append(self.source[self.last_modified : self.last_set])
set_value = "".join(result[self.set_index :]).strip()
del result[self.set_index :]
result.append(set_value + ")")
self.last_modified = self.last_set
self.last_set = None
def _is_assigned_in_a_tuple_assignment(self, occurrence):
offset = occurrence.get_word_range()[0]
return self.worder.is_assigned_in_a_tuple_assignment(offset)
@property
@utils.saveit
def source(self):
if self.resource is not None:
return self.resource.read()
else:
return self.pymodule.source_code
@property
@utils.saveit
def lines(self):
if self.pymodule is None:
self.pymodule = self.project.get_pymodule(self.resource)
return self.pymodule.lines
@property
@utils.saveit
def worder(self):
return worder.Worder(self.source)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,246 @@
import rope.base.exceptions
import rope.base.pyobjects
from rope.base.builtins import Lambda
from rope.base import worder
class DefinitionInfo(object):
def __init__(
self, function_name, is_method, args_with_defaults, args_arg, keywords_arg
):
self.function_name = function_name
self.is_method = is_method
self.args_with_defaults = args_with_defaults
self.args_arg = args_arg
self.keywords_arg = keywords_arg
def to_string(self):
return "%s(%s)" % (self.function_name, self.arguments_to_string())
def arguments_to_string(self, from_index=0):
params = []
for arg, default in self.args_with_defaults:
if default is not None:
params.append("%s=%s" % (arg, default))
else:
params.append(arg)
if self.args_arg is not None:
params.append("*" + self.args_arg)
if self.keywords_arg:
params.append("**" + self.keywords_arg)
return ", ".join(params[from_index:])
@staticmethod
def _read(pyfunction, code):
kind = pyfunction.get_kind()
is_method = kind == "method"
is_lambda = kind == "lambda"
info = _FunctionParser(code, is_method, is_lambda)
args, keywords = info.get_parameters()
args_arg = None
keywords_arg = None
if args and args[-1].startswith("**"):
keywords_arg = args[-1][2:]
del args[-1]
if args and args[-1].startswith("*"):
args_arg = args[-1][1:]
del args[-1]
args_with_defaults = [(name, None) for name in args]
args_with_defaults.extend(keywords)
return DefinitionInfo(
info.get_function_name(),
is_method,
args_with_defaults,
args_arg,
keywords_arg,
)
@staticmethod
def read(pyfunction):
pymodule = pyfunction.get_module()
word_finder = worder.Worder(pymodule.source_code)
lineno = pyfunction.get_ast().lineno
start = pymodule.lines.get_line_start(lineno)
if isinstance(pyfunction, Lambda):
call = word_finder.get_lambda_and_args(start)
else:
call = word_finder.get_function_and_args_in_header(start)
return DefinitionInfo._read(pyfunction, call)
class CallInfo(object):
def __init__(
self,
function_name,
args,
keywords,
args_arg,
keywords_arg,
implicit_arg,
constructor,
):
self.function_name = function_name
self.args = args
self.keywords = keywords
self.args_arg = args_arg
self.keywords_arg = keywords_arg
self.implicit_arg = implicit_arg
self.constructor = constructor
def to_string(self):
function = self.function_name
if self.implicit_arg:
function = self.args[0] + "." + self.function_name
params = []
start = 0
if self.implicit_arg or self.constructor:
start = 1
if self.args[start:]:
params.extend(self.args[start:])
if self.keywords:
params.extend(["%s=%s" % (name, value) for name, value in self.keywords])
if self.args_arg is not None:
params.append("*" + self.args_arg)
if self.keywords_arg:
params.append("**" + self.keywords_arg)
return "%s(%s)" % (function, ", ".join(params))
@staticmethod
def read(primary, pyname, definition_info, code):
is_method_call = CallInfo._is_method_call(primary, pyname)
is_constructor = CallInfo._is_class(pyname)
is_classmethod = CallInfo._is_classmethod(pyname)
info = _FunctionParser(code, is_method_call or is_classmethod)
args, keywords = info.get_parameters()
args_arg = None
keywords_arg = None
if args and args[-1].startswith("**"):
keywords_arg = args[-1][2:]
del args[-1]
if args and args[-1].startswith("*"):
args_arg = args[-1][1:]
del args[-1]
if is_constructor:
args.insert(0, definition_info.args_with_defaults[0][0])
return CallInfo(
info.get_function_name(),
args,
keywords,
args_arg,
keywords_arg,
is_method_call or is_classmethod,
is_constructor,
)
@staticmethod
def _is_method_call(primary, pyname):
return (
primary is not None
and isinstance(primary.get_object().get_type(), rope.base.pyobjects.PyClass)
and CallInfo._is_method(pyname)
)
@staticmethod
def _is_class(pyname):
return pyname is not None and isinstance(
pyname.get_object(), rope.base.pyobjects.PyClass
)
@staticmethod
def _is_method(pyname):
if pyname is not None and isinstance(
pyname.get_object(), rope.base.pyobjects.PyFunction
):
return pyname.get_object().get_kind() == "method"
return False
@staticmethod
def _is_classmethod(pyname):
if pyname is not None and isinstance(
pyname.get_object(), rope.base.pyobjects.PyFunction
):
return pyname.get_object().get_kind() == "classmethod"
return False
class ArgumentMapping(object):
def __init__(self, definition_info, call_info):
self.call_info = call_info
self.param_dict = {}
self.keyword_args = []
self.args_arg = []
for index, value in enumerate(call_info.args):
if index < len(definition_info.args_with_defaults):
name = definition_info.args_with_defaults[index][0]
self.param_dict[name] = value
else:
self.args_arg.append(value)
for name, value in call_info.keywords:
index = -1
for pair in definition_info.args_with_defaults:
if pair[0] == name:
self.param_dict[name] = value
break
else:
self.keyword_args.append((name, value))
def to_call_info(self, definition_info):
args = []
keywords = []
for index in range(len(definition_info.args_with_defaults)):
name = definition_info.args_with_defaults[index][0]
if name in self.param_dict:
args.append(self.param_dict[name])
else:
for i in range(index, len(definition_info.args_with_defaults)):
name = definition_info.args_with_defaults[i][0]
if name in self.param_dict:
keywords.append((name, self.param_dict[name]))
break
args.extend(self.args_arg)
keywords.extend(self.keyword_args)
return CallInfo(
self.call_info.function_name,
args,
keywords,
self.call_info.args_arg,
self.call_info.keywords_arg,
self.call_info.implicit_arg,
self.call_info.constructor,
)
class _FunctionParser(object):
def __init__(self, call, implicit_arg, is_lambda=False):
self.call = call
self.implicit_arg = implicit_arg
self.word_finder = worder.Worder(self.call)
if is_lambda:
self.last_parens = self.call.rindex(":")
else:
self.last_parens = self.call.rindex(")")
self.first_parens = self.word_finder._find_parens_start(self.last_parens)
def get_parameters(self):
args, keywords = self.word_finder.get_parameters(
self.first_parens, self.last_parens
)
if self.is_called_as_a_method():
instance = self.call[: self.call.rindex(".", 0, self.first_parens)]
args.insert(0, instance.strip())
return args, keywords
def get_instance(self):
if self.is_called_as_a_method():
return self.word_finder.get_primary_at(
self.call.rindex(".", 0, self.first_parens) - 1
)
def get_function_name(self):
if self.is_called_as_a_method():
return self.word_finder.get_word_at(self.first_parens - 1)
else:
return self.word_finder.get_primary_at(self.first_parens - 1)
def is_called_as_a_method(self):
return self.implicit_arg and "." in self.call[: self.first_parens]

View file

@ -0,0 +1,337 @@
"""A package for handling imports
This package provides tools for modifying module imports after
refactorings or as a separate task.
"""
import rope.base.evaluate
from rope.base import libutils
from rope.base.change import ChangeSet, ChangeContents
from rope.refactor import occurrences, rename
from rope.refactor.importutils import module_imports, actions
from rope.refactor.importutils.importinfo import NormalImport, FromImport
import rope.base.codeanalyze
class ImportOrganizer(object):
"""Perform some import-related commands
Each method returns a `rope.base.change.Change` object.
"""
def __init__(self, project):
self.project = project
self.import_tools = ImportTools(self.project)
def organize_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.organize_imports, resource, offset
)
def expand_star_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.expand_stars, resource, offset
)
def froms_to_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.froms_to_imports, resource, offset
)
def relatives_to_absolutes(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.relatives_to_absolutes, resource, offset
)
def handle_long_imports(self, resource, offset=None):
return self._perform_command_on_import_tools(
self.import_tools.handle_long_imports, resource, offset
)
def _perform_command_on_import_tools(self, method, resource, offset):
pymodule = self.project.get_pymodule(resource)
before_performing = pymodule.source_code
import_filter = None
if offset is not None:
import_filter = self._line_filter(pymodule.lines.get_line_number(offset))
result = method(pymodule, import_filter=import_filter)
if result is not None and result != before_performing:
changes = ChangeSet(
method.__name__.replace("_", " ") + " in <%s>" % resource.path
)
changes.add_change(ChangeContents(resource, result))
return changes
def _line_filter(self, lineno):
def import_filter(import_stmt):
return import_stmt.start_line <= lineno < import_stmt.end_line
return import_filter
class ImportTools(object):
def __init__(self, project):
self.project = project
def get_import(self, resource):
"""The import statement for `resource`"""
module_name = libutils.modname(resource)
return NormalImport(((module_name, None),))
def get_from_import(self, resource, name):
"""The from import statement for `name` in `resource`"""
module_name = libutils.modname(resource)
names = []
if isinstance(name, list):
names = [(imported, None) for imported in name]
else:
names = [
(name, None),
]
return FromImport(module_name, 0, tuple(names))
def module_imports(self, module, imports_filter=None):
return module_imports.ModuleImports(self.project, module, imports_filter)
def froms_to_imports(self, pymodule, import_filter=None):
pymodule = self._clean_up_imports(pymodule, import_filter)
module_imports = self.module_imports(pymodule, import_filter)
for import_stmt in module_imports.imports:
if import_stmt.readonly or not self._is_transformable_to_normal(
import_stmt.import_info
):
continue
pymodule = self._from_to_normal(pymodule, import_stmt)
# Adding normal imports in place of froms
module_imports = self.module_imports(pymodule, import_filter)
for import_stmt in module_imports.imports:
if not import_stmt.readonly and self._is_transformable_to_normal(
import_stmt.import_info
):
import_stmt.import_info = NormalImport(
((import_stmt.import_info.module_name, None),)
)
module_imports.remove_duplicates()
return module_imports.get_changed_source()
def expand_stars(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
module_imports.expand_stars()
return module_imports.get_changed_source()
def _from_to_normal(self, pymodule, import_stmt):
resource = pymodule.get_resource()
from_import = import_stmt.import_info
module_name = from_import.module_name
for name, alias in from_import.names_and_aliases:
imported = name
if alias is not None:
imported = alias
occurrence_finder = occurrences.create_finder(
self.project, imported, pymodule[imported], imports=False
)
source = rename.rename_in_module(
occurrence_finder,
module_name + "." + name,
pymodule=pymodule,
replace_primary=True,
)
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
return pymodule
def _clean_up_imports(self, pymodule, import_filter):
resource = pymodule.get_resource()
module_with_imports = self.module_imports(pymodule, import_filter)
module_with_imports.expand_stars()
source = module_with_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
source = self.relatives_to_absolutes(pymodule)
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
module_with_imports = self.module_imports(pymodule, import_filter)
module_with_imports.remove_duplicates()
module_with_imports.remove_unused_imports()
source = module_with_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(self.project, source, resource)
return pymodule
def relatives_to_absolutes(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
to_be_absolute_list = module_imports.get_relative_to_absolute_list()
for name, absolute_name in to_be_absolute_list:
pymodule = self._rename_in_module(pymodule, name, absolute_name)
module_imports = self.module_imports(pymodule, import_filter)
module_imports.get_relative_to_absolute_list()
source = module_imports.get_changed_source()
if source is None:
source = pymodule.source_code
return source
def _is_transformable_to_normal(self, import_info):
if not isinstance(import_info, FromImport):
return False
return True
def organize_imports(
self,
pymodule,
unused=True,
duplicates=True,
selfs=True,
sort=True,
import_filter=None,
):
if unused or duplicates:
module_imports = self.module_imports(pymodule, import_filter)
if unused:
module_imports.remove_unused_imports()
if self.project.prefs.get("split_imports"):
module_imports.force_single_imports()
if duplicates:
module_imports.remove_duplicates()
source = module_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
if selfs:
pymodule = self._remove_self_imports(pymodule, import_filter)
if sort:
return self.sort_imports(pymodule, import_filter)
else:
return pymodule.source_code
def _remove_self_imports(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
(
to_be_fixed,
to_be_renamed,
) = module_imports.get_self_import_fix_and_rename_list()
for name in to_be_fixed:
try:
pymodule = self._rename_in_module(pymodule, name, "", till_dot=True)
except ValueError:
# There is a self import with direct access to it
return pymodule
for name, new_name in to_be_renamed:
pymodule = self._rename_in_module(pymodule, name, new_name)
module_imports = self.module_imports(pymodule, import_filter)
module_imports.get_self_import_fix_and_rename_list()
source = module_imports.get_changed_source()
if source is not None:
pymodule = libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
return pymodule
def _rename_in_module(self, pymodule, name, new_name, till_dot=False):
old_name = name.split(".")[-1]
old_pyname = rope.base.evaluate.eval_str(pymodule.get_scope(), name)
occurrence_finder = occurrences.create_finder(
self.project, old_name, old_pyname, imports=False
)
changes = rope.base.codeanalyze.ChangeCollector(pymodule.source_code)
for occurrence in occurrence_finder.find_occurrences(pymodule=pymodule):
start, end = occurrence.get_primary_range()
if till_dot:
new_end = pymodule.source_code.index(".", end) + 1
space = pymodule.source_code[end : new_end - 1].strip()
if not space == "":
for c in space:
if not c.isspace() and c not in "\\":
raise ValueError()
end = new_end
changes.add_change(start, end, new_name)
source = changes.get_changed()
if source is not None:
pymodule = libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
return pymodule
def sort_imports(self, pymodule, import_filter=None):
module_imports = self.module_imports(pymodule, import_filter)
module_imports.sort_imports()
return module_imports.get_changed_source()
def handle_long_imports(
self, pymodule, maxdots=2, maxlength=27, import_filter=None
):
# IDEA: `maxdots` and `maxlength` can be specified in project config
# adding new from imports
module_imports = self.module_imports(pymodule, import_filter)
to_be_fixed = module_imports.handle_long_imports(maxdots, maxlength)
# performing the renaming
pymodule = libutils.get_string_module(
self.project,
module_imports.get_changed_source(),
resource=pymodule.get_resource(),
)
for name in to_be_fixed:
pymodule = self._rename_in_module(pymodule, name, name.split(".")[-1])
# organizing imports
return self.organize_imports(
pymodule, selfs=False, sort=False, import_filter=import_filter
)
def get_imports(project, pydefined):
"""A shortcut for getting the `ImportInfo` used in a scope"""
pymodule = pydefined.get_module()
module = module_imports.ModuleImports(project, pymodule)
if pymodule == pydefined:
return [stmt.import_info for stmt in module.imports]
return module.get_used_imports(pydefined)
def get_module_imports(project, pymodule):
"""A shortcut for creating a `module_imports.ModuleImports` object"""
return module_imports.ModuleImports(project, pymodule)
def add_import(project, pymodule, module_name, name=None):
imports = get_module_imports(project, pymodule)
candidates = []
names = []
selected_import = None
# from mod import name
if name is not None:
from_import = FromImport(module_name, 0, [(name, None)])
names.append(name)
candidates.append(from_import)
# from pkg import mod
if "." in module_name:
pkg, mod = module_name.rsplit(".", 1)
from_import = FromImport(pkg, 0, [(mod, None)])
if project.prefs.get("prefer_module_from_imports"):
selected_import = from_import
candidates.append(from_import)
if name:
names.append(mod + "." + name)
else:
names.append(mod)
# import mod
normal_import = NormalImport([(module_name, None)])
if name:
names.append(module_name + "." + name)
else:
names.append(module_name)
candidates.append(normal_import)
visitor = actions.AddingVisitor(project, candidates)
if selected_import is None:
selected_import = normal_import
for import_statement in imports.imports:
if import_statement.accept(visitor):
selected_import = visitor.import_info
break
imports.add_import(selected_import)
imported_name = names[candidates.index(selected_import)]
return imports.get_changed_source(), imported_name

View file

@ -0,0 +1,367 @@
from rope.base import libutils
from rope.base import pyobjects, exceptions, stdmods
from rope.refactor import occurrences
from rope.refactor.importutils import importinfo
class ImportInfoVisitor(object):
def dispatch(self, import_):
try:
method_name = "visit" + import_.import_info.__class__.__name__
method = getattr(self, method_name)
return method(import_, import_.import_info)
except exceptions.ModuleNotFoundError:
pass
def visitEmptyImport(self, import_stmt, import_info):
pass
def visitNormalImport(self, import_stmt, import_info):
pass
def visitFromImport(self, import_stmt, import_info):
pass
class RelativeToAbsoluteVisitor(ImportInfoVisitor):
def __init__(self, project, current_folder):
self.to_be_absolute = []
self.project = project
self.folder = current_folder
self.context = importinfo.ImportContext(project, current_folder)
def visitNormalImport(self, import_stmt, import_info):
self.to_be_absolute.extend(self._get_relative_to_absolute_list(import_info))
new_pairs = []
for name, alias in import_info.names_and_aliases:
resource = self.project.find_module(name, folder=self.folder)
if resource is None:
new_pairs.append((name, alias))
continue
absolute_name = libutils.modname(resource)
new_pairs.append((absolute_name, alias))
if not import_info._are_name_and_alias_lists_equal(
new_pairs, import_info.names_and_aliases
):
import_stmt.import_info = importinfo.NormalImport(new_pairs)
def _get_relative_to_absolute_list(self, import_info):
result = []
for name, alias in import_info.names_and_aliases:
if alias is not None:
continue
resource = self.project.find_module(name, folder=self.folder)
if resource is None:
continue
absolute_name = libutils.modname(resource)
if absolute_name != name:
result.append((name, absolute_name))
return result
def visitFromImport(self, import_stmt, import_info):
resource = import_info.get_imported_resource(self.context)
if resource is None:
return None
absolute_name = libutils.modname(resource)
if import_info.module_name != absolute_name:
import_stmt.import_info = importinfo.FromImport(
absolute_name, 0, import_info.names_and_aliases
)
class FilteringVisitor(ImportInfoVisitor):
def __init__(self, project, folder, can_select):
self.to_be_absolute = []
self.project = project
self.can_select = self._transform_can_select(can_select)
self.context = importinfo.ImportContext(project, folder)
def _transform_can_select(self, can_select):
def can_select_name_and_alias(name, alias):
imported = name
if alias is not None:
imported = alias
return can_select(imported)
return can_select_name_and_alias
def visitNormalImport(self, import_stmt, import_info):
new_pairs = []
for name, alias in import_info.names_and_aliases:
if self.can_select(name, alias):
new_pairs.append((name, alias))
return importinfo.NormalImport(new_pairs)
def visitFromImport(self, import_stmt, import_info):
if _is_future(import_info):
return import_info
new_pairs = []
if import_info.is_star_import():
for name in import_info.get_imported_names(self.context):
if self.can_select(name, None):
new_pairs.append(import_info.names_and_aliases[0])
break
else:
for name, alias in import_info.names_and_aliases:
if self.can_select(name, alias):
new_pairs.append((name, alias))
return importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
class RemovingVisitor(ImportInfoVisitor):
def __init__(self, project, folder, can_select):
self.to_be_absolute = []
self.project = project
self.filtering = FilteringVisitor(project, folder, can_select)
def dispatch(self, import_):
result = self.filtering.dispatch(import_)
if result is not None:
import_.import_info = result
class AddingVisitor(ImportInfoVisitor):
"""A class for adding imports
Given a list of `ImportInfo`, it tries to add each import to the
module and returns `True` and gives up when an import can be added
to older ones.
"""
def __init__(self, project, import_list):
self.project = project
self.import_list = import_list
self.import_info = None
def dispatch(self, import_):
for import_info in self.import_list:
self.import_info = import_info
if ImportInfoVisitor.dispatch(self, import_):
return True
# TODO: Handle adding relative and absolute imports
def visitNormalImport(self, import_stmt, import_info):
if not isinstance(self.import_info, import_info.__class__):
return False
# Adding ``import x`` and ``import x.y`` that results ``import x.y``
if (
len(import_info.names_and_aliases)
== len(self.import_info.names_and_aliases)
== 1
):
imported1 = import_info.names_and_aliases[0]
imported2 = self.import_info.names_and_aliases[0]
if imported1[1] == imported2[1] is None:
if imported1[0].startswith(imported2[0] + "."):
return True
if imported2[0].startswith(imported1[0] + "."):
import_stmt.import_info = self.import_info
return True
# Multiple imports using a single import statement is discouraged
# so we won't bother adding them.
if self.import_info._are_name_and_alias_lists_equal(
import_info.names_and_aliases, self.import_info.names_and_aliases
):
return True
def visitFromImport(self, import_stmt, import_info):
if (
isinstance(self.import_info, import_info.__class__)
and import_info.module_name == self.import_info.module_name
and import_info.level == self.import_info.level
):
if import_info.is_star_import():
return True
if self.import_info.is_star_import():
import_stmt.import_info = self.import_info
return True
if self.project.prefs.get("split_imports"):
return (
self.import_info.names_and_aliases == import_info.names_and_aliases
)
new_pairs = list(import_info.names_and_aliases)
for pair in self.import_info.names_and_aliases:
if pair not in new_pairs:
new_pairs.append(pair)
import_stmt.import_info = importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
return True
class ExpandStarsVisitor(ImportInfoVisitor):
def __init__(self, project, folder, can_select):
self.project = project
self.filtering = FilteringVisitor(project, folder, can_select)
self.context = importinfo.ImportContext(project, folder)
def visitNormalImport(self, import_stmt, import_info):
self.filtering.dispatch(import_stmt)
def visitFromImport(self, import_stmt, import_info):
if import_info.is_star_import():
new_pairs = []
for name in import_info.get_imported_names(self.context):
new_pairs.append((name, None))
new_import = importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
import_stmt.import_info = self.filtering.visitFromImport(None, new_import)
else:
self.filtering.dispatch(import_stmt)
class SelfImportVisitor(ImportInfoVisitor):
def __init__(self, project, current_folder, resource):
self.project = project
self.folder = current_folder
self.resource = resource
self.to_be_fixed = set()
self.to_be_renamed = set()
self.context = importinfo.ImportContext(project, current_folder)
def visitNormalImport(self, import_stmt, import_info):
new_pairs = []
for name, alias in import_info.names_and_aliases:
resource = self.project.find_module(name, folder=self.folder)
if resource is not None and resource == self.resource:
imported = name
if alias is not None:
imported = alias
self.to_be_fixed.add(imported)
else:
new_pairs.append((name, alias))
if not import_info._are_name_and_alias_lists_equal(
new_pairs, import_info.names_and_aliases
):
import_stmt.import_info = importinfo.NormalImport(new_pairs)
def visitFromImport(self, import_stmt, import_info):
resource = import_info.get_imported_resource(self.context)
if resource is None:
return
if resource == self.resource:
self._importing_names_from_self(import_info, import_stmt)
return
pymodule = self.project.get_pymodule(resource)
new_pairs = []
for name, alias in import_info.names_and_aliases:
try:
result = pymodule[name].get_object()
if (
isinstance(result, pyobjects.PyModule)
and result.get_resource() == self.resource
):
imported = name
if alias is not None:
imported = alias
self.to_be_fixed.add(imported)
else:
new_pairs.append((name, alias))
except exceptions.AttributeNotFoundError:
new_pairs.append((name, alias))
if not import_info._are_name_and_alias_lists_equal(
new_pairs, import_info.names_and_aliases
):
import_stmt.import_info = importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
def _importing_names_from_self(self, import_info, import_stmt):
if not import_info.is_star_import():
for name, alias in import_info.names_and_aliases:
if alias is not None:
self.to_be_renamed.add((alias, name))
import_stmt.empty_import()
class SortingVisitor(ImportInfoVisitor):
def __init__(self, project, current_folder):
self.project = project
self.folder = current_folder
self.standard = set()
self.third_party = set()
self.in_project = set()
self.future = set()
self.context = importinfo.ImportContext(project, current_folder)
def visitNormalImport(self, import_stmt, import_info):
if import_info.names_and_aliases:
name, alias = import_info.names_and_aliases[0]
resource = self.project.find_module(name, folder=self.folder)
self._check_imported_resource(import_stmt, resource, name)
def visitFromImport(self, import_stmt, import_info):
resource = import_info.get_imported_resource(self.context)
self._check_imported_resource(import_stmt, resource, import_info.module_name)
def _check_imported_resource(self, import_stmt, resource, imported_name):
info = import_stmt.import_info
if resource is not None and resource.project == self.project:
self.in_project.add(import_stmt)
elif _is_future(info):
self.future.add(import_stmt)
elif imported_name.split(".")[0] in stdmods.standard_modules():
self.standard.add(import_stmt)
else:
self.third_party.add(import_stmt)
class LongImportVisitor(ImportInfoVisitor):
def __init__(self, current_folder, project, maxdots, maxlength):
self.maxdots = maxdots
self.maxlength = maxlength
self.to_be_renamed = set()
self.current_folder = current_folder
self.project = project
self.new_imports = []
def visitNormalImport(self, import_stmt, import_info):
for name, alias in import_info.names_and_aliases:
if alias is None and self._is_long(name):
self.to_be_renamed.add(name)
last_dot = name.rindex(".")
from_ = name[:last_dot]
imported = name[last_dot + 1 :]
self.new_imports.append(
importinfo.FromImport(from_, 0, ((imported, None),))
)
def _is_long(self, name):
return name.count(".") > self.maxdots or (
"." in name and len(name) > self.maxlength
)
class RemovePyNameVisitor(ImportInfoVisitor):
def __init__(self, project, pymodule, pyname, folder):
self.pymodule = pymodule
self.pyname = pyname
self.context = importinfo.ImportContext(project, folder)
def visitFromImport(self, import_stmt, import_info):
new_pairs = []
if not import_info.is_star_import():
for name, alias in import_info.names_and_aliases:
try:
pyname = self.pymodule[alias or name]
if occurrences.same_pyname(self.pyname, pyname):
continue
except exceptions.AttributeNotFoundError:
pass
new_pairs.append((name, alias))
return importinfo.FromImport(
import_info.module_name, import_info.level, new_pairs
)
def dispatch(self, import_):
result = ImportInfoVisitor.dispatch(self, import_)
if result is not None:
import_.import_info = result
def _is_future(info):
return isinstance(info, importinfo.FromImport) and info.module_name == "__future__"

View file

@ -0,0 +1,203 @@
class ImportStatement(object):
"""Represent an import in a module
`readonly` attribute controls whether this import can be changed
by import actions or not.
"""
def __init__(
self, import_info, start_line, end_line, main_statement=None, blank_lines=0
):
self.start_line = start_line
self.end_line = end_line
self.readonly = False
self.main_statement = main_statement
self._import_info = None
self.import_info = import_info
self._is_changed = False
self.new_start = None
self.blank_lines = blank_lines
def _get_import_info(self):
return self._import_info
def _set_import_info(self, new_import):
if (
not self.readonly
and new_import is not None
and not new_import == self._import_info
):
self._is_changed = True
self._import_info = new_import
import_info = property(_get_import_info, _set_import_info)
def get_import_statement(self):
if self._is_changed or self.main_statement is None:
return self.import_info.get_import_statement()
else:
return self.main_statement
def empty_import(self):
self.import_info = ImportInfo.get_empty_import()
def move(self, lineno, blank_lines=0):
self.new_start = lineno
self.blank_lines = blank_lines
def get_old_location(self):
return self.start_line, self.end_line
def get_new_start(self):
return self.new_start
def is_changed(self):
return self._is_changed or (
self.new_start is not None or self.new_start != self.start_line
)
def accept(self, visitor):
return visitor.dispatch(self)
class ImportInfo(object):
def get_imported_primaries(self, context):
pass
def get_imported_names(self, context):
return [
primary.split(".")[0] for primary in self.get_imported_primaries(context)
]
def get_import_statement(self):
pass
def is_empty(self):
pass
def __hash__(self):
return hash(self.get_import_statement())
def _are_name_and_alias_lists_equal(self, list1, list2):
if len(list1) != len(list2):
return False
for pair1, pair2 in zip(list1, list2):
if pair1 != pair2:
return False
return True
def __eq__(self, obj):
return (
isinstance(obj, self.__class__)
and self.get_import_statement() == obj.get_import_statement()
)
def __ne__(self, obj):
return not self.__eq__(obj)
@staticmethod
def get_empty_import():
return EmptyImport()
class NormalImport(ImportInfo):
def __init__(self, names_and_aliases):
self.names_and_aliases = names_and_aliases
def get_imported_primaries(self, context):
result = []
for name, alias in self.names_and_aliases:
if alias:
result.append(alias)
else:
result.append(name)
return result
def get_import_statement(self):
result = "import "
for name, alias in self.names_and_aliases:
result += name
if alias:
result += " as " + alias
result += ", "
return result[:-2]
def is_empty(self):
return len(self.names_and_aliases) == 0
class FromImport(ImportInfo):
def __init__(self, module_name, level, names_and_aliases):
self.module_name = module_name
self.level = level
self.names_and_aliases = names_and_aliases
def get_imported_primaries(self, context):
if self.names_and_aliases[0][0] == "*":
module = self.get_imported_module(context)
return [name for name in module if not name.startswith("_")]
result = []
for name, alias in self.names_and_aliases:
if alias:
result.append(alias)
else:
result.append(name)
return result
def get_imported_resource(self, context):
"""Get the imported resource
Returns `None` if module was not found.
"""
if self.level == 0:
return context.project.find_module(self.module_name, folder=context.folder)
else:
return context.project.find_relative_module(
self.module_name, context.folder, self.level
)
def get_imported_module(self, context):
"""Get the imported `PyModule`
Raises `rope.base.exceptions.ModuleNotFoundError` if module
could not be found.
"""
if self.level == 0:
return context.project.get_module(self.module_name, context.folder)
else:
return context.project.get_relative_module(
self.module_name, context.folder, self.level
)
def get_import_statement(self):
result = "from " + "." * self.level + self.module_name + " import "
for name, alias in self.names_and_aliases:
result += name
if alias:
result += " as " + alias
result += ", "
return result[:-2]
def is_empty(self):
return len(self.names_and_aliases) == 0
def is_star_import(self):
return len(self.names_and_aliases) > 0 and self.names_and_aliases[0][0] == "*"
class EmptyImport(ImportInfo):
names_and_aliases = []
def is_empty(self):
return True
def get_imported_primaries(self, context):
return []
class ImportContext(object):
def __init__(self, project, folder):
self.project = project
self.folder = folder

View file

@ -0,0 +1,551 @@
from rope.base import ast
from rope.base import exceptions
from rope.base import pynames
from rope.base import utils
from rope.refactor.importutils import actions
from rope.refactor.importutils import importinfo
class ModuleImports(object):
def __init__(self, project, pymodule, import_filter=None):
self.project = project
self.pymodule = pymodule
self.separating_lines = 0
self.filter = import_filter
self.sorted = False
@property
@utils.saveit
def imports(self):
finder = _GlobalImportFinder(self.pymodule)
result = finder.find_import_statements()
self.separating_lines = finder.get_separating_line_count()
if self.filter is not None:
for import_stmt in result:
if not self.filter(import_stmt):
import_stmt.readonly = True
return result
def _get_unbound_names(self, defined_pyobject):
visitor = _GlobalUnboundNameFinder(self.pymodule, defined_pyobject)
ast.walk(self.pymodule.get_ast(), visitor)
return visitor.unbound
def _get_all_star_list(self, pymodule):
result = set()
try:
all_star_list = pymodule.get_attribute("__all__")
except exceptions.AttributeNotFoundError:
return result
# FIXME: Need a better way to recursively infer possible values.
# Currently pyobjects can recursively infer type, but not values.
# Do a very basic 1-level value inference
for assignment in all_star_list.assignments:
if isinstance(assignment.ast_node, ast.List):
stack = list(assignment.ast_node.elts)
while stack:
el = stack.pop()
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
name = pymodule.get_attribute(el.id)
if isinstance(name, pynames.AssignedName):
for av in name.assignments:
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(el, ast.IfExp):
stack.append(el.body)
stack.append(el.orelse)
return result
def remove_unused_imports(self):
can_select = _OneTimeSelector(
self._get_unbound_names(self.pymodule)
| self._get_all_star_list(self.pymodule)
)
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
import_statement.accept(visitor)
def get_used_imports(self, defined_pyobject):
result = []
can_select = _OneTimeSelector(self._get_unbound_names(defined_pyobject))
visitor = actions.FilteringVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
new_import = import_statement.accept(visitor)
if new_import is not None and not new_import.is_empty():
result.append(new_import)
return result
def get_changed_source(self):
if not self.project.prefs.get("pull_imports_to_top") and not self.sorted:
return "".join(self._rewrite_imports(self.imports))
# Make sure we forward a removed import's preceding blank
# lines count to the following import statement.
prev_stmt = None
for stmt in self.imports:
if prev_stmt is not None and prev_stmt.import_info.is_empty():
stmt.blank_lines = max(prev_stmt.blank_lines, stmt.blank_lines)
prev_stmt = stmt
# The new list of imports.
imports = [stmt for stmt in self.imports if not stmt.import_info.is_empty()]
after_removing = self._remove_imports(self.imports)
first_non_blank = self._first_non_blank_line(after_removing, 0)
first_import = self._first_import_line() - 1
result = []
# Writing module docs
result.extend(after_removing[first_non_blank:first_import])
# Writing imports
sorted_imports = sorted(imports, key=self._get_location)
for stmt in sorted_imports:
if stmt != sorted_imports[0]:
result.append("\n" * stmt.blank_lines)
result.append(stmt.get_import_statement() + "\n")
if sorted_imports and first_non_blank < len(after_removing):
result.append("\n" * self.separating_lines)
# Writing the body
first_after_imports = self._first_non_blank_line(after_removing, first_import)
result.extend(after_removing[first_after_imports:])
return "".join(result)
def _get_import_location(self, stmt):
start = stmt.get_new_start()
if start is None:
start = stmt.get_old_location()[0]
return start
def _get_location(self, stmt):
if stmt.get_new_start() is not None:
return stmt.get_new_start()
else:
return stmt.get_old_location()[0]
def _remove_imports(self, imports):
lines = self.pymodule.source_code.splitlines(True)
after_removing = []
first_import_line = self._first_import_line()
last_index = 0
for stmt in imports:
start, end = stmt.get_old_location()
blank_lines = 0
if start != first_import_line:
blank_lines = _count_blank_lines(
lines.__getitem__, start - 2, last_index - 1, -1
)
after_removing.extend(lines[last_index : start - 1 - blank_lines])
last_index = end - 1
after_removing.extend(lines[last_index:])
return after_removing
def _rewrite_imports(self, imports):
lines = self.pymodule.source_code.splitlines(True)
after_rewriting = []
last_index = 0
for stmt in imports:
start, end = stmt.get_old_location()
after_rewriting.extend(lines[last_index : start - 1])
if not stmt.import_info.is_empty():
after_rewriting.append(stmt.get_import_statement() + "\n")
last_index = end - 1
after_rewriting.extend(lines[last_index:])
return after_rewriting
def _first_non_blank_line(self, lines, lineno):
return lineno + _count_blank_lines(lines.__getitem__, lineno, len(lines))
def add_import(self, import_info):
visitor = actions.AddingVisitor(self.project, [import_info])
for import_statement in self.imports:
if import_statement.accept(visitor):
break
else:
lineno = self._get_new_import_lineno()
blanks = self._get_new_import_blanks()
self.imports.append(
importinfo.ImportStatement(
import_info, lineno, lineno, blank_lines=blanks
)
)
def _get_new_import_blanks(self):
return 0
def _get_new_import_lineno(self):
if self.imports:
return self.imports[-1].end_line
return 1
def filter_names(self, can_select):
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
import_statement.accept(visitor)
def expand_stars(self):
can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule))
visitor = actions.ExpandStarsVisitor(
self.project, self._current_folder(), can_select
)
for import_statement in self.imports:
import_statement.accept(visitor)
def remove_duplicates(self):
added_imports = []
for import_stmt in self.imports:
visitor = actions.AddingVisitor(self.project, [import_stmt.import_info])
for added_import in added_imports:
if added_import.accept(visitor):
import_stmt.empty_import()
else:
added_imports.append(import_stmt)
def force_single_imports(self):
"""force a single import per statement"""
for import_stmt in self.imports[:]:
import_info = import_stmt.import_info
if import_info.is_empty() or import_stmt.readonly:
continue
if len(import_info.names_and_aliases) > 1:
for name_and_alias in import_info.names_and_aliases:
if hasattr(import_info, "module_name"):
new_import = importinfo.FromImport(
import_info.module_name, import_info.level, [name_and_alias]
)
else:
new_import = importinfo.NormalImport([name_and_alias])
self.add_import(new_import)
import_stmt.empty_import()
def get_relative_to_absolute_list(self):
visitor = actions.RelativeToAbsoluteVisitor(
self.project, self._current_folder()
)
for import_stmt in self.imports:
if not import_stmt.readonly:
import_stmt.accept(visitor)
return visitor.to_be_absolute
def get_self_import_fix_and_rename_list(self):
visitor = actions.SelfImportVisitor(
self.project, self._current_folder(), self.pymodule.get_resource()
)
for import_stmt in self.imports:
if not import_stmt.readonly:
import_stmt.accept(visitor)
return visitor.to_be_fixed, visitor.to_be_renamed
def _current_folder(self):
return self.pymodule.get_resource().parent
def sort_imports(self):
if self.project.prefs.get("sort_imports_alphabetically"):
sort_kwargs = dict(key=self._get_import_name)
else:
sort_kwargs = dict(key=self._key_imports)
# IDEA: Sort from import list
visitor = actions.SortingVisitor(self.project, self._current_folder())
for import_statement in self.imports:
import_statement.accept(visitor)
in_projects = sorted(visitor.in_project, **sort_kwargs)
third_party = sorted(visitor.third_party, **sort_kwargs)
standards = sorted(visitor.standard, **sort_kwargs)
future = sorted(visitor.future, **sort_kwargs)
last_index = self._first_import_line()
last_index = self._move_imports(future, last_index, 0)
last_index = self._move_imports(standards, last_index, 1)
last_index = self._move_imports(third_party, last_index, 1)
last_index = self._move_imports(in_projects, last_index, 1)
self.separating_lines = 2
self.sorted = True
def _first_import_line(self):
nodes = self.pymodule.get_ast().body
lineno = 0
if self.pymodule.get_doc() is not None:
lineno = 1
if len(nodes) > lineno:
if isinstance(nodes[lineno], ast.Import) or isinstance(
nodes[lineno], ast.ImportFrom
):
return nodes[lineno].lineno
lineno = self.pymodule.logical_lines.logical_line_in(nodes[lineno].lineno)[
0
]
else:
lineno = self.pymodule.lines.length()
return lineno - _count_blank_lines(
self.pymodule.lines.get_line, lineno - 1, 1, -1
)
def _get_import_name(self, import_stmt):
import_info = import_stmt.import_info
if hasattr(import_info, "module_name"):
return "%s.%s" % (
import_info.module_name,
import_info.names_and_aliases[0][0],
)
else:
return import_info.names_and_aliases[0][0]
def _key_imports(self, stm1):
str1 = stm1.get_import_statement()
return str1.startswith("from "), str1
# str1 = stmt1.get_import_statement()
# str2 = stmt2.get_import_statement()
# if str1.startswith('from ') and not str2.startswith('from '):
# return 1
# if not str1.startswith('from ') and str2.startswith('from '):
# return -1
# return cmp(str1, str2)
def _move_imports(self, imports, index, blank_lines):
if imports:
imports[0].move(index, blank_lines)
index += 1
if len(imports) > 1:
for stmt in imports[1:]:
stmt.move(index)
index += 1
return index
def handle_long_imports(self, maxdots, maxlength):
visitor = actions.LongImportVisitor(
self._current_folder(), self.project, maxdots, maxlength
)
for import_statement in self.imports:
if not import_statement.readonly:
import_statement.accept(visitor)
for import_info in visitor.new_imports:
self.add_import(import_info)
return visitor.to_be_renamed
def remove_pyname(self, pyname):
"""Removes pyname when imported in ``from mod import x``"""
visitor = actions.RemovePyNameVisitor(
self.project, self.pymodule, pyname, self._current_folder()
)
for import_stmt in self.imports:
import_stmt.accept(visitor)
def _count_blank_lines(get_line, start, end, step=1):
count = 0
for idx in range(start, end, step):
if get_line(idx).strip() == "":
count += 1
else:
break
return count
class _OneTimeSelector(object):
def __init__(self, names):
self.names = names
self.selected_names = set()
def __call__(self, imported_primary):
if self._can_name_be_added(imported_primary):
for name in self._get_dotted_tokens(imported_primary):
self.selected_names.add(name)
return True
return False
def _get_dotted_tokens(self, imported_primary):
tokens = imported_primary.split(".")
for i in range(len(tokens)):
yield ".".join(tokens[: i + 1])
def _can_name_be_added(self, imported_primary):
for name in self._get_dotted_tokens(imported_primary):
if name in self.names and name not in self.selected_names:
return True
return False
class _UnboundNameFinder(object):
def __init__(self, pyobject):
self.pyobject = pyobject
def _visit_child_scope(self, node):
pyobject = (
self.pyobject.get_module()
.get_scope()
.get_inner_scope_for_line(node.lineno)
.pyobject
)
visitor = _LocalUnboundNameFinder(pyobject, self)
for child in ast.get_child_nodes(node):
ast.walk(child, visitor)
def _FunctionDef(self, node):
self._visit_child_scope(node)
def _ClassDef(self, node):
self._visit_child_scope(node)
def _Name(self, node):
if self._get_root()._is_node_interesting(node) and not self.is_bound(node.id):
self.add_unbound(node.id)
def _Attribute(self, node):
result = []
while isinstance(node, ast.Attribute):
result.append(node.attr)
node = node.value
if isinstance(node, ast.Name):
result.append(node.id)
primary = ".".join(reversed(result))
if self._get_root()._is_node_interesting(node) and not self.is_bound(
primary
):
self.add_unbound(primary)
else:
ast.walk(node, self)
def _get_root(self):
pass
def is_bound(self, name, propagated=False):
pass
def add_unbound(self, name):
pass
class _GlobalUnboundNameFinder(_UnboundNameFinder):
def __init__(self, pymodule, wanted_pyobject):
super(_GlobalUnboundNameFinder, self).__init__(pymodule)
self.unbound = set()
self.names = set()
for name, pyname in pymodule._get_structural_attributes().items():
if not isinstance(pyname, (pynames.ImportedName, pynames.ImportedModule)):
self.names.add(name)
wanted_scope = wanted_pyobject.get_scope()
self.start = wanted_scope.get_start()
self.end = wanted_scope.get_end() + 1
def _get_root(self):
return self
def is_bound(self, primary, propagated=False):
name = primary.split(".")[0]
return name in self.names
def add_unbound(self, name):
names = name.split(".")
for i in range(len(names)):
self.unbound.add(".".join(names[: i + 1]))
def _is_node_interesting(self, node):
return self.start <= node.lineno < self.end
class _LocalUnboundNameFinder(_UnboundNameFinder):
def __init__(self, pyobject, parent):
super(_LocalUnboundNameFinder, self).__init__(pyobject)
self.parent = parent
def _get_root(self):
return self.parent._get_root()
def is_bound(self, primary, propagated=False):
name = primary.split(".")[0]
if propagated:
names = self.pyobject.get_scope().get_propagated_names()
else:
names = self.pyobject.get_scope().get_names()
if name in names or self.parent.is_bound(name, propagated=True):
return True
return False
def add_unbound(self, name):
self.parent.add_unbound(name)
class _GlobalImportFinder(object):
def __init__(self, pymodule):
self.current_folder = None
if pymodule.get_resource():
self.current_folder = pymodule.get_resource().parent
self.pymodule = pymodule
self.imports = []
self.pymodule = pymodule
self.lines = self.pymodule.lines
def visit_import(self, node, end_line):
start_line = node.lineno
import_statement = importinfo.ImportStatement(
importinfo.NormalImport(self._get_names(node.names)),
start_line,
end_line,
self._get_text(start_line, end_line),
blank_lines=self._count_empty_lines_before(start_line),
)
self.imports.append(import_statement)
def _count_empty_lines_before(self, lineno):
return _count_blank_lines(self.lines.get_line, lineno - 1, 0, -1)
def _count_empty_lines_after(self, lineno):
return _count_blank_lines(self.lines.get_line, lineno + 1, self.lines.length())
def get_separating_line_count(self):
if not self.imports:
return 0
return self._count_empty_lines_after(self.imports[-1].end_line - 1)
def _get_text(self, start_line, end_line):
result = []
for index in range(start_line, end_line):
result.append(self.lines.get_line(index))
return "\n".join(result)
def visit_from(self, node, end_line):
level = 0
if node.level:
level = node.level
import_info = importinfo.FromImport(
node.module or "", # see comment at rope.base.ast.walk
level,
self._get_names(node.names),
)
start_line = node.lineno
self.imports.append(
importinfo.ImportStatement(
import_info,
node.lineno,
end_line,
self._get_text(start_line, end_line),
blank_lines=self._count_empty_lines_before(start_line),
)
)
def _get_names(self, alias_names):
result = []
for alias in alias_names:
result.append((alias.name, alias.asname))
return result
def find_import_statements(self):
nodes = self.pymodule.get_ast().body
for index, node in enumerate(nodes):
if isinstance(node, (ast.Import, ast.ImportFrom)):
lines = self.pymodule.logical_lines
end_line = lines.logical_line_in(node.lineno)[1] + 1
if isinstance(node, ast.Import):
self.visit_import(node, end_line)
if isinstance(node, ast.ImportFrom):
self.visit_from(node, end_line)
return self.imports

View file

@ -0,0 +1,686 @@
# Known Bugs when inlining a function/method
# The values passed to function are inlined using _inlined_variable.
# This may cause two problems, illustrated in the examples below
#
# def foo(var1):
# var1 = var1*10
# return var1
#
# If a call to foo(20) is inlined, the result of inlined function is 20,
# but it should be 200.
#
# def foo(var1):
# var2 = var1*10
# return var2
#
# 2- If a call to foo(10+10) is inlined the result of inlined function is 110
# but it should be 200.
import re
import rope.base.exceptions
import rope.refactor.functionutils
from rope.base import (
pynames,
pyobjects,
codeanalyze,
taskhandle,
evaluate,
worder,
utils,
libutils,
)
from rope.base.change import ChangeSet, ChangeContents
from rope.refactor import (
occurrences,
rename,
sourceutils,
importutils,
move,
change_signature,
)
def unique_prefix():
n = 0
while True:
yield "__" + str(n) + "__"
n += 1
def create_inline(project, resource, offset):
"""Create a refactoring object for inlining
Based on `resource` and `offset` it returns an instance of
`InlineMethod`, `InlineVariable` or `InlineParameter`.
"""
pyname = _get_pyname(project, resource, offset)
message = (
"Inline refactoring should be performed on "
"a method, local variable or parameter."
)
if pyname is None:
raise rope.base.exceptions.RefactoringError(message)
if isinstance(pyname, pynames.ImportedName):
pyname = pyname._get_imported_pyname()
if isinstance(pyname, pynames.AssignedName):
return InlineVariable(project, resource, offset)
if isinstance(pyname, pynames.ParameterName):
return InlineParameter(project, resource, offset)
if isinstance(pyname.get_object(), pyobjects.PyFunction):
return InlineMethod(project, resource, offset)
else:
raise rope.base.exceptions.RefactoringError(message)
class _Inliner(object):
def __init__(self, project, resource, offset):
self.project = project
self.pyname = _get_pyname(project, resource, offset)
range_finder = worder.Worder(resource.read(), True)
self.region = range_finder.get_primary_range(offset)
self.name = range_finder.get_word_at(offset)
self.offset = offset
self.original = resource
def get_changes(self, *args, **kwds):
pass
def get_kind(self):
"""Return either 'variable', 'method' or 'parameter'"""
class InlineMethod(_Inliner):
def __init__(self, *args, **kwds):
super(InlineMethod, self).__init__(*args, **kwds)
self.pyfunction = self.pyname.get_object()
self.pymodule = self.pyfunction.get_module()
self.resource = self.pyfunction.get_module().get_resource()
self.occurrence_finder = occurrences.create_finder(
self.project, self.name, self.pyname
)
self.normal_generator = _DefinitionGenerator(self.project, self.pyfunction)
self._init_imports()
def _init_imports(self):
body = sourceutils.get_body(self.pyfunction)
body, imports = move.moving_code_with_imports(self.project, self.resource, body)
self.imports = imports
self.others_generator = _DefinitionGenerator(
self.project, self.pyfunction, body=body
)
def _get_scope_range(self):
scope = self.pyfunction.get_scope()
lines = self.pymodule.lines
start_line = scope.get_start()
if self.pyfunction.decorators:
decorators = self.pyfunction.decorators
if hasattr(decorators[0], "lineno"):
start_line = decorators[0].lineno
start_offset = lines.get_line_start(start_line)
end_offset = min(
lines.get_line_end(scope.end) + 1, len(self.pymodule.source_code)
)
return (start_offset, end_offset)
def get_changes(
self,
remove=True,
only_current=False,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Get the changes this refactoring makes
If `remove` is `False` the definition will not be removed. If
`only_current` is `True`, the the current occurrence will be
inlined, only.
"""
changes = ChangeSet("Inline method <%s>" % self.name)
if resources is None:
resources = self.project.get_python_files()
if only_current:
resources = [self.original]
if remove:
resources.append(self.resource)
job_set = task_handle.create_jobset("Collecting Changes", len(resources))
for file in resources:
job_set.started_job(file.path)
if file == self.resource:
changes.add_change(
self._defining_file_changes(
changes, remove=remove, only_current=only_current
)
)
else:
aim = None
if only_current and self.original == file:
aim = self.offset
handle = _InlineFunctionCallsForModuleHandle(
self.project, file, self.others_generator, aim
)
result = move.ModuleSkipRenamer(
self.occurrence_finder, file, handle
).get_changed_module()
if result is not None:
result = _add_imports(self.project, result, file, self.imports)
if remove:
result = _remove_from(self.project, self.pyname, result, file)
changes.add_change(ChangeContents(file, result))
job_set.finished_job()
return changes
def _get_removed_range(self):
scope = self.pyfunction.get_scope()
lines = self.pymodule.lines
start, end = self._get_scope_range()
end_line = scope.get_end()
for i in range(end_line + 1, lines.length()):
if lines.get_line(i).strip() == "":
end_line = i
else:
break
end = min(lines.get_line_end(end_line) + 1, len(self.pymodule.source_code))
return (start, end)
def _defining_file_changes(self, changes, remove, only_current):
start_offset, end_offset = self._get_removed_range()
aim = None
if only_current:
if self.resource == self.original:
aim = self.offset
else:
# we don't want to change any of them
aim = len(self.resource.read()) + 100
handle = _InlineFunctionCallsForModuleHandle(
self.project, self.resource, self.normal_generator, aim_offset=aim
)
replacement = None
if remove:
replacement = self._get_method_replacement()
result = move.ModuleSkipRenamer(
self.occurrence_finder,
self.resource,
handle,
start_offset,
end_offset,
replacement,
).get_changed_module()
return ChangeContents(self.resource, result)
def _get_method_replacement(self):
if self._is_the_last_method_of_a_class():
indents = sourceutils.get_indents(
self.pymodule.lines, self.pyfunction.get_scope().get_start()
)
return " " * indents + "pass\n"
return ""
def _is_the_last_method_of_a_class(self):
pyclass = self.pyfunction.parent
if not isinstance(pyclass, pyobjects.PyClass):
return False
class_start, class_end = sourceutils.get_body_region(pyclass)
source = self.pymodule.source_code
func_start, func_end = self._get_scope_range()
if (
source[class_start:func_start].strip() == ""
and source[func_end:class_end].strip() == ""
):
return True
return False
def get_kind(self):
return "method"
class InlineVariable(_Inliner):
def __init__(self, *args, **kwds):
super(InlineVariable, self).__init__(*args, **kwds)
self.pymodule = self.pyname.get_definition_location()[0]
self.resource = self.pymodule.get_resource()
self._check_exceptional_conditions()
self._init_imports()
def _check_exceptional_conditions(self):
if len(self.pyname.assignments) != 1:
raise rope.base.exceptions.RefactoringError(
"Local variable should be assigned once for inlining."
)
def get_changes(
self,
remove=True,
only_current=False,
resources=None,
docs=False,
task_handle=taskhandle.NullTaskHandle(),
):
if resources is None:
if rename._is_local(self.pyname):
resources = [self.resource]
else:
resources = self.project.get_python_files()
if only_current:
resources = [self.original]
if remove and self.original != self.resource:
resources.append(self.resource)
changes = ChangeSet("Inline variable <%s>" % self.name)
jobset = task_handle.create_jobset("Calculating changes", len(resources))
for resource in resources:
jobset.started_job(resource.path)
if resource == self.resource:
source = self._change_main_module(remove, only_current, docs)
changes.add_change(ChangeContents(self.resource, source))
else:
result = self._change_module(resource, remove, only_current)
if result is not None:
result = _add_imports(self.project, result, resource, self.imports)
changes.add_change(ChangeContents(resource, result))
jobset.finished_job()
return changes
def _change_main_module(self, remove, only_current, docs):
region = None
if only_current and self.original == self.resource:
region = self.region
return _inline_variable(
self.project,
self.pymodule,
self.pyname,
self.name,
remove=remove,
region=region,
docs=docs,
)
def _init_imports(self):
vardef = _getvardef(self.pymodule, self.pyname)
self.imported, self.imports = move.moving_code_with_imports(
self.project, self.resource, vardef
)
def _change_module(self, resource, remove, only_current):
filters = [occurrences.NoImportsFilter(), occurrences.PyNameFilter(self.pyname)]
if only_current and resource == self.original:
def check_aim(occurrence):
start, end = occurrence.get_primary_range()
if self.offset < start or end < self.offset:
return False
filters.insert(0, check_aim)
finder = occurrences.Finder(self.project, self.name, filters=filters)
changed = rename.rename_in_module(
finder, self.imported, resource=resource, replace_primary=True
)
if changed and remove:
changed = _remove_from(self.project, self.pyname, changed, resource)
return changed
def get_kind(self):
return "variable"
class InlineParameter(_Inliner):
def __init__(self, *args, **kwds):
super(InlineParameter, self).__init__(*args, **kwds)
resource, offset = self._function_location()
index = self.pyname.index
self.changers = [change_signature.ArgumentDefaultInliner(index)]
self.signature = change_signature.ChangeSignature(
self.project, resource, offset
)
def _function_location(self):
pymodule, lineno = self.pyname.get_definition_location()
resource = pymodule.get_resource()
start = pymodule.lines.get_line_start(lineno)
word_finder = worder.Worder(pymodule.source_code)
offset = word_finder.find_function_offset(start)
return resource, offset
def get_changes(self, **kwds):
"""Get the changes needed by this refactoring
See `rope.refactor.change_signature.ChangeSignature.get_changes()`
for arguments.
"""
return self.signature.get_changes(self.changers, **kwds)
def get_kind(self):
return "parameter"
def _join_lines(lines):
definition_lines = []
for unchanged_line in lines:
line = unchanged_line.strip()
if line.endswith("\\"):
line = line[:-1].strip()
definition_lines.append(line)
joined = " ".join(definition_lines)
return joined
class _DefinitionGenerator(object):
unique_prefix = unique_prefix()
def __init__(self, project, pyfunction, body=None):
self.project = project
self.pyfunction = pyfunction
self.pymodule = pyfunction.get_module()
self.resource = self.pymodule.get_resource()
self.definition_info = self._get_definition_info()
self.definition_params = self._get_definition_params()
self._calculated_definitions = {}
if body is not None:
self.body = body
else:
self.body = sourceutils.get_body(self.pyfunction)
def _get_definition_info(self):
return rope.refactor.functionutils.DefinitionInfo.read(self.pyfunction)
def _get_definition_params(self):
definition_info = self.definition_info
paramdict = dict([pair for pair in definition_info.args_with_defaults])
if (
definition_info.args_arg is not None
or definition_info.keywords_arg is not None
):
raise rope.base.exceptions.RefactoringError(
"Cannot inline functions with list and keyword arguements."
)
if self.pyfunction.get_kind() == "classmethod":
paramdict[
definition_info.args_with_defaults[0][0]
] = self.pyfunction.parent.get_name()
return paramdict
def get_function_name(self):
return self.pyfunction.get_name()
def get_definition(self, primary, pyname, call, host_vars=[], returns=False):
# caching already calculated definitions
return self._calculate_definition(primary, pyname, call, host_vars, returns)
def _calculate_header(self, primary, pyname, call):
# A header is created which initializes parameters
# to the values passed to the function.
call_info = rope.refactor.functionutils.CallInfo.read(
primary, pyname, self.definition_info, call
)
paramdict = self.definition_params
mapping = rope.refactor.functionutils.ArgumentMapping(
self.definition_info, call_info
)
for param_name, value in mapping.param_dict.items():
paramdict[param_name] = value
header = ""
to_be_inlined = []
for name, value in paramdict.items():
if name != value and value is not None:
header += name + " = " + value.replace("\n", " ") + "\n"
to_be_inlined.append(name)
return header, to_be_inlined
def _calculate_definition(self, primary, pyname, call, host_vars, returns):
header, to_be_inlined = self._calculate_header(primary, pyname, call)
source = header + self.body
mod = libutils.get_string_module(self.project, source)
name_dict = mod.get_scope().get_names()
all_names = [
x
for x in name_dict
if not isinstance(name_dict[x], rope.base.builtins.BuiltinName)
]
# If there is a name conflict, all variable names
# inside the inlined function are renamed
if len(set(all_names).intersection(set(host_vars))) > 0:
prefix = next(_DefinitionGenerator.unique_prefix)
guest = libutils.get_string_module(self.project, source, self.resource)
to_be_inlined = [prefix + item for item in to_be_inlined]
for item in all_names:
pyname = guest[item]
occurrence_finder = occurrences.create_finder(
self.project, item, pyname
)
source = rename.rename_in_module(
occurrence_finder, prefix + item, pymodule=guest
)
guest = libutils.get_string_module(self.project, source, self.resource)
# parameters not reassigned inside the functions are now inlined.
for name in to_be_inlined:
pymodule = libutils.get_string_module(self.project, source, self.resource)
pyname = pymodule[name]
source = _inline_variable(self.project, pymodule, pyname, name)
return self._replace_returns_with(source, returns)
def _replace_returns_with(self, source, returns):
result = []
returned = None
last_changed = 0
for match in _DefinitionGenerator._get_return_pattern().finditer(source):
for key, value in match.groupdict().items():
if value and key == "return":
result.append(source[last_changed : match.start("return")])
if returns:
self._check_nothing_after_return(source, match.end("return"))
beg_idx = match.end("return")
returned = _join_lines(
source[beg_idx : len(source)].splitlines()
)
last_changed = len(source)
else:
current = match.end("return")
while current < len(source) and source[current] in " \t":
current += 1
last_changed = current
if current == len(source) or source[current] == "\n":
result.append("pass")
result.append(source[last_changed:])
return "".join(result), returned
def _check_nothing_after_return(self, source, offset):
lines = codeanalyze.SourceLinesAdapter(source)
lineno = lines.get_line_number(offset)
logical_lines = codeanalyze.LogicalLineFinder(lines)
lineno = logical_lines.logical_line_in(lineno)[1]
if source[lines.get_line_end(lineno) : len(source)].strip() != "":
raise rope.base.exceptions.RefactoringError(
"Cannot inline functions with statements " + "after return statement."
)
@classmethod
def _get_return_pattern(cls):
if not hasattr(cls, "_return_pattern"):
def named_pattern(name, list_):
return "(?P<%s>" % name + "|".join(list_) + ")"
comment_pattern = named_pattern("comment", [r"#[^\n]*"])
string_pattern = named_pattern("string", [codeanalyze.get_string_pattern()])
return_pattern = r"\b(?P<return>return)\b"
cls._return_pattern = re.compile(
comment_pattern + "|" + string_pattern + "|" + return_pattern
)
return cls._return_pattern
class _InlineFunctionCallsForModuleHandle(object):
def __init__(self, project, resource, definition_generator, aim_offset=None):
"""Inlines occurrences
If `aim` is not `None` only the occurrences that intersect
`aim` offset will be inlined.
"""
self.project = project
self.generator = definition_generator
self.resource = resource
self.aim = aim_offset
def occurred_inside_skip(self, change_collector, occurrence):
if not occurrence.is_defined():
raise rope.base.exceptions.RefactoringError(
"Cannot inline functions that reference themselves"
)
def occurred_outside_skip(self, change_collector, occurrence):
start, end = occurrence.get_primary_range()
# we remove out of date imports later
if occurrence.is_in_import_statement():
return
# the function is referenced outside an import statement
if not occurrence.is_called():
raise rope.base.exceptions.RefactoringError(
"Reference to inlining function other than function call"
" in <file: %s, offset: %d>" % (self.resource.path, start)
)
if self.aim is not None and (self.aim < start or self.aim > end):
return
end_parens = self._find_end_parens(self.source, end - 1)
lineno = self.lines.get_line_number(start)
start_line, end_line = self.pymodule.logical_lines.logical_line_in(lineno)
line_start = self.lines.get_line_start(start_line)
line_end = self.lines.get_line_end(end_line)
returns = (
self.source[line_start:start].strip() != ""
or self.source[end_parens:line_end].strip() != ""
)
indents = sourceutils.get_indents(self.lines, start_line)
primary, pyname = occurrence.get_primary_and_pyname()
host = self.pymodule
scope = host.scope.get_inner_scope_for_line(lineno)
definition, returned = self.generator.get_definition(
primary,
pyname,
self.source[start:end_parens],
scope.get_names(),
returns=returns,
)
end = min(line_end + 1, len(self.source))
change_collector.add_change(
line_start, end, sourceutils.fix_indentation(definition, indents)
)
if returns:
name = returned
if name is None:
name = "None"
change_collector.add_change(
line_end,
end,
self.source[line_start:start] + name + self.source[end_parens:end],
)
def _find_end_parens(self, source, offset):
finder = worder.Worder(source)
return finder.get_word_parens_range(offset)[1]
@property
@utils.saveit
def pymodule(self):
return self.project.get_pymodule(self.resource)
@property
@utils.saveit
def source(self):
if self.resource is not None:
return self.resource.read()
else:
return self.pymodule.source_code
@property
@utils.saveit
def lines(self):
return self.pymodule.lines
def _inline_variable(
project, pymodule, pyname, name, remove=True, region=None, docs=False
):
definition = _getvardef(pymodule, pyname)
start, end = _assigned_lineno(pymodule, pyname)
occurrence_finder = occurrences.create_finder(project, name, pyname, docs=docs)
changed_source = rename.rename_in_module(
occurrence_finder,
definition,
pymodule=pymodule,
replace_primary=True,
writes=False,
region=region,
)
if changed_source is None:
changed_source = pymodule.source_code
if remove:
lines = codeanalyze.SourceLinesAdapter(changed_source)
source = (
changed_source[: lines.get_line_start(start)]
+ changed_source[lines.get_line_end(end) + 1 :]
)
else:
source = changed_source
return source
def _getvardef(pymodule, pyname):
assignment = pyname.assignments[0]
lines = pymodule.lines
start, end = _assigned_lineno(pymodule, pyname)
definition_with_assignment = _join_lines(
[lines.get_line(n) for n in range(start, end + 1)]
)
if assignment.levels:
raise rope.base.exceptions.RefactoringError("Cannot inline tuple assignments.")
definition = definition_with_assignment[
definition_with_assignment.index("=") + 1 :
].strip()
return definition
def _assigned_lineno(pymodule, pyname):
definition_line = pyname.assignments[0].ast_node.lineno
return pymodule.logical_lines.logical_line_in(definition_line)
def _add_imports(project, source, resource, imports):
if not imports:
return source
pymodule = libutils.get_string_module(project, source, resource)
module_import = importutils.get_module_imports(project, pymodule)
for import_info in imports:
module_import.add_import(import_info)
source = module_import.get_changed_source()
pymodule = libutils.get_string_module(project, source, resource)
import_tools = importutils.ImportTools(project)
return import_tools.organize_imports(pymodule, unused=False, sort=False)
def _get_pyname(project, resource, offset):
pymodule = project.get_pymodule(resource)
pyname = evaluate.eval_location(pymodule, offset)
if isinstance(pyname, pynames.ImportedName):
pyname = pyname._get_imported_pyname()
return pyname
def _remove_from(project, pyname, source, resource):
pymodule = libutils.get_string_module(project, source, resource)
module_import = importutils.get_module_imports(project, pymodule)
module_import.remove_pyname(pyname)
return module_import.get_changed_source()

View file

@ -0,0 +1,146 @@
import rope.base.exceptions
import rope.base.pyobjects
from rope.base import libutils
from rope.base import taskhandle, evaluate
from rope.base.change import ChangeSet, ChangeContents
from rope.refactor import rename, occurrences, sourceutils, importutils
class IntroduceFactory(object):
def __init__(self, project, resource, offset):
self.project = project
self.offset = offset
this_pymodule = self.project.get_pymodule(resource)
self.old_pyname = evaluate.eval_location(this_pymodule, offset)
if self.old_pyname is None or not isinstance(
self.old_pyname.get_object(), rope.base.pyobjects.PyClass
):
raise rope.base.exceptions.RefactoringError(
"Introduce factory should be performed on a class."
)
self.old_name = self.old_pyname.get_object().get_name()
self.pymodule = self.old_pyname.get_object().get_module()
self.resource = self.pymodule.get_resource()
def get_changes(
self,
factory_name,
global_factory=False,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Get the changes this refactoring makes
`factory_name` indicates the name of the factory function to
be added. If `global_factory` is `True` the factory will be
global otherwise a static method is added to the class.
`resources` can be a list of `rope.base.resource.File` that
this refactoring should be applied on; if `None` all python
files in the project are searched.
"""
if resources is None:
resources = self.project.get_python_files()
changes = ChangeSet("Introduce factory method <%s>" % factory_name)
job_set = task_handle.create_jobset("Collecting Changes", len(resources))
self._change_module(resources, changes, factory_name, global_factory, job_set)
return changes
def get_name(self):
"""Return the name of the class"""
return self.old_name
def _change_module(self, resources, changes, factory_name, global_, job_set):
if global_:
replacement = "__rope_factory_%s_" % factory_name
else:
replacement = self._new_function_name(factory_name, global_)
for file_ in resources:
job_set.started_job(file_.path)
if file_ == self.resource:
self._change_resource(changes, factory_name, global_)
job_set.finished_job()
continue
changed_code = self._rename_occurrences(file_, replacement, global_)
if changed_code is not None:
if global_:
new_pymodule = libutils.get_string_module(
self.project, changed_code, self.resource
)
modname = libutils.modname(self.resource)
changed_code, imported = importutils.add_import(
self.project, new_pymodule, modname, factory_name
)
changed_code = changed_code.replace(replacement, imported)
changes.add_change(ChangeContents(file_, changed_code))
job_set.finished_job()
def _change_resource(self, changes, factory_name, global_):
class_scope = self.old_pyname.get_object().get_scope()
source_code = self._rename_occurrences(
self.resource, self._new_function_name(factory_name, global_), global_
)
if source_code is None:
source_code = self.pymodule.source_code
else:
self.pymodule = libutils.get_string_module(
self.project, source_code, resource=self.resource
)
lines = self.pymodule.lines
start = self._get_insertion_offset(class_scope, lines)
result = source_code[:start]
result += self._get_factory_method(lines, class_scope, factory_name, global_)
result += source_code[start:]
changes.add_change(ChangeContents(self.resource, result))
def _get_insertion_offset(self, class_scope, lines):
start_line = class_scope.get_end()
if class_scope.get_scopes():
start_line = class_scope.get_scopes()[-1].get_end()
start = lines.get_line_end(start_line) + 1
return start
def _get_factory_method(self, lines, class_scope, factory_name, global_):
unit_indents = " " * sourceutils.get_indent(self.project)
if global_:
if self._get_scope_indents(lines, class_scope) > 0:
raise rope.base.exceptions.RefactoringError(
"Cannot make global factory method for nested classes."
)
return "\ndef %s(*args, **kwds):\n%sreturn %s(*args, **kwds)\n" % (
factory_name,
unit_indents,
self.old_name,
)
unindented_factory = (
"@staticmethod\ndef %s(*args, **kwds):\n" % factory_name
+ "%sreturn %s(*args, **kwds)\n" % (unit_indents, self.old_name)
)
indents = self._get_scope_indents(lines, class_scope) + sourceutils.get_indent(
self.project
)
return "\n" + sourceutils.indent_lines(unindented_factory, indents)
def _get_scope_indents(self, lines, scope):
return sourceutils.get_indents(lines, scope.get_start())
def _new_function_name(self, factory_name, global_):
if global_:
return factory_name
else:
return self.old_name + "." + factory_name
def _rename_occurrences(self, file_, changed_name, global_factory):
finder = occurrences.create_finder(
self.project, self.old_name, self.old_pyname, only_calls=True
)
result = rename.rename_in_module(
finder, changed_name, resource=file_, replace_primary=global_factory
)
return result
IntroduceFactoryRefactoring = IntroduceFactory

View file

@ -0,0 +1,96 @@
import rope.base.change
from rope.base import exceptions, evaluate, worder, codeanalyze
from rope.refactor import functionutils, sourceutils, occurrences
class IntroduceParameter(object):
"""Introduce parameter refactoring
This refactoring adds a new parameter to a function and replaces
references to an expression in it with the new parameter.
The parameter finding part is different from finding similar
pieces in extract refactorings. In this refactoring parameters
are found based on the object they reference to. For instance
in::
class A(object):
var = None
class B(object):
a = A()
b = B()
a = b.a
def f(a):
x = b.a.var + a.var
using this refactoring on ``a.var`` with ``p`` as the new
parameter name, will result in::
def f(p=a.var):
x = p + p
"""
def __init__(self, project, resource, offset):
self.project = project
self.resource = resource
self.offset = offset
self.pymodule = self.project.get_pymodule(self.resource)
scope = self.pymodule.get_scope().get_inner_scope_for_offset(offset)
if scope.get_kind() != "Function":
raise exceptions.RefactoringError(
"Introduce parameter should be performed inside functions"
)
self.pyfunction = scope.pyobject
self.name, self.pyname = self._get_name_and_pyname()
if self.pyname is None:
raise exceptions.RefactoringError(
"Cannot find the definition of <%s>" % self.name
)
def _get_primary(self):
word_finder = worder.Worder(self.resource.read())
return word_finder.get_primary_at(self.offset)
def _get_name_and_pyname(self):
return (
worder.get_name_at(self.resource, self.offset),
evaluate.eval_location(self.pymodule, self.offset),
)
def get_changes(self, new_parameter):
definition_info = functionutils.DefinitionInfo.read(self.pyfunction)
definition_info.args_with_defaults.append((new_parameter, self._get_primary()))
collector = codeanalyze.ChangeCollector(self.resource.read())
header_start, header_end = self._get_header_offsets()
body_start, body_end = sourceutils.get_body_region(self.pyfunction)
collector.add_change(header_start, header_end, definition_info.to_string())
self._change_function_occurrences(
collector, body_start, body_end, new_parameter
)
changes = rope.base.change.ChangeSet("Introduce parameter <%s>" % new_parameter)
change = rope.base.change.ChangeContents(self.resource, collector.get_changed())
changes.add_change(change)
return changes
def _get_header_offsets(self):
lines = self.pymodule.lines
start_line = self.pyfunction.get_scope().get_start()
end_line = self.pymodule.logical_lines.logical_line_in(start_line)[1]
start = lines.get_line_start(start_line)
end = lines.get_line_end(end_line)
start = self.pymodule.source_code.find("def", start) + 4
end = self.pymodule.source_code.rfind(":", start, end)
return start, end
def _change_function_occurrences(
self, collector, function_start, function_end, new_name
):
finder = occurrences.create_finder(self.project, self.name, self.pyname)
for occurrence in finder.find_occurrences(resource=self.resource):
start, end = occurrence.get_primary_range()
if function_start <= start < function_end:
collector.add_change(start, end, new_name)

View file

@ -0,0 +1,52 @@
from rope.base import pynames, evaluate, exceptions, worder
from rope.refactor.rename import Rename
class LocalToField(object):
def __init__(self, project, resource, offset):
self.project = project
self.resource = resource
self.offset = offset
def get_changes(self):
name = worder.get_name_at(self.resource, self.offset)
this_pymodule = self.project.get_pymodule(self.resource)
pyname = evaluate.eval_location(this_pymodule, self.offset)
if not self._is_a_method_local(pyname):
raise exceptions.RefactoringError(
"Convert local variable to field should be performed on \n"
"a local variable of a method."
)
pymodule, lineno = pyname.get_definition_location()
function_scope = pymodule.get_scope().get_inner_scope_for_line(lineno)
# Not checking redefinition
# self._check_redefinition(name, function_scope)
new_name = self._get_field_name(function_scope.pyobject, name)
changes = Rename(self.project, self.resource, self.offset).get_changes(
new_name, resources=[self.resource]
)
return changes
def _check_redefinition(self, name, function_scope):
class_scope = function_scope.parent
if name in class_scope.pyobject:
raise exceptions.RefactoringError("The field %s already exists" % name)
def _get_field_name(self, pyfunction, name):
self_name = pyfunction.get_param_names()[0]
new_name = self_name + "." + name
return new_name
def _is_a_method_local(self, pyname):
pymodule, lineno = pyname.get_definition_location()
holding_scope = pymodule.get_scope().get_inner_scope_for_line(lineno)
parent = holding_scope.parent
return (
isinstance(pyname, pynames.AssignedName)
and pyname in holding_scope.get_names().values()
and holding_scope.get_kind() == "Function"
and parent is not None
and parent.get_kind() == "Class"
)

View file

@ -0,0 +1,96 @@
import warnings
from rope.base import libutils
from rope.base import pyobjects, exceptions, change, evaluate, codeanalyze
from rope.refactor import sourceutils, occurrences, rename
class MethodObject(object):
def __init__(self, project, resource, offset):
self.project = project
this_pymodule = self.project.get_pymodule(resource)
pyname = evaluate.eval_location(this_pymodule, offset)
if pyname is None or not isinstance(pyname.get_object(), pyobjects.PyFunction):
raise exceptions.RefactoringError(
"Replace method with method object refactoring should be "
"performed on a function."
)
self.pyfunction = pyname.get_object()
self.pymodule = self.pyfunction.get_module()
self.resource = self.pymodule.get_resource()
def get_new_class(self, name):
body = sourceutils.fix_indentation(
self._get_body(), sourceutils.get_indent(self.project) * 2
)
return "class %s(object):\n\n%s%sdef __call__(self):\n%s" % (
name,
self._get_init(),
" " * sourceutils.get_indent(self.project),
body,
)
def get_changes(self, classname=None, new_class_name=None):
if new_class_name is not None:
warnings.warn(
"new_class_name parameter is deprecated; use classname",
DeprecationWarning,
stacklevel=2,
)
classname = new_class_name
collector = codeanalyze.ChangeCollector(self.pymodule.source_code)
start, end = sourceutils.get_body_region(self.pyfunction)
indents = sourceutils.get_indents(
self.pymodule.lines, self.pyfunction.get_scope().get_start()
) + sourceutils.get_indent(self.project)
new_contents = " " * indents + "return %s(%s)()\n" % (
classname,
", ".join(self._get_parameter_names()),
)
collector.add_change(start, end, new_contents)
insertion = self._get_class_insertion_point()
collector.add_change(
insertion, insertion, "\n\n" + self.get_new_class(classname)
)
changes = change.ChangeSet("Replace method with method object refactoring")
changes.add_change(
change.ChangeContents(self.resource, collector.get_changed())
)
return changes
def _get_class_insertion_point(self):
current = self.pyfunction
while current.parent != self.pymodule:
current = current.parent
end = self.pymodule.lines.get_line_end(current.get_scope().get_end())
return min(end + 1, len(self.pymodule.source_code))
def _get_body(self):
body = sourceutils.get_body(self.pyfunction)
for param in self._get_parameter_names():
body = param + " = None\n" + body
pymod = libutils.get_string_module(self.project, body, self.resource)
pyname = pymod[param]
finder = occurrences.create_finder(self.project, param, pyname)
result = rename.rename_in_module(finder, "self." + param, pymodule=pymod)
body = result[result.index("\n") + 1 :]
return body
def _get_init(self):
params = self._get_parameter_names()
indents = " " * sourceutils.get_indent(self.project)
if not params:
return ""
header = indents + "def __init__(self"
body = ""
for arg in params:
new_name = arg
if arg == "self":
new_name = "host"
header += ", %s" % new_name
body += indents * 2 + "self.%s = %s\n" % (arg, new_name)
header += "):"
return "%s\n%s\n" % (header, body)
def _get_parameter_names(self):
return self.pyfunction.get_param_names()

View file

@ -0,0 +1,845 @@
"""A module containing classes for move refactoring
`create_move()` is a factory for creating move refactoring objects
based on inputs.
"""
from rope.base import (
pyobjects,
codeanalyze,
exceptions,
pynames,
taskhandle,
evaluate,
worder,
libutils,
)
from rope.base.change import ChangeSet, ChangeContents, MoveResource
from rope.refactor import importutils, rename, occurrences, sourceutils, functionutils
def create_move(project, resource, offset=None):
"""A factory for creating Move objects
Based on `resource` and `offset`, return one of `MoveModule`,
`MoveGlobal` or `MoveMethod` for performing move refactoring.
"""
if offset is None:
return MoveModule(project, resource)
this_pymodule = project.get_pymodule(resource)
pyname = evaluate.eval_location(this_pymodule, offset)
if pyname is not None:
pyobject = pyname.get_object()
if isinstance(pyobject, pyobjects.PyModule) or isinstance(
pyobject, pyobjects.PyPackage
):
return MoveModule(project, pyobject.get_resource())
if isinstance(pyobject, pyobjects.PyFunction) and isinstance(
pyobject.parent, pyobjects.PyClass
):
return MoveMethod(project, resource, offset)
if (
isinstance(pyobject, pyobjects.PyDefinedObject)
and isinstance(pyobject.parent, pyobjects.PyModule)
or isinstance(pyname, pynames.AssignedName)
):
return MoveGlobal(project, resource, offset)
raise exceptions.RefactoringError(
"Move only works on global classes/functions/variables, modules and " "methods."
)
class MoveMethod(object):
"""For moving methods
It makes a new method in the destination class and changes
the body of the old method to call the new method. You can
inline the old method to change all of its occurrences.
"""
def __init__(self, project, resource, offset):
self.project = project
this_pymodule = self.project.get_pymodule(resource)
pyname = evaluate.eval_location(this_pymodule, offset)
self.method_name = worder.get_name_at(resource, offset)
self.pyfunction = pyname.get_object()
if self.pyfunction.get_kind() != "method":
raise exceptions.RefactoringError("Only normal methods" " can be moved.")
def get_changes(
self,
dest_attr,
new_name=None,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Return the changes needed for this refactoring
Parameters:
- `dest_attr`: the name of the destination attribute
- `new_name`: the name of the new method; if `None` uses
the old name
- `resources` can be a list of `rope.base.resources.File` to
apply this refactoring on. If `None`, the restructuring
will be applied to all python files.
"""
changes = ChangeSet("Moving method <%s>" % self.method_name)
if resources is None:
resources = self.project.get_python_files()
if new_name is None:
new_name = self.get_method_name()
resource1, start1, end1, new_content1 = self._get_changes_made_by_old_class(
dest_attr, new_name
)
collector1 = codeanalyze.ChangeCollector(resource1.read())
collector1.add_change(start1, end1, new_content1)
resource2, start2, end2, new_content2 = self._get_changes_made_by_new_class(
dest_attr, new_name
)
if resource1 == resource2:
collector1.add_change(start2, end2, new_content2)
else:
collector2 = codeanalyze.ChangeCollector(resource2.read())
collector2.add_change(start2, end2, new_content2)
result = collector2.get_changed()
import_tools = importutils.ImportTools(self.project)
new_imports = self._get_used_imports(import_tools)
if new_imports:
goal_pymodule = libutils.get_string_module(
self.project, result, resource2
)
result = _add_imports_to_module(
import_tools, goal_pymodule, new_imports
)
if resource2 in resources:
changes.add_change(ChangeContents(resource2, result))
if resource1 in resources:
changes.add_change(ChangeContents(resource1, collector1.get_changed()))
return changes
def get_method_name(self):
return self.method_name
def _get_used_imports(self, import_tools):
return importutils.get_imports(self.project, self.pyfunction)
def _get_changes_made_by_old_class(self, dest_attr, new_name):
pymodule = self.pyfunction.get_module()
indents = self._get_scope_indents(self.pyfunction)
body = "return self.%s.%s(%s)\n" % (
dest_attr,
new_name,
self._get_passed_arguments_string(),
)
region = sourceutils.get_body_region(self.pyfunction)
return (
pymodule.get_resource(),
region[0],
region[1],
sourceutils.fix_indentation(body, indents),
)
def _get_scope_indents(self, pyobject):
pymodule = pyobject.get_module()
return sourceutils.get_indents(
pymodule.lines, pyobject.get_scope().get_start()
) + sourceutils.get_indent(self.project)
def _get_changes_made_by_new_class(self, dest_attr, new_name):
old_pyclass = self.pyfunction.parent
if dest_attr not in old_pyclass:
raise exceptions.RefactoringError(
"Destination attribute <%s> not found" % dest_attr
)
pyclass = old_pyclass[dest_attr].get_object().get_type()
if not isinstance(pyclass, pyobjects.PyClass):
raise exceptions.RefactoringError(
"Unknown class type for attribute <%s>" % dest_attr
)
pymodule = pyclass.get_module()
resource = pyclass.get_module().get_resource()
start, end = sourceutils.get_body_region(pyclass)
pre_blanks = "\n"
if pymodule.source_code[start:end].strip() != "pass":
pre_blanks = "\n\n"
start = end
indents = self._get_scope_indents(pyclass)
body = pre_blanks + sourceutils.fix_indentation(
self.get_new_method(new_name), indents
)
return resource, start, end, body
def get_new_method(self, name):
return "%s\n%s" % (
self._get_new_header(name),
sourceutils.fix_indentation(
self._get_body(), sourceutils.get_indent(self.project)
),
)
def _get_unchanged_body(self):
return sourceutils.get_body(self.pyfunction)
def _get_body(self, host="host"):
self_name = self._get_self_name()
body = self_name + " = None\n" + self._get_unchanged_body()
pymodule = libutils.get_string_module(self.project, body)
finder = occurrences.create_finder(self.project, self_name, pymodule[self_name])
result = rename.rename_in_module(finder, host, pymodule=pymodule)
if result is None:
result = body
return result[result.index("\n") + 1 :]
def _get_self_name(self):
return self.pyfunction.get_param_names()[0]
def _get_new_header(self, name):
header = "def %s(self" % name
if self._is_host_used():
header += ", host"
definition_info = functionutils.DefinitionInfo.read(self.pyfunction)
others = definition_info.arguments_to_string(1)
if others:
header += ", " + others
return header + "):"
def _get_passed_arguments_string(self):
result = ""
if self._is_host_used():
result = "self"
definition_info = functionutils.DefinitionInfo.read(self.pyfunction)
others = definition_info.arguments_to_string(1)
if others:
if result:
result += ", "
result += others
return result
def _is_host_used(self):
return self._get_body("__old_self") != self._get_unchanged_body()
class MoveGlobal(object):
"""For moving global function and classes"""
def __init__(self, project, resource, offset):
self.project = project
this_pymodule = self.project.get_pymodule(resource)
self.old_pyname = evaluate.eval_location(this_pymodule, offset)
if self.old_pyname is None:
raise exceptions.RefactoringError(
"Move refactoring should be performed on a class/function/variable."
)
if self._is_variable(self.old_pyname):
self.old_name = worder.get_name_at(resource, offset)
pymodule = this_pymodule
else:
self.old_name = self.old_pyname.get_object().get_name()
pymodule = self.old_pyname.get_object().get_module()
self._check_exceptional_conditions()
self.source = pymodule.get_resource()
self.tools = _MoveTools(
self.project, self.source, self.old_pyname, self.old_name
)
self.import_tools = self.tools.import_tools
def _import_filter(self, stmt):
module_name = libutils.modname(self.source)
if isinstance(stmt.import_info, importutils.NormalImport):
# Affect any statement that imports the source module
return any(
module_name == name
for name, alias in stmt.import_info.names_and_aliases
)
elif isinstance(stmt.import_info, importutils.FromImport):
# Affect statements importing from the source package
if "." in module_name:
package_name, basename = module_name.rsplit(".", 1)
if stmt.import_info.module_name == package_name and any(
basename == name
for name, alias in stmt.import_info.names_and_aliases
):
return True
return stmt.import_info.module_name == module_name
return False
def _check_exceptional_conditions(self):
if self._is_variable(self.old_pyname):
pymodule = self.old_pyname.get_definition_location()[0]
try:
pymodule.get_scope().get_name(self.old_name)
except exceptions.NameNotFoundError:
self._raise_refactoring_error()
elif not (
isinstance(self.old_pyname.get_object(), pyobjects.PyDefinedObject)
and self._is_global(self.old_pyname.get_object())
):
self._raise_refactoring_error()
def _raise_refactoring_error(self):
raise exceptions.RefactoringError(
"Move refactoring should be performed on a global class, function "
"or variable."
)
def _is_global(self, pyobject):
return pyobject.get_scope().parent == pyobject.get_module().get_scope()
def _is_variable(self, pyname):
return isinstance(pyname, pynames.AssignedName)
def get_changes(
self, dest, resources=None, task_handle=taskhandle.NullTaskHandle()
):
if resources is None:
resources = self.project.get_python_files()
if dest is None or not dest.exists():
raise exceptions.RefactoringError("Move destination does not exist.")
if dest.is_folder() and dest.has_child("__init__.py"):
dest = dest.get_child("__init__.py")
if dest.is_folder():
raise exceptions.RefactoringError(
"Move destination for non-modules should not be folders."
)
if self.source == dest:
raise exceptions.RefactoringError(
"Moving global elements to the same module."
)
return self._calculate_changes(dest, resources, task_handle)
def _calculate_changes(self, dest, resources, task_handle):
changes = ChangeSet("Moving global <%s>" % self.old_name)
job_set = task_handle.create_jobset("Collecting Changes", len(resources))
for file_ in resources:
job_set.started_job(file_.path)
if file_ == self.source:
changes.add_change(self._source_module_changes(dest))
elif file_ == dest:
changes.add_change(self._dest_module_changes(dest))
elif self.tools.occurs_in_module(resource=file_):
pymodule = self.project.get_pymodule(file_)
# Changing occurrences
placeholder = "__rope_renaming_%s_" % self.old_name
source = self.tools.rename_in_module(placeholder, resource=file_)
should_import = source is not None
# Removing out of date imports
pymodule = self.tools.new_pymodule(pymodule, source)
source = self.import_tools.organize_imports(
pymodule, sort=False, import_filter=self._import_filter
)
# Adding new import
if should_import:
pymodule = self.tools.new_pymodule(pymodule, source)
source, imported = importutils.add_import(
self.project, pymodule, self._new_modname(dest), self.old_name
)
source = source.replace(placeholder, imported)
source = self.tools.new_source(pymodule, source)
if source != file_.read():
changes.add_change(ChangeContents(file_, source))
job_set.finished_job()
return changes
def _source_module_changes(self, dest):
placeholder = "__rope_moving_%s_" % self.old_name
handle = _ChangeMoveOccurrencesHandle(placeholder)
occurrence_finder = occurrences.create_finder(
self.project, self.old_name, self.old_pyname
)
start, end = self._get_moving_region()
renamer = ModuleSkipRenamer(occurrence_finder, self.source, handle, start, end)
source = renamer.get_changed_module()
pymodule = libutils.get_string_module(self.project, source, self.source)
source = self.import_tools.organize_imports(pymodule, sort=False)
if handle.occurred:
pymodule = libutils.get_string_module(self.project, source, self.source)
# Adding new import
source, imported = importutils.add_import(
self.project, pymodule, self._new_modname(dest), self.old_name
)
source = source.replace(placeholder, imported)
return ChangeContents(self.source, source)
def _new_modname(self, dest):
return libutils.modname(dest)
def _dest_module_changes(self, dest):
# Changing occurrences
pymodule = self.project.get_pymodule(dest)
source = self.tools.rename_in_module(self.old_name, pymodule)
pymodule = self.tools.new_pymodule(pymodule, source)
moving, imports = self._get_moving_element_with_imports()
pymodule, has_changed = self._add_imports2(pymodule, imports)
module_with_imports = self.import_tools.module_imports(pymodule)
source = pymodule.source_code
lineno = 0
if module_with_imports.imports:
lineno = module_with_imports.imports[-1].end_line - 1
else:
while lineno < pymodule.lines.length() and pymodule.lines.get_line(
lineno + 1
).lstrip().startswith("#"):
lineno += 1
if lineno > 0:
cut = pymodule.lines.get_line_end(lineno) + 1
result = source[:cut] + "\n\n" + moving + source[cut:]
else:
result = moving + source
# Organizing imports
source = result
pymodule = libutils.get_string_module(self.project, source, dest)
source = self.import_tools.organize_imports(pymodule, sort=False, unused=False)
# Remove unused imports of the old module
pymodule = libutils.get_string_module(self.project, source, dest)
source = self.import_tools.organize_imports(
pymodule,
sort=False,
selfs=False,
unused=True,
import_filter=self._import_filter,
)
return ChangeContents(dest, source)
def _get_moving_element_with_imports(self):
return moving_code_with_imports(
self.project, self.source, self._get_moving_element()
)
def _get_module_with_imports(self, source_code, resource):
pymodule = libutils.get_string_module(self.project, source_code, resource)
return self.import_tools.module_imports(pymodule)
def _get_moving_element(self):
start, end = self._get_moving_region()
moving = self.source.read()[start:end]
return moving.rstrip() + "\n"
def _get_moving_region(self):
pymodule = self.project.get_pymodule(self.source)
lines = pymodule.lines
if self._is_variable(self.old_pyname):
logical_lines = pymodule.logical_lines
lineno = logical_lines.logical_line_in(
self.old_pyname.get_definition_location()[1]
)[0]
start = lines.get_line_start(lineno)
end_line = logical_lines.logical_line_in(lineno)[1]
else:
scope = self.old_pyname.get_object().get_scope()
start = lines.get_line_start(scope.get_start())
end_line = scope.get_end()
# Include comment lines before the definition
start_line = lines.get_line_number(start)
while start_line > 1 and lines.get_line(start_line - 1).startswith("#"):
start_line -= 1
start = lines.get_line_start(start_line)
while end_line < lines.length() and lines.get_line(end_line + 1).strip() == "":
end_line += 1
end = min(lines.get_line_end(end_line) + 1, len(pymodule.source_code))
return start, end
def _add_imports2(self, pymodule, new_imports):
source = self.tools.add_imports(pymodule, new_imports)
if source is None:
return pymodule, False
else:
resource = pymodule.get_resource()
pymodule = libutils.get_string_module(self.project, source, resource)
return pymodule, True
class MoveModule(object):
"""For moving modules and packages"""
def __init__(self, project, resource):
self.project = project
if not resource.is_folder() and resource.name == "__init__.py":
resource = resource.parent
if resource.is_folder() and not resource.has_child("__init__.py"):
raise exceptions.RefactoringError("Cannot move non-package folder.")
dummy_pymodule = libutils.get_string_module(self.project, "")
self.old_pyname = pynames.ImportedModule(dummy_pymodule, resource=resource)
self.source = self.old_pyname.get_object().get_resource()
if self.source.is_folder():
self.old_name = self.source.name
else:
self.old_name = self.source.name[:-3]
self.tools = _MoveTools(
self.project, self.source, self.old_pyname, self.old_name
)
self.import_tools = self.tools.import_tools
def get_changes(
self, dest, resources=None, task_handle=taskhandle.NullTaskHandle()
):
if resources is None:
resources = self.project.get_python_files()
if dest is None or not dest.is_folder():
raise exceptions.RefactoringError(
"Move destination for modules should be packages."
)
return self._calculate_changes(dest, resources, task_handle)
def _calculate_changes(self, dest, resources, task_handle):
changes = ChangeSet("Moving module <%s>" % self.old_name)
job_set = task_handle.create_jobset("Collecting changes", len(resources))
for module in resources:
job_set.started_job(module.path)
if module == self.source:
self._change_moving_module(changes, dest)
else:
source = self._change_occurrences_in_module(dest, resource=module)
if source is not None:
changes.add_change(ChangeContents(module, source))
job_set.finished_job()
if self.project == self.source.project:
changes.add_change(MoveResource(self.source, dest.path))
return changes
def _new_modname(self, dest):
destname = libutils.modname(dest)
if destname:
return destname + "." + self.old_name
return self.old_name
def _new_import(self, dest):
return importutils.NormalImport([(self._new_modname(dest), None)])
def _change_moving_module(self, changes, dest):
if not self.source.is_folder():
pymodule = self.project.get_pymodule(self.source)
source = self.import_tools.relatives_to_absolutes(pymodule)
pymodule = self.tools.new_pymodule(pymodule, source)
source = self._change_occurrences_in_module(dest, pymodule)
source = self.tools.new_source(pymodule, source)
if source != self.source.read():
changes.add_change(ChangeContents(self.source, source))
def _change_occurrences_in_module(self, dest, pymodule=None, resource=None):
if not self.tools.occurs_in_module(pymodule=pymodule, resource=resource):
return
if pymodule is None:
pymodule = self.project.get_pymodule(resource)
new_name = self._new_modname(dest)
module_imports = importutils.get_module_imports(self.project, pymodule)
changed = False
source = None
if libutils.modname(dest):
changed = self._change_import_statements(dest, new_name, module_imports)
if changed:
source = module_imports.get_changed_source()
source = self.tools.new_source(pymodule, source)
pymodule = self.tools.new_pymodule(pymodule, source)
new_import = self._new_import(dest)
source = self.tools.rename_in_module(
new_name,
imports=True,
pymodule=pymodule,
resource=resource if not changed else None,
)
should_import = self.tools.occurs_in_module(
pymodule=pymodule, resource=resource, imports=False
)
pymodule = self.tools.new_pymodule(pymodule, source)
source = self.tools.remove_old_imports(pymodule)
if should_import:
pymodule = self.tools.new_pymodule(pymodule, source)
source = self.tools.add_imports(pymodule, [new_import])
source = self.tools.new_source(pymodule, source)
if source is not None and source != pymodule.resource.read():
return source
return None
def _change_import_statements(self, dest, new_name, module_imports):
moving_module = self.source
parent_module = moving_module.parent
changed = False
for import_stmt in module_imports.imports:
if not any(
name_and_alias[0] == self.old_name
for name_and_alias in import_stmt.import_info.names_and_aliases
) and not any(
name_and_alias[0] == libutils.modname(self.source)
for name_and_alias in import_stmt.import_info.names_and_aliases
):
continue
# Case 1: Look for normal imports of the moving module.
if isinstance(import_stmt.import_info, importutils.NormalImport):
continue
# Case 2: The moving module is from-imported.
changed = (
self._handle_moving_in_from_import_stmt(
dest, import_stmt, module_imports, parent_module
)
or changed
)
# Case 3: Names are imported from the moving module.
context = importutils.importinfo.ImportContext(self.project, None)
if (
not import_stmt.import_info.is_empty()
and import_stmt.import_info.get_imported_resource(context)
== moving_module
):
import_stmt.import_info = importutils.FromImport(
new_name,
import_stmt.import_info.level,
import_stmt.import_info.names_and_aliases,
)
changed = True
return changed
def _handle_moving_in_from_import_stmt(
self, dest, import_stmt, module_imports, parent_module
):
changed = False
context = importutils.importinfo.ImportContext(self.project, None)
if import_stmt.import_info.get_imported_resource(context) == parent_module:
imports = import_stmt.import_info.names_and_aliases
new_imports = []
for name, alias in imports:
# The moving module was imported.
if name == self.old_name:
changed = True
new_import = importutils.FromImport(
libutils.modname(dest), 0, [(self.old_name, alias)]
)
module_imports.add_import(new_import)
else:
new_imports.append((name, alias))
# Update the imports if the imported names were changed.
if new_imports != imports:
changed = True
if new_imports:
import_stmt.import_info = importutils.FromImport(
import_stmt.import_info.module_name,
import_stmt.import_info.level,
new_imports,
)
else:
import_stmt.empty_import()
return changed
class _ChangeMoveOccurrencesHandle(object):
def __init__(self, new_name):
self.new_name = new_name
self.occurred = False
def occurred_inside_skip(self, change_collector, occurrence):
pass
def occurred_outside_skip(self, change_collector, occurrence):
start, end = occurrence.get_primary_range()
change_collector.add_change(start, end, self.new_name)
self.occurred = True
class _MoveTools(object):
def __init__(self, project, source, pyname, old_name):
self.project = project
self.source = source
self.old_pyname = pyname
self.old_name = old_name
self.import_tools = importutils.ImportTools(self.project)
def remove_old_imports(self, pymodule):
old_source = pymodule.source_code
module_with_imports = self.import_tools.module_imports(pymodule)
class CanSelect(object):
changed = False
old_name = self.old_name
old_pyname = self.old_pyname
def __call__(self, name):
try:
if (
name == self.old_name
and pymodule[name].get_object() == self.old_pyname.get_object()
):
self.changed = True
return False
except exceptions.AttributeNotFoundError:
pass
return True
can_select = CanSelect()
module_with_imports.filter_names(can_select)
new_source = module_with_imports.get_changed_source()
if old_source != new_source:
return new_source
def rename_in_module(self, new_name, pymodule=None, imports=False, resource=None):
occurrence_finder = self._create_finder(imports)
source = rename.rename_in_module(
occurrence_finder,
new_name,
replace_primary=True,
pymodule=pymodule,
resource=resource,
)
return source
def occurs_in_module(self, pymodule=None, resource=None, imports=True):
finder = self._create_finder(imports)
for occurrence in finder.find_occurrences(pymodule=pymodule, resource=resource):
return True
return False
def _create_finder(self, imports):
return occurrences.create_finder(
self.project,
self.old_name,
self.old_pyname,
imports=imports,
keywords=False,
)
def new_pymodule(self, pymodule, source):
if source is not None:
return libutils.get_string_module(
self.project, source, pymodule.get_resource()
)
return pymodule
def new_source(self, pymodule, source):
if source is None:
return pymodule.source_code
return source
def add_imports(self, pymodule, new_imports):
return _add_imports_to_module(self.import_tools, pymodule, new_imports)
def _add_imports_to_module(import_tools, pymodule, new_imports):
module_with_imports = import_tools.module_imports(pymodule)
for new_import in new_imports:
module_with_imports.add_import(new_import)
return module_with_imports.get_changed_source()
def moving_code_with_imports(project, resource, source):
import_tools = importutils.ImportTools(project)
pymodule = libutils.get_string_module(project, source, resource)
# Strip comment prefix, if any. These need to stay before the moving
# section, but imports would be added between them.
lines = codeanalyze.SourceLinesAdapter(source)
start = 1
while start < lines.length() and lines.get_line(start).startswith("#"):
start += 1
moving_prefix = source[: lines.get_line_start(start)]
pymodule = libutils.get_string_module(
project, source[lines.get_line_start(start) :], resource
)
origin = project.get_pymodule(resource)
imports = []
for stmt in import_tools.module_imports(origin).imports:
imports.append(stmt.import_info)
back_names = []
for name in origin:
if name not in pymodule:
back_names.append(name)
imports.append(import_tools.get_from_import(resource, back_names))
source = _add_imports_to_module(import_tools, pymodule, imports)
pymodule = libutils.get_string_module(project, source, resource)
source = import_tools.relatives_to_absolutes(pymodule)
pymodule = libutils.get_string_module(project, source, resource)
source = import_tools.organize_imports(pymodule, selfs=False)
pymodule = libutils.get_string_module(project, source, resource)
# extracting imports after changes
module_imports = import_tools.module_imports(pymodule)
imports = [import_stmt.import_info for import_stmt in module_imports.imports]
start = 1
if module_imports.imports:
start = module_imports.imports[-1].end_line
lines = codeanalyze.SourceLinesAdapter(source)
while start < lines.length() and not lines.get_line(start).strip():
start += 1
# Reinsert the prefix which was removed at the beginning
moving = moving_prefix + source[lines.get_line_start(start) :]
return moving, imports
class ModuleSkipRenamerHandle(object):
def occurred_outside_skip(self, change_collector, occurrence):
pass
def occurred_inside_skip(self, change_collector, occurrence):
pass
class ModuleSkipRenamer(object):
"""Rename occurrences in a module
This class can be used when you want to treat a region in a file
separately from other parts when renaming.
"""
def __init__(
self,
occurrence_finder,
resource,
handle=None,
skip_start=0,
skip_end=0,
replacement="",
):
"""Constructor
if replacement is `None` the region is not changed. Otherwise
it is replaced with `replacement`.
"""
self.occurrence_finder = occurrence_finder
self.resource = resource
self.skip_start = skip_start
self.skip_end = skip_end
self.replacement = replacement
self.handle = handle
if self.handle is None:
self.handle = ModuleSkipRenamerHandle()
def get_changed_module(self):
source = self.resource.read()
change_collector = codeanalyze.ChangeCollector(source)
if self.replacement is not None:
change_collector.add_change(
self.skip_start, self.skip_end, self.replacement
)
for occurrence in self.occurrence_finder.find_occurrences(self.resource):
start, end = occurrence.get_primary_range()
if self.skip_start <= start < self.skip_end:
self.handle.occurred_inside_skip(change_collector, occurrence)
else:
self.handle.occurred_outside_skip(change_collector, occurrence)
result = change_collector.get_changed()
if result is not None and result != source:
return result

View file

@ -0,0 +1,76 @@
"""This module can be used for performing cross-project refactorings
See the "cross-project refactorings" section of ``docs/library.rst``
file.
"""
from rope.base import resources, libutils
class MultiProjectRefactoring(object):
def __init__(self, refactoring, projects, addpath=True):
"""Create a multiproject proxy for the main refactoring
`projects` are other project.
"""
self.refactoring = refactoring
self.projects = projects
self.addpath = addpath
def __call__(self, project, *args, **kwds):
"""Create the refactoring"""
return _MultiRefactoring(
self.refactoring, self.projects, self.addpath, project, *args, **kwds
)
class _MultiRefactoring(object):
def __init__(self, refactoring, other_projects, addpath, project, *args, **kwds):
self.refactoring = refactoring
self.projects = [project] + other_projects
for other_project in other_projects:
for folder in self.project.get_source_folders():
other_project.get_prefs().add("python_path", folder.real_path)
self.refactorings = []
for other in self.projects:
args, kwds = self._resources_for_args(other, args, kwds)
self.refactorings.append(self.refactoring(other, *args, **kwds))
def get_all_changes(self, *args, **kwds):
"""Get a project to changes dict"""
result = []
for project, refactoring in zip(self.projects, self.refactorings):
args, kwds = self._resources_for_args(project, args, kwds)
result.append((project, refactoring.get_changes(*args, **kwds)))
return result
def __getattr__(self, name):
return getattr(self.main_refactoring, name)
def _resources_for_args(self, project, args, kwds):
newargs = [self._change_project_resource(project, arg) for arg in args]
newkwds = dict(
(name, self._change_project_resource(project, value))
for name, value in kwds.items()
)
return newargs, newkwds
def _change_project_resource(self, project, obj):
if isinstance(obj, resources.Resource) and obj.project != project:
return libutils.path_to_resource(project, obj.real_path)
return obj
@property
def project(self):
return self.projects[0]
@property
def main_refactoring(self):
return self.refactorings[0]
def perform(project_changes):
for project, changes in project_changes:
project.do(changes)

View file

@ -0,0 +1,424 @@
"""Find occurrences of a name in a project.
This module consists of a `Finder` that finds all occurrences of a name
in a project. The `Finder.find_occurrences()` method is a generator that
yields `Occurrence` instances for each occurrence of the name. To create
a `Finder` object, use the `create_finder()` function:
finder = occurrences.create_finder(project, 'foo', pyname)
for occurrence in finder.find_occurrences():
pass
It's possible to filter the occurrences. They can be specified when
calling the `create_finder()` function.
* `only_calls`: If True, return only those instances where the name is
a function that's being called.
* `imports`: If False, don't return instances that are in import
statements.
* `unsure`: If a prediate function, return instances where we don't
know what the name references. It also filters based on the
predicate function.
* `docs`: If True, it will search for occurrences in regions normally
ignored. E.g., strings and comments.
* `in_hierarchy`: If True, it will find occurrences if the name is in
the class's hierarchy.
* `instance`: Used only when you want implicit interfaces to be
considered.
* `keywords`: If False, don't return instances that are the names of keyword
arguments
"""
import ast
import re
from rope.base import codeanalyze
from rope.base import evaluate
from rope.base import exceptions
from rope.base import pynames
from rope.base import pyobjects
from rope.base import utils
from rope.base import worder
class Finder(object):
"""For finding occurrences of a name
The constructor takes a `filters` argument. It should be a list
of functions that take a single argument. For each possible
occurrence, these functions are called in order with the an
instance of `Occurrence`:
* If it returns `None` other filters are tried.
* If it returns `True`, the occurrence will be a match.
* If it returns `False`, the occurrence will be skipped.
* If all of the filters return `None`, it is skipped also.
"""
def __init__(self, project, name, filters=[lambda o: True], docs=False):
self.project = project
self.name = name
self.docs = docs
self.filters = filters
self._textual_finder = _TextualFinder(name, docs=docs)
def find_occurrences(self, resource=None, pymodule=None):
"""Generate `Occurrence` instances"""
tools = _OccurrenceToolsCreator(
self.project, resource=resource, pymodule=pymodule, docs=self.docs
)
for offset in self._textual_finder.find_offsets(tools.source_code):
occurrence = Occurrence(tools, offset)
for filter in self.filters:
result = filter(occurrence)
if result is None:
continue
if result:
yield occurrence
break
def create_finder(
project,
name,
pyname,
only_calls=False,
imports=True,
unsure=None,
docs=False,
instance=None,
in_hierarchy=False,
keywords=True,
):
"""A factory for `Finder`
Based on the arguments it creates a list of filters. `instance`
argument is needed only when you want implicit interfaces to be
considered.
"""
pynames_ = set([pyname])
filters = []
if only_calls:
filters.append(CallsFilter())
if not imports:
filters.append(NoImportsFilter())
if not keywords:
filters.append(NoKeywordsFilter())
if isinstance(instance, pynames.ParameterName):
for pyobject in instance.get_objects():
try:
pynames_.add(pyobject[name])
except exceptions.AttributeNotFoundError:
pass
for pyname in pynames_:
filters.append(PyNameFilter(pyname))
if in_hierarchy:
filters.append(InHierarchyFilter(pyname))
if unsure:
filters.append(UnsureFilter(unsure))
return Finder(project, name, filters=filters, docs=docs)
class Occurrence(object):
def __init__(self, tools, offset):
self.tools = tools
self.offset = offset
self.resource = tools.resource
@utils.saveit
def get_word_range(self):
return self.tools.word_finder.get_word_range(self.offset)
@utils.saveit
def get_primary_range(self):
return self.tools.word_finder.get_primary_range(self.offset)
@utils.saveit
def get_pyname(self):
try:
return self.tools.name_finder.get_pyname_at(self.offset)
except exceptions.BadIdentifierError:
pass
@utils.saveit
def get_primary_and_pyname(self):
try:
return self.tools.name_finder.get_primary_and_pyname_at(self.offset)
except exceptions.BadIdentifierError:
pass
@utils.saveit
def is_in_import_statement(self):
return self.tools.word_finder.is_from_statement(
self.offset
) or self.tools.word_finder.is_import_statement(self.offset)
def is_called(self):
return self.tools.word_finder.is_a_function_being_called(self.offset)
def is_defined(self):
return self.tools.word_finder.is_a_class_or_function_name_in_header(self.offset)
def is_a_fixed_primary(self):
return self.tools.word_finder.is_a_class_or_function_name_in_header(
self.offset
) or self.tools.word_finder.is_a_name_after_from_import(self.offset)
def is_written(self):
return self.tools.word_finder.is_assigned_here(self.offset)
def is_unsure(self):
return unsure_pyname(self.get_pyname())
def is_function_keyword_parameter(self):
return self.tools.word_finder.is_function_keyword_parameter(self.offset)
@property
@utils.saveit
def lineno(self):
offset = self.get_word_range()[0]
return self.tools.pymodule.lines.get_line_number(offset)
def same_pyname(expected, pyname):
"""Check whether `expected` and `pyname` are the same"""
if expected is None or pyname is None:
return False
if expected == pyname:
return True
if type(expected) not in (pynames.ImportedModule, pynames.ImportedName) and type(
pyname
) not in (pynames.ImportedModule, pynames.ImportedName):
return False
return (
expected.get_definition_location() == pyname.get_definition_location()
and expected.get_object() == pyname.get_object()
)
def unsure_pyname(pyname, unbound=True):
"""Return `True` if we don't know what this name references"""
if pyname is None:
return True
if unbound and not isinstance(pyname, pynames.UnboundName):
return False
if pyname.get_object() == pyobjects.get_unknown():
return True
class PyNameFilter(object):
"""For finding occurrences of a name."""
def __init__(self, pyname):
self.pyname = pyname
def __call__(self, occurrence):
if same_pyname(self.pyname, occurrence.get_pyname()):
return True
class InHierarchyFilter(object):
"""Finds the occurrence if the name is in the class's hierarchy."""
def __init__(self, pyname, implementations_only=False):
self.pyname = pyname
self.impl_only = implementations_only
self.pyclass = self._get_containing_class(pyname)
if self.pyclass is not None:
self.name = pyname.get_object().get_name()
self.roots = self._get_root_classes(self.pyclass, self.name)
else:
self.roots = None
def __call__(self, occurrence):
if self.roots is None:
return
pyclass = self._get_containing_class(occurrence.get_pyname())
if pyclass is not None:
roots = self._get_root_classes(pyclass, self.name)
if self.roots.intersection(roots):
return True
def _get_containing_class(self, pyname):
if isinstance(pyname, pynames.DefinedName):
scope = pyname.get_object().get_scope()
parent = scope.parent
if parent is not None and parent.get_kind() == "Class":
return parent.pyobject
def _get_root_classes(self, pyclass, name):
if self.impl_only and pyclass == self.pyclass:
return set([pyclass])
result = set()
for superclass in pyclass.get_superclasses():
if name in superclass:
result.update(self._get_root_classes(superclass, name))
if not result:
return set([pyclass])
return result
class UnsureFilter(object):
"""Occurrences where we don't knoow what the name references."""
def __init__(self, unsure):
self.unsure = unsure
def __call__(self, occurrence):
if occurrence.is_unsure() and self.unsure(occurrence):
return True
class NoImportsFilter(object):
"""Don't include import statements as occurrences."""
def __call__(self, occurrence):
if occurrence.is_in_import_statement():
return False
class CallsFilter(object):
"""Filter out non-call occurrences."""
def __call__(self, occurrence):
if not occurrence.is_called():
return False
class NoKeywordsFilter(object):
"""Filter out keyword parameters."""
def __call__(self, occurrence):
if occurrence.is_function_keyword_parameter():
return False
class _TextualFinder(object):
def __init__(self, name, docs=False):
self.name = name
self.docs = docs
self.comment_pattern = _TextualFinder.any("comment", [r"#[^\n]*"])
self.string_pattern = _TextualFinder.any(
"string", [codeanalyze.get_string_pattern()]
)
self.f_string_pattern = _TextualFinder.any(
"fstring", [codeanalyze.get_formatted_string_pattern()]
)
self.pattern = self._get_occurrence_pattern(self.name)
def find_offsets(self, source):
if not self._fast_file_query(source):
return
if self.docs:
searcher = self._normal_search
else:
searcher = self._re_search
for matched in searcher(source):
yield matched
def _re_search(self, source):
for match in self.pattern.finditer(source):
if match.groupdict()["occurrence"]:
yield match.start("occurrence")
elif utils.pycompat.PY36 and match.groupdict()["fstring"]:
f_string = match.groupdict()["fstring"]
for occurrence_node in self._search_in_f_string(f_string):
yield match.start("fstring") + occurrence_node.col_offset
def _search_in_f_string(self, f_string):
tree = ast.parse(f_string)
for node in ast.walk(tree):
if isinstance(node, ast.Name) and node.id == self.name:
yield node
def _normal_search(self, source):
current = 0
while True:
try:
found = source.index(self.name, current)
current = found + len(self.name)
if (found == 0 or not self._is_id_char(source[found - 1])) and (
current == len(source) or not self._is_id_char(source[current])
):
yield found
except ValueError:
break
def _is_id_char(self, c):
return c.isalnum() or c == "_"
def _fast_file_query(self, source):
try:
source.index(self.name)
return True
except ValueError:
return False
def _get_source(self, resource, pymodule):
if resource is not None:
return resource.read()
else:
return pymodule.source_code
def _get_occurrence_pattern(self, name):
occurrence_pattern = _TextualFinder.any("occurrence", ["\\b" + name + "\\b"])
pattern = re.compile(
occurrence_pattern
+ "|"
+ self.comment_pattern
+ "|"
+ self.string_pattern
+ "|"
+ self.f_string_pattern
)
return pattern
@staticmethod
def any(name, list_):
return "(?P<%s>" % name + "|".join(list_) + ")"
class _OccurrenceToolsCreator(object):
def __init__(self, project, resource=None, pymodule=None, docs=False):
self.project = project
self.__resource = resource
self.__pymodule = pymodule
self.docs = docs
@property
@utils.saveit
def name_finder(self):
return evaluate.ScopeNameFinder(self.pymodule)
@property
@utils.saveit
def source_code(self):
return self.pymodule.source_code
@property
@utils.saveit
def word_finder(self):
return worder.Worder(self.source_code, self.docs)
@property
@utils.saveit
def resource(self):
if self.__resource is not None:
return self.__resource
if self.__pymodule is not None:
return self.__pymodule.resource
@property
@utils.saveit
def pymodule(self):
if self.__pymodule is not None:
return self.__pymodule
return self.project.get_pymodule(self.resource)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,265 @@
import warnings
from rope.base import (
exceptions,
pyobjects,
pynames,
taskhandle,
evaluate,
worder,
codeanalyze,
libutils,
)
from rope.base.change import ChangeSet, ChangeContents, MoveResource
from rope.refactor import occurrences
class Rename(object):
"""A class for performing rename refactoring
It can rename everything: classes, functions, modules, packages,
methods, variables and keyword arguments.
"""
def __init__(self, project, resource, offset=None):
"""If `offset` is None, the `resource` itself will be renamed"""
self.project = project
self.resource = resource
if offset is not None:
self.old_name = worder.get_name_at(self.resource, offset)
this_pymodule = self.project.get_pymodule(self.resource)
self.old_instance, self.old_pyname = evaluate.eval_location2(
this_pymodule, offset
)
if self.old_pyname is None:
raise exceptions.RefactoringError(
"Rename refactoring should be performed"
" on resolvable python identifiers."
)
else:
if not resource.is_folder() and resource.name == "__init__.py":
resource = resource.parent
dummy_pymodule = libutils.get_string_module(self.project, "")
self.old_instance = None
self.old_pyname = pynames.ImportedModule(dummy_pymodule, resource=resource)
if resource.is_folder():
self.old_name = resource.name
else:
self.old_name = resource.name[:-3]
def get_old_name(self):
return self.old_name
def get_changes(
self,
new_name,
in_file=None,
in_hierarchy=False,
unsure=None,
docs=False,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Get the changes needed for this refactoring
Parameters:
- `in_hierarchy`: when renaming a method this keyword forces
to rename all matching methods in the hierarchy
- `docs`: when `True` rename refactoring will rename
occurrences in comments and strings where the name is
visible. Setting it will make renames faster, too.
- `unsure`: decides what to do about unsure occurrences.
If `None`, they are ignored. Otherwise `unsure` is
called with an instance of `occurrence.Occurrence` as
parameter. If it returns `True`, the occurrence is
considered to be a match.
- `resources` can be a list of `rope.base.resources.File` to
apply this refactoring on. If `None`, the restructuring
will be applied to all python files.
- `in_file`: this argument has been deprecated; use
`resources` instead.
"""
if unsure in (True, False):
warnings.warn(
"unsure parameter should be a function that returns " "True or False",
DeprecationWarning,
stacklevel=2,
)
def unsure_func(value=unsure):
return value
unsure = unsure_func
if in_file is not None:
warnings.warn(
"`in_file` argument has been deprecated; use `resources` " "instead. ",
DeprecationWarning,
stacklevel=2,
)
if in_file:
resources = [self.resource]
if _is_local(self.old_pyname):
resources = [self.resource]
if resources is None:
resources = self.project.get_python_files()
changes = ChangeSet("Renaming <%s> to <%s>" % (self.old_name, new_name))
finder = occurrences.create_finder(
self.project,
self.old_name,
self.old_pyname,
unsure=unsure,
docs=docs,
instance=self.old_instance,
in_hierarchy=in_hierarchy and self.is_method(),
)
job_set = task_handle.create_jobset("Collecting Changes", len(resources))
for file_ in resources:
job_set.started_job(file_.path)
new_content = rename_in_module(finder, new_name, resource=file_)
if new_content is not None:
changes.add_change(ChangeContents(file_, new_content))
job_set.finished_job()
if self._is_renaming_a_module():
resource = self.old_pyname.get_object().get_resource()
if self._is_allowed_to_move(resources, resource):
self._rename_module(resource, new_name, changes)
return changes
def _is_allowed_to_move(self, resources, resource):
if resource.is_folder():
try:
return resource.get_child("__init__.py") in resources
except exceptions.ResourceNotFoundError:
return False
else:
return resource in resources
def _is_renaming_a_module(self):
if isinstance(self.old_pyname.get_object(), pyobjects.AbstractModule):
return True
return False
def is_method(self):
pyname = self.old_pyname
return (
isinstance(pyname, pynames.DefinedName)
and isinstance(pyname.get_object(), pyobjects.PyFunction)
and isinstance(pyname.get_object().parent, pyobjects.PyClass)
)
def _rename_module(self, resource, new_name, changes):
if not resource.is_folder():
new_name = new_name + ".py"
parent_path = resource.parent.path
if parent_path == "":
new_location = new_name
else:
new_location = parent_path + "/" + new_name
changes.add_change(MoveResource(resource, new_location))
class ChangeOccurrences(object):
"""A class for changing the occurrences of a name in a scope
This class replaces the occurrences of a name. Note that it only
changes the scope containing the offset passed to the constructor.
What's more it does not have any side-effects. That is for
example changing occurrences of a module does not rename the
module; it merely replaces the occurrences of that module in a
scope with the given expression. This class is useful for
performing many custom refactorings.
"""
def __init__(self, project, resource, offset):
self.project = project
self.resource = resource
self.offset = offset
self.old_name = worder.get_name_at(resource, offset)
self.pymodule = project.get_pymodule(self.resource)
self.old_pyname = evaluate.eval_location(self.pymodule, offset)
def get_old_name(self):
word_finder = worder.Worder(self.resource.read())
return word_finder.get_primary_at(self.offset)
def _get_scope_offset(self):
scope = self.pymodule.get_scope().get_inner_scope_for_offset(self.offset)
return scope.get_region()
def get_changes(self, new_name, only_calls=False, reads=True, writes=True):
changes = ChangeSet(
"Changing <%s> occurrences to <%s>" % (self.old_name, new_name)
)
scope_start, scope_end = self._get_scope_offset()
finder = occurrences.create_finder(
self.project,
self.old_name,
self.old_pyname,
imports=False,
only_calls=only_calls,
)
new_contents = rename_in_module(
finder,
new_name,
pymodule=self.pymodule,
replace_primary=True,
region=(scope_start, scope_end),
reads=reads,
writes=writes,
)
if new_contents is not None:
changes.add_change(ChangeContents(self.resource, new_contents))
return changes
def rename_in_module(
occurrences_finder,
new_name,
resource=None,
pymodule=None,
replace_primary=False,
region=None,
reads=True,
writes=True,
):
"""Returns the changed source or `None` if there is no changes"""
if resource is not None:
source_code = resource.read()
else:
source_code = pymodule.source_code
change_collector = codeanalyze.ChangeCollector(source_code)
for occurrence in occurrences_finder.find_occurrences(resource, pymodule):
if replace_primary and occurrence.is_a_fixed_primary():
continue
if replace_primary:
start, end = occurrence.get_primary_range()
else:
start, end = occurrence.get_word_range()
if (not reads and not occurrence.is_written()) or (
not writes and occurrence.is_written()
):
continue
if region is None or region[0] <= start < region[1]:
change_collector.add_change(start, end, new_name)
return change_collector.get_changed()
def _is_local(pyname):
module, lineno = pyname.get_definition_location()
if lineno is None:
return False
scope = module.get_scope().get_inner_scope_for_line(lineno)
if isinstance(pyname, pynames.DefinedName) and scope.get_kind() in (
"Function",
"Class",
):
scope = scope.parent
return (
scope.get_kind() == "Function"
and pyname in scope.get_names().values()
and isinstance(pyname, pynames.AssignedName)
)

View file

@ -0,0 +1,321 @@
import warnings
from rope.base import change, taskhandle, builtins, ast, codeanalyze
from rope.base import libutils
from rope.refactor import patchedast, similarfinder, sourceutils
from rope.refactor.importutils import module_imports
class Restructure(object):
"""A class to perform python restructurings
A restructuring transforms pieces of code matching `pattern` to
`goal`. In the `pattern` wildcards can appear. Wildcards match
some piece of code based on their kind and arguments that are
passed to them through `args`.
`args` is a dictionary of wildcard names to wildcard arguments.
If the argument is a tuple, the first item of the tuple is
considered to be the name of the wildcard to use; otherwise the
"default" wildcard is used. For getting the list arguments a
wildcard supports, see the pydoc of the wildcard. (see
`rope.refactor.wildcard.DefaultWildcard` for the default
wildcard.)
`wildcards` is the list of wildcard types that can appear in
`pattern`. See `rope.refactor.wildcards`. If a wildcard does not
specify its kind (by using a tuple in args), the wildcard named
"default" is used. So there should be a wildcard with "default"
name in `wildcards`.
`imports` is the list of imports that changed modules should
import. Note that rope handles duplicate imports and does not add
the import if it already appears.
Example #1::
pattern ${pyobject}.get_attribute(${name})
goal ${pyobject}[${name}]
args pyobject: instance=rope.base.pyobjects.PyObject
Example #2::
pattern ${name} in ${pyobject}.get_attributes()
goal ${name} in {pyobject}
args pyobject: instance=rope.base.pyobjects.PyObject
Example #3::
pattern ${pycore}.create_module(${project}.root, ${name})
goal generate.create_module(${project}, ${name})
imports
from rope.contrib import generate
args
project: type=rope.base.project.Project
Example #4::
pattern ${pow}(${param1}, ${param2})
goal ${param1} ** ${param2}
args pow: name=mod.pow, exact
Example #5::
pattern ${inst}.longtask(${p1}, ${p2})
goal
${inst}.subtask1(${p1})
${inst}.subtask2(${p2})
args
inst: type=mod.A,unsure
"""
def __init__(self, project, pattern, goal, args=None, imports=None, wildcards=None):
"""Construct a restructuring
See class pydoc for more info about the arguments.
"""
self.project = project
self.pattern = pattern
self.goal = goal
self.args = args
if self.args is None:
self.args = {}
self.imports = imports
if self.imports is None:
self.imports = []
self.wildcards = wildcards
self.template = similarfinder.CodeTemplate(self.goal)
def get_changes(
self,
checks=None,
imports=None,
resources=None,
task_handle=taskhandle.NullTaskHandle(),
):
"""Get the changes needed by this restructuring
`resources` can be a list of `rope.base.resources.File` to
apply the restructuring on. If `None`, the restructuring will
be applied to all python files.
`checks` argument has been deprecated. Use the `args` argument
of the constructor. The usage of::
strchecks = {'obj1.type': 'mod.A', 'obj2': 'mod.B',
'obj3.object': 'mod.C'}
checks = restructuring.make_checks(strchecks)
can be replaced with::
args = {'obj1': 'type=mod.A', 'obj2': 'name=mod.B',
'obj3': 'object=mod.C'}
where obj1, obj2 and obj3 are wildcard names that appear
in restructuring pattern.
"""
if checks is not None:
warnings.warn(
"The use of checks parameter is deprecated; "
"use the args parameter of the constructor instead.",
DeprecationWarning,
stacklevel=2,
)
for name, value in checks.items():
self.args[name] = similarfinder._pydefined_to_str(value)
if imports is not None:
warnings.warn(
"The use of imports parameter is deprecated; "
"use imports parameter of the constructor, instead.",
DeprecationWarning,
stacklevel=2,
)
self.imports = imports
changes = change.ChangeSet(
"Restructuring <%s> to <%s>" % (self.pattern, self.goal)
)
if resources is not None:
files = [
resource
for resource in resources
if libutils.is_python_file(self.project, resource)
]
else:
files = self.project.get_python_files()
job_set = task_handle.create_jobset("Collecting Changes", len(files))
for resource in files:
job_set.started_job(resource.path)
pymodule = self.project.get_pymodule(resource)
finder = similarfinder.SimilarFinder(pymodule, wildcards=self.wildcards)
matches = list(finder.get_matches(self.pattern, self.args))
computer = self._compute_changes(matches, pymodule)
result = computer.get_changed()
if result is not None:
imported_source = self._add_imports(resource, result, self.imports)
changes.add_change(change.ChangeContents(resource, imported_source))
job_set.finished_job()
return changes
def _compute_changes(self, matches, pymodule):
return _ChangeComputer(
pymodule.source_code,
pymodule.get_ast(),
pymodule.lines,
self.template,
matches,
)
def _add_imports(self, resource, source, imports):
if not imports:
return source
import_infos = self._get_import_infos(resource, imports)
pymodule = libutils.get_string_module(self.project, source, resource)
imports = module_imports.ModuleImports(self.project, pymodule)
for import_info in import_infos:
imports.add_import(import_info)
return imports.get_changed_source()
def _get_import_infos(self, resource, imports):
pymodule = libutils.get_string_module(
self.project, "\n".join(imports), resource
)
imports = module_imports.ModuleImports(self.project, pymodule)
return [imports.import_info for imports in imports.imports]
def make_checks(self, string_checks):
"""Convert str to str dicts to str to PyObject dicts
This function is here to ease writing a UI.
"""
checks = {}
for key, value in string_checks.items():
is_pyname = not key.endswith(".object") and not key.endswith(".type")
evaluated = self._evaluate(value, is_pyname=is_pyname)
if evaluated is not None:
checks[key] = evaluated
return checks
def _evaluate(self, code, is_pyname=True):
attributes = code.split(".")
pyname = None
if attributes[0] in ("__builtin__", "__builtins__"):
class _BuiltinsStub(object):
def get_attribute(self, name):
return builtins.builtins[name]
pyobject = _BuiltinsStub()
else:
pyobject = self.project.get_module(attributes[0])
for attribute in attributes[1:]:
pyname = pyobject[attribute]
if pyname is None:
return None
pyobject = pyname.get_object()
return pyname if is_pyname else pyobject
def replace(code, pattern, goal):
"""used by other refactorings"""
finder = similarfinder.RawSimilarFinder(code)
matches = list(finder.get_matches(pattern))
ast = patchedast.get_patched_ast(code)
lines = codeanalyze.SourceLinesAdapter(code)
template = similarfinder.CodeTemplate(goal)
computer = _ChangeComputer(code, ast, lines, template, matches)
result = computer.get_changed()
if result is None:
return code
return result
class _ChangeComputer(object):
def __init__(self, code, ast, lines, goal, matches):
self.source = code
self.goal = goal
self.matches = matches
self.ast = ast
self.lines = lines
self.matched_asts = {}
self._nearest_roots = {}
if self._is_expression():
for match in self.matches:
self.matched_asts[match.ast] = match
def get_changed(self):
if self._is_expression():
result = self._get_node_text(self.ast)
if result == self.source:
return None
return result
else:
collector = codeanalyze.ChangeCollector(self.source)
last_end = -1
for match in self.matches:
start, end = match.get_region()
if start < last_end:
if not self._is_expression():
continue
last_end = end
replacement = self._get_matched_text(match)
collector.add_change(start, end, replacement)
return collector.get_changed()
def _is_expression(self):
return self.matches and isinstance(
self.matches[0], similarfinder.ExpressionMatch
)
def _get_matched_text(self, match):
mapping = {}
for name in self.goal.get_names():
node = match.get_ast(name)
if node is None:
raise similarfinder.BadNameInCheckError("Unknown name <%s>" % name)
force = self._is_expression() and match.ast == node
mapping[name] = self._get_node_text(node, force)
unindented = self.goal.substitute(mapping)
return self._auto_indent(match.get_region()[0], unindented)
def _get_node_text(self, node, force=False):
if not force and node in self.matched_asts:
return self._get_matched_text(self.matched_asts[node])
start, end = patchedast.node_region(node)
main_text = self.source[start:end]
collector = codeanalyze.ChangeCollector(main_text)
for node in self._get_nearest_roots(node):
sub_start, sub_end = patchedast.node_region(node)
collector.add_change(
sub_start - start, sub_end - start, self._get_node_text(node)
)
result = collector.get_changed()
if result is None:
return main_text
return result
def _auto_indent(self, offset, text):
lineno = self.lines.get_line_number(offset)
indents = sourceutils.get_indents(self.lines, lineno)
result = []
for index, line in enumerate(text.splitlines(True)):
if index != 0 and line.strip():
result.append(" " * indents)
result.append(line)
return "".join(result)
def _get_nearest_roots(self, node):
if node not in self._nearest_roots:
result = []
for child in ast.get_child_nodes(node):
if child in self.matched_asts:
result.append(child)
else:
result.extend(self._get_nearest_roots(child))
self._nearest_roots[node] = result
return self._nearest_roots[node]

View file

@ -0,0 +1,369 @@
"""This module can be used for finding similar code"""
import re
import rope.refactor.wildcards
from rope.base import libutils
from rope.base import codeanalyze, exceptions, ast, builtins
from rope.refactor import patchedast, wildcards
from rope.refactor.patchedast import MismatchedTokenError
class BadNameInCheckError(exceptions.RefactoringError):
pass
class SimilarFinder(object):
"""`SimilarFinder` can be used to find similar pieces of code
See the notes in the `rope.refactor.restructure` module for more
info.
"""
def __init__(self, pymodule, wildcards=None):
"""Construct a SimilarFinder"""
self.source = pymodule.source_code
try:
self.raw_finder = RawSimilarFinder(
pymodule.source_code, pymodule.get_ast(), self._does_match
)
except MismatchedTokenError:
print("in file %s" % pymodule.resource.path)
raise
self.pymodule = pymodule
if wildcards is None:
self.wildcards = {}
for wildcard in [
rope.refactor.wildcards.DefaultWildcard(pymodule.pycore.project)
]:
self.wildcards[wildcard.get_name()] = wildcard
else:
self.wildcards = wildcards
def get_matches(self, code, args={}, start=0, end=None):
self.args = args
if end is None:
end = len(self.source)
skip_region = None
if "skip" in args.get("", {}):
resource, region = args[""]["skip"]
if resource == self.pymodule.get_resource():
skip_region = region
return self.raw_finder.get_matches(code, start=start, end=end, skip=skip_region)
def get_match_regions(self, *args, **kwds):
for match in self.get_matches(*args, **kwds):
yield match.get_region()
def _does_match(self, node, name):
arg = self.args.get(name, "")
kind = "default"
if isinstance(arg, (tuple, list)):
kind = arg[0]
arg = arg[1]
suspect = wildcards.Suspect(self.pymodule, node, name)
return self.wildcards[kind].matches(suspect, arg)
class RawSimilarFinder(object):
"""A class for finding similar expressions and statements"""
def __init__(self, source, node=None, does_match=None):
if node is None:
try:
node = ast.parse(source)
except SyntaxError:
# needed to parse expression containing := operator
node = ast.parse("(" + source + ")")
if does_match is None:
self.does_match = self._simple_does_match
else:
self.does_match = does_match
self._init_using_ast(node, source)
def _simple_does_match(self, node, name):
return isinstance(node, (ast.expr, ast.Name))
def _init_using_ast(self, node, source):
self.source = source
self._matched_asts = {}
if not hasattr(node, "region"):
patchedast.patch_ast(node, source)
self.ast = node
def get_matches(self, code, start=0, end=None, skip=None):
"""Search for `code` in source and return a list of `Match`-es
`code` can contain wildcards. ``${name}`` matches normal
names and ``${?name} can match any expression. You can use
`Match.get_ast()` for getting the node that has matched a
given pattern.
"""
if end is None:
end = len(self.source)
for match in self._get_matched_asts(code):
match_start, match_end = match.get_region()
if start <= match_start and match_end <= end:
if skip is not None and (skip[0] < match_end and skip[1] > match_start):
continue
yield match
def _get_matched_asts(self, code):
if code not in self._matched_asts:
wanted = self._create_pattern(code)
matches = _ASTMatcher(self.ast, wanted, self.does_match).find_matches()
self._matched_asts[code] = matches
return self._matched_asts[code]
def _create_pattern(self, expression):
expression = self._replace_wildcards(expression)
node = ast.parse(expression)
# Getting Module.Stmt.nodes
nodes = node.body
if len(nodes) == 1 and isinstance(nodes[0], ast.Expr):
# Getting Discard.expr
wanted = nodes[0].value
else:
wanted = nodes
return wanted
def _replace_wildcards(self, expression):
ropevar = _RopeVariable()
template = CodeTemplate(expression)
mapping = {}
for name in template.get_names():
mapping[name] = ropevar.get_var(name)
return template.substitute(mapping)
class _ASTMatcher(object):
def __init__(self, body, pattern, does_match):
"""Searches the given pattern in the body AST.
body is an AST node and pattern can be either an AST node or
a list of ASTs nodes
"""
self.body = body
self.pattern = pattern
self.matches = None
self.ropevar = _RopeVariable()
self.matches_callback = does_match
def find_matches(self):
if self.matches is None:
self.matches = []
ast.call_for_nodes(self.body, self._check_node, recursive=True)
return self.matches
def _check_node(self, node):
if isinstance(self.pattern, list):
self._check_statements(node)
else:
self._check_expression(node)
def _check_expression(self, node):
mapping = {}
if self._match_nodes(self.pattern, node, mapping):
self.matches.append(ExpressionMatch(node, mapping))
def _check_statements(self, node):
for child in ast.get_children(node):
if isinstance(child, (list, tuple)):
self.__check_stmt_list(child)
def __check_stmt_list(self, nodes):
for index in range(len(nodes)):
if len(nodes) - index >= len(self.pattern):
current_stmts = nodes[index : index + len(self.pattern)]
mapping = {}
if self._match_stmts(current_stmts, mapping):
self.matches.append(StatementMatch(current_stmts, mapping))
def _match_nodes(self, expected, node, mapping):
if isinstance(expected, ast.Name):
if self.ropevar.is_var(expected.id):
return self._match_wildcard(expected, node, mapping)
if not isinstance(expected, ast.AST):
return expected == node
if expected.__class__ != node.__class__:
return False
children1 = self._get_children(expected)
children2 = self._get_children(node)
if len(children1) != len(children2):
return False
for child1, child2 in zip(children1, children2):
if isinstance(child1, ast.AST):
if not self._match_nodes(child1, child2, mapping):
return False
elif isinstance(child1, (list, tuple)):
if not isinstance(child2, (list, tuple)) or len(child1) != len(child2):
return False
for c1, c2 in zip(child1, child2):
if not self._match_nodes(c1, c2, mapping):
return False
else:
if type(child1) is not type(child2) or child1 != child2:
return False
return True
def _get_children(self, node):
"""Return not `ast.expr_context` children of `node`"""
children = ast.get_children(node)
return [child for child in children if not isinstance(child, ast.expr_context)]
def _match_stmts(self, current_stmts, mapping):
if len(current_stmts) != len(self.pattern):
return False
for stmt, expected in zip(current_stmts, self.pattern):
if not self._match_nodes(expected, stmt, mapping):
return False
return True
def _match_wildcard(self, node1, node2, mapping):
name = self.ropevar.get_base(node1.id)
if name not in mapping:
if self.matches_callback(node2, name):
mapping[name] = node2
return True
return False
else:
return self._match_nodes(mapping[name], node2, {})
class Match(object):
def __init__(self, mapping):
self.mapping = mapping
def get_region(self):
"""Returns match region"""
def get_ast(self, name):
"""Return the ast node that has matched rope variables"""
return self.mapping.get(name, None)
class ExpressionMatch(Match):
def __init__(self, ast, mapping):
super(ExpressionMatch, self).__init__(mapping)
self.ast = ast
def get_region(self):
return self.ast.region
class StatementMatch(Match):
def __init__(self, ast_list, mapping):
super(StatementMatch, self).__init__(mapping)
self.ast_list = ast_list
def get_region(self):
return self.ast_list[0].region[0], self.ast_list[-1].region[1]
class CodeTemplate(object):
def __init__(self, template):
self.template = template
self._find_names()
def _find_names(self):
self.names = {}
for match in CodeTemplate._get_pattern().finditer(self.template):
if "name" in match.groupdict() and match.group("name") is not None:
start, end = match.span("name")
name = self.template[start + 2 : end - 1]
if name not in self.names:
self.names[name] = []
self.names[name].append((start, end))
def get_names(self):
return self.names.keys()
def substitute(self, mapping):
collector = codeanalyze.ChangeCollector(self.template)
for name, occurrences in self.names.items():
for region in occurrences:
collector.add_change(region[0], region[1], mapping[name])
result = collector.get_changed()
if result is None:
return self.template
return result
_match_pattern = None
@classmethod
def _get_pattern(cls):
if cls._match_pattern is None:
pattern = (
codeanalyze.get_comment_pattern()
+ "|"
+ codeanalyze.get_string_pattern()
+ "|"
+ r"(?P<name>\$\{[^\s\$\}]*\})"
)
cls._match_pattern = re.compile(pattern)
return cls._match_pattern
class _RopeVariable(object):
"""Transform and identify rope inserted wildcards"""
_normal_prefix = "__rope__variable_normal_"
_any_prefix = "__rope__variable_any_"
def get_var(self, name):
if name.startswith("?"):
return self._get_any(name)
else:
return self._get_normal(name)
def is_var(self, name):
return self._is_normal(name) or self._is_var(name)
def get_base(self, name):
if self._is_normal(name):
return name[len(self._normal_prefix) :]
if self._is_var(name):
return "?" + name[len(self._any_prefix) :]
def _get_normal(self, name):
return self._normal_prefix + name
def _get_any(self, name):
return self._any_prefix + name[1:]
def _is_normal(self, name):
return name.startswith(self._normal_prefix)
def _is_var(self, name):
return name.startswith(self._any_prefix)
def make_pattern(code, variables):
variables = set(variables)
collector = codeanalyze.ChangeCollector(code)
def does_match(node, name):
return isinstance(node, ast.Name) and node.id == name
finder = RawSimilarFinder(code, does_match=does_match)
for variable in variables:
for match in finder.get_matches("${%s}" % variable):
start, end = match.get_region()
collector.add_change(start, end, "${%s}" % variable)
result = collector.get_changed()
return result if result is not None else code
def _pydefined_to_str(pydefined):
address = []
if isinstance(pydefined, (builtins.BuiltinClass, builtins.BuiltinFunction)):
return "__builtins__." + pydefined.get_name()
else:
while pydefined.parent is not None:
address.insert(0, pydefined.get_name())
pydefined = pydefined.parent
module_name = libutils.modname(pydefined.resource)
return ".".join(module_name.split(".") + address)

View file

@ -0,0 +1,93 @@
from rope.base import codeanalyze
def get_indents(lines, lineno):
return codeanalyze.count_line_indents(lines.get_line(lineno))
def find_minimum_indents(source_code):
result = 80
lines = source_code.split("\n")
for line in lines:
if line.strip() == "":
continue
result = min(result, codeanalyze.count_line_indents(line))
return result
def indent_lines(source_code, amount):
if amount == 0:
return source_code
lines = source_code.splitlines(True)
result = []
for l in lines:
if l.strip() == "":
result.append("\n")
continue
if amount < 0:
indents = codeanalyze.count_line_indents(l)
result.append(max(0, indents + amount) * " " + l.lstrip())
else:
result.append(" " * amount + l)
return "".join(result)
def fix_indentation(code, new_indents):
"""Change the indentation of `code` to `new_indents`"""
min_indents = find_minimum_indents(code)
return indent_lines(code, new_indents - min_indents)
def add_methods(pymodule, class_scope, methods_sources):
source_code = pymodule.source_code
lines = pymodule.lines
insertion_line = class_scope.get_end()
if class_scope.get_scopes():
insertion_line = class_scope.get_scopes()[-1].get_end()
insertion_offset = lines.get_line_end(insertion_line)
methods = "\n\n" + "\n\n".join(methods_sources)
indented_methods = fix_indentation(
methods,
get_indents(lines, class_scope.get_start())
+ get_indent(pymodule.pycore.project),
)
result = []
result.append(source_code[:insertion_offset])
result.append(indented_methods)
result.append(source_code[insertion_offset:])
return "".join(result)
def get_body(pyfunction):
"""Return unindented function body"""
# FIXME scope = pyfunction.get_scope()
pymodule = pyfunction.get_module()
start, end = get_body_region(pyfunction)
return fix_indentation(pymodule.source_code[start:end], 0)
def get_body_region(defined):
"""Return the start and end offsets of function body"""
scope = defined.get_scope()
pymodule = defined.get_module()
lines = pymodule.lines
node = defined.get_ast()
start_line = node.lineno
if defined.get_doc() is None:
start_line = node.body[0].lineno
elif len(node.body) > 1:
start_line = node.body[1].lineno
start = lines.get_line_start(start_line)
scope_start = pymodule.logical_lines.logical_line_in(scope.start)
if scope_start[1] >= start_line:
# a one-liner!
# XXX: what if colon appears in a string
start = pymodule.source_code.index(":", start) + 1
while pymodule.source_code[start].isspace():
start += 1
end = min(lines.get_line_end(scope.end) + 1, len(pymodule.source_code))
return start, end
def get_indent(project):
return project.prefs.get("indent_size", 4)

View file

@ -0,0 +1,171 @@
from itertools import chain
from rope.base import ast
from rope.base.utils import pycompat
def find_visible(node, lines):
"""Return the line which is visible from all `lines`"""
root = ast_suite_tree(node)
return find_visible_for_suite(root, lines)
def find_visible_for_suite(root, lines):
if len(lines) == 1:
return lines[0]
line1 = lines[0]
line2 = find_visible_for_suite(root, lines[1:])
suite1 = root.find_suite(line1)
suite2 = root.find_suite(line2)
def valid(suite):
return suite is not None and not suite.ignored
if valid(suite1) and not valid(suite2):
return line1
if not valid(suite1) and valid(suite2):
return line2
if not valid(suite1) and not valid(suite2):
return None
while suite1 != suite2 and suite1.parent != suite2.parent:
if suite1._get_level() < suite2._get_level():
line2 = suite2.get_start()
suite2 = suite2.parent
elif suite1._get_level() > suite2._get_level():
line1 = suite1.get_start()
suite1 = suite1.parent
else:
line1 = suite1.get_start()
line2 = suite2.get_start()
suite1 = suite1.parent
suite2 = suite2.parent
if suite1 == suite2:
return min(line1, line2)
return min(suite1.get_start(), suite2.get_start())
def ast_suite_tree(node):
if hasattr(node, "lineno"):
lineno = node.lineno
else:
lineno = 1
return Suite(node.body, lineno)
class Suite(object):
def __init__(self, child_nodes, lineno, parent=None, ignored=False):
self.parent = parent
self.lineno = lineno
self.child_nodes = child_nodes
self._children = None
self.ignored = ignored
def get_start(self):
if self.parent is None:
if self.child_nodes:
return self.local_start()
else:
return 1
return self.lineno
def get_children(self):
if self._children is None:
walker = _SuiteWalker(self)
for child in self.child_nodes:
ast.walk(child, walker)
self._children = walker.suites
return self._children
def local_start(self):
return self.child_nodes[0].lineno
def local_end(self):
end = self.child_nodes[-1].lineno
if self.get_children():
end = max(end, self.get_children()[-1].local_end())
return end
def find_suite(self, line):
if line is None:
return None
for child in self.get_children():
if child.local_start() <= line <= child.local_end():
return child.find_suite(line)
return self
def _get_level(self):
if self.parent is None:
return 0
return self.parent._get_level() + 1
class _SuiteWalker(object):
def __init__(self, suite):
self.suite = suite
self.suites = []
def _If(self, node):
self._add_if_like_node(node)
def _For(self, node):
self._add_if_like_node(node)
def _While(self, node):
self._add_if_like_node(node)
def _With(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite))
def _AsyncWith(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite))
def _Match(self, node):
case_bodies = list(
chain.from_iterable([[case.pattern] + case.body for case in node.cases])
)
self.suites.append(Suite(case_bodies, node.lineno, self.suite))
def _TryFinally(self, node):
proceed_to_except_handler = False
if len(node.finalbody) == 1:
if pycompat.PY2:
proceed_to_except_handler = isinstance(node.body[0], ast.TryExcept)
elif pycompat.PY3:
try:
proceed_to_except_handler = isinstance(
node.handlers[0], ast.ExceptHandler
)
except IndexError:
pass
if proceed_to_except_handler:
self._TryExcept(node if pycompat.PY3 else node.body[0])
else:
self.suites.append(Suite(node.body, node.lineno, self.suite))
self.suites.append(Suite(node.finalbody, node.lineno, self.suite))
def _Try(self, node):
if len(node.finalbody) == 1:
self._TryFinally(node)
else:
self._TryExcept(node)
def _TryExcept(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite))
for handler in node.handlers:
self.suites.append(Suite(handler.body, node.lineno, self.suite))
if node.orelse:
self.suites.append(Suite(node.orelse, node.lineno, self.suite))
def _add_if_like_node(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite))
if node.orelse:
self.suites.append(Suite(node.orelse, node.lineno, self.suite))
def _FunctionDef(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True))
def _AsyncFunctionDef(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True))
def _ClassDef(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True))

View file

@ -0,0 +1,29 @@
import rope.refactor.importutils
from rope.base.change import ChangeSet, ChangeContents, MoveResource, CreateFolder
class ModuleToPackage(object):
def __init__(self, project, resource):
self.project = project
self.resource = resource
def get_changes(self):
changes = ChangeSet("Transform <%s> module to package" % self.resource.path)
new_content = self._transform_relatives_to_absolute(self.resource)
if new_content is not None:
changes.add_change(ChangeContents(self.resource, new_content))
parent = self.resource.parent
name = self.resource.name[:-3]
changes.add_change(CreateFolder(parent, name))
parent_path = parent.path + "/"
if not parent.path:
parent_path = ""
new_path = parent_path + "%s/__init__.py" % name
if self.resource.project == self.project:
changes.add_change(MoveResource(self.resource, new_path))
return changes
def _transform_relatives_to_absolute(self, resource):
pymodule = self.project.get_pymodule(resource)
import_tools = rope.refactor.importutils.ImportTools(self.project)
return import_tools.relatives_to_absolutes(pymodule)

View file

@ -0,0 +1,194 @@
from rope.base import change, taskhandle, evaluate, exceptions, pyobjects, pynames, ast
from rope.base import libutils
from rope.refactor import restructure, sourceutils, similarfinder
class UseFunction(object):
"""Try to use a function wherever possible"""
def __init__(self, project, resource, offset):
self.project = project
self.offset = offset
this_pymodule = project.get_pymodule(resource)
pyname = evaluate.eval_location(this_pymodule, offset)
if pyname is None:
raise exceptions.RefactoringError("Unresolvable name selected")
self.pyfunction = pyname.get_object()
if not isinstance(self.pyfunction, pyobjects.PyFunction) or not isinstance(
self.pyfunction.parent, pyobjects.PyModule
):
raise exceptions.RefactoringError(
"Use function works for global functions, only."
)
self.resource = self.pyfunction.get_module().get_resource()
self._check_returns()
def _check_returns(self):
node = self.pyfunction.get_ast()
if _yield_count(node):
raise exceptions.RefactoringError(
"Use function should not be used on generatorS."
)
returns = _return_count(node)
if returns > 1:
raise exceptions.RefactoringError(
"usefunction: Function has more than one return statement."
)
if returns == 1 and not _returns_last(node):
raise exceptions.RefactoringError(
"usefunction: return should be the last statement."
)
def get_changes(self, resources=None, task_handle=taskhandle.NullTaskHandle()):
if resources is None:
resources = self.project.get_python_files()
changes = change.ChangeSet("Using function <%s>" % self.pyfunction.get_name())
if self.resource in resources:
newresources = list(resources)
newresources.remove(self.resource)
for c in self._restructure(newresources, task_handle).changes:
changes.add_change(c)
if self.resource in resources:
for c in self._restructure(
[self.resource], task_handle, others=False
).changes:
changes.add_change(c)
return changes
def get_function_name(self):
return self.pyfunction.get_name()
def _restructure(self, resources, task_handle, others=True):
pattern = self._make_pattern()
goal = self._make_goal(import_=others)
imports = None
if others:
imports = ["import %s" % self._module_name()]
body_region = sourceutils.get_body_region(self.pyfunction)
args_value = {"skip": (self.resource, body_region)}
args = {"": args_value}
restructuring = restructure.Restructure(
self.project, pattern, goal, args=args, imports=imports
)
return restructuring.get_changes(resources=resources, task_handle=task_handle)
def _find_temps(self):
return find_temps(self.project, self._get_body())
def _module_name(self):
return libutils.modname(self.resource)
def _make_pattern(self):
params = self.pyfunction.get_param_names()
body = self._get_body()
body = restructure.replace(body, "return", "pass")
wildcards = list(params)
wildcards.extend(self._find_temps())
if self._does_return():
if self._is_expression():
replacement = "${%s}" % self._rope_returned
else:
replacement = "%s = ${%s}" % (self._rope_result, self._rope_returned)
body = restructure.replace(
body, "return ${%s}" % self._rope_returned, replacement
)
wildcards.append(self._rope_result)
return similarfinder.make_pattern(body, wildcards)
def _get_body(self):
return sourceutils.get_body(self.pyfunction)
def _make_goal(self, import_=False):
params = self.pyfunction.get_param_names()
function_name = self.pyfunction.get_name()
if import_:
function_name = self._module_name() + "." + function_name
goal = "%s(%s)" % (function_name, ", ".join(("${%s}" % p) for p in params))
if self._does_return() and not self._is_expression():
goal = "${%s} = %s" % (self._rope_result, goal)
return goal
def _does_return(self):
body = self._get_body()
removed_return = restructure.replace(body, "return ${result}", "")
return removed_return != body
def _is_expression(self):
return len(self.pyfunction.get_ast().body) == 1
_rope_result = "_rope__result"
_rope_returned = "_rope__returned"
def find_temps(project, code):
code = "def f():\n" + sourceutils.indent_lines(code, 4)
pymodule = libutils.get_string_module(project, code)
result = []
function_scope = pymodule.get_scope().get_scopes()[0]
for name, pyname in function_scope.get_names().items():
if isinstance(pyname, pynames.AssignedName):
result.append(name)
return result
def _returns_last(node):
return node.body and isinstance(node.body[-1], ast.Return)
def _namedexpr_last(node):
if not hasattr(ast, "NamedExpr"): # python<3.8
return False
return (
bool(node.body)
and len(node.body) == 1
and isinstance(node.body[-1].value, ast.NamedExpr)
)
def _yield_count(node):
visitor = _ReturnOrYieldFinder()
visitor.start_walking(node)
return visitor.yields
def _return_count(node):
visitor = _ReturnOrYieldFinder()
visitor.start_walking(node)
return visitor.returns
def _named_expr_count(node):
visitor = _ReturnOrYieldFinder()
visitor.start_walking(node)
return visitor.named_expression
class _ReturnOrYieldFinder(object):
def __init__(self):
self.returns = 0
self.named_expression = 0
self.yields = 0
def _Return(self, node):
self.returns += 1
def _NamedExpr(self, node):
self.named_expression += 1
def _Yield(self, node):
self.yields += 1
def _FunctionDef(self, node):
pass
def _ClassDef(self, node):
pass
def start_walking(self, node):
nodes = [node]
if isinstance(node, ast.FunctionDef):
nodes = ast.get_child_nodes(node)
for child in nodes:
ast.walk(child, self)

View file

@ -0,0 +1,175 @@
from rope.base import ast, evaluate, builtins, pyobjects
from rope.refactor import patchedast, occurrences
class Wildcard(object):
def get_name(self):
"""Return the name of this wildcard"""
def matches(self, suspect, arg):
"""Return `True` if `suspect` matches this wildcard"""
class Suspect(object):
def __init__(self, pymodule, node, name):
self.name = name
self.pymodule = pymodule
self.node = node
class DefaultWildcard(object):
"""The default restructuring wildcard
The argument passed to this wildcard is in the
``key1=value1,key2=value2,...`` format. Possible keys are:
* name - for checking the reference
* type - for checking the type
* object - for checking the object
* instance - for checking types but similar to builtin isinstance
* exact - matching only occurrences with the same name as the wildcard
* unsure - matching unsure occurrences
"""
def __init__(self, project):
self.project = project
def get_name(self):
return "default"
def matches(self, suspect, arg=""):
args = parse_arg(arg)
if not self._check_exact(args, suspect):
return False
if not self._check_object(args, suspect):
return False
return True
def _check_object(self, args, suspect):
kind = None
expected = None
unsure = args.get("unsure", False)
for check in ["name", "object", "type", "instance"]:
if check in args:
kind = check
expected = args[check]
if expected is not None:
checker = _CheckObject(self.project, expected, kind, unsure=unsure)
return checker(suspect.pymodule, suspect.node)
return True
def _check_exact(self, args, suspect):
node = suspect.node
if args.get("exact"):
if not isinstance(node, ast.Name) or not node.id == suspect.name:
return False
else:
if not isinstance(node, ast.expr):
return False
return True
def parse_arg(arg):
if isinstance(arg, dict):
return arg
result = {}
tokens = arg.split(",")
for token in tokens:
if "=" in token:
parts = token.split("=", 1)
result[parts[0].strip()] = parts[1].strip()
else:
result[token.strip()] = True
return result
class _CheckObject(object):
def __init__(self, project, expected, kind="object", unsure=False):
self.project = project
self.kind = kind
self.unsure = unsure
self.expected = self._evaluate(expected)
def __call__(self, pymodule, node):
pyname = self._evaluate_node(pymodule, node)
if pyname is None or self.expected is None:
return self.unsure
if self._unsure_pyname(pyname, unbound=self.kind == "name"):
return True
if self.kind == "name":
return self._same_pyname(self.expected, pyname)
else:
pyobject = pyname.get_object()
if self.kind == "object":
objects = [pyobject]
if self.kind == "type":
objects = [pyobject.get_type()]
if self.kind == "instance":
objects = [pyobject]
objects.extend(self._get_super_classes(pyobject))
objects.extend(self._get_super_classes(pyobject.get_type()))
for pyobject in objects:
if self._same_pyobject(self.expected.get_object(), pyobject):
return True
return False
def _get_super_classes(self, pyobject):
result = []
if isinstance(pyobject, pyobjects.AbstractClass):
for superclass in pyobject.get_superclasses():
result.append(superclass)
result.extend(self._get_super_classes(superclass))
return result
def _same_pyobject(self, expected, pyobject):
return expected == pyobject
def _same_pyname(self, expected, pyname):
return occurrences.same_pyname(expected, pyname)
def _unsure_pyname(self, pyname, unbound=True):
return self.unsure and occurrences.unsure_pyname(pyname, unbound)
def _split_name(self, name):
parts = name.split(".")
expression, kind = parts[0], parts[-1]
if len(parts) == 1:
kind = "name"
return expression, kind
def _evaluate_node(self, pymodule, node):
scope = pymodule.get_scope().get_inner_scope_for_line(node.lineno)
expression = node
if isinstance(expression, ast.Name) and isinstance(expression.ctx, ast.Store):
start, end = patchedast.node_region(expression)
text = pymodule.source_code[start:end]
return evaluate.eval_str(scope, text)
else:
return evaluate.eval_node(scope, expression)
def _evaluate(self, code):
attributes = code.split(".")
pyname = None
if attributes[0] in ("__builtin__", "__builtins__"):
class _BuiltinsStub(object):
def get_attribute(self, name):
return builtins.builtins[name]
def __getitem__(self, name):
return builtins.builtins[name]
def __contains__(self, name):
return name in builtins.builtins
pyobject = _BuiltinsStub()
else:
pyobject = self.project.get_module(attributes[0])
for attribute in attributes[1:]:
pyname = pyobject[attribute]
if pyname is None:
return None
pyobject = pyname.get_object()
return pyname