This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,194 @@
from rope.base import change, taskhandle, evaluate, exceptions, pyobjects, pynames, ast
from rope.base import libutils
from rope.refactor import restructure, sourceutils, similarfinder
class UseFunction(object):
"""Try to use a function wherever possible"""
def __init__(self, project, resource, offset):
self.project = project
self.offset = offset
this_pymodule = project.get_pymodule(resource)
pyname = evaluate.eval_location(this_pymodule, offset)
if pyname is None:
raise exceptions.RefactoringError("Unresolvable name selected")
self.pyfunction = pyname.get_object()
if not isinstance(self.pyfunction, pyobjects.PyFunction) or not isinstance(
self.pyfunction.parent, pyobjects.PyModule
):
raise exceptions.RefactoringError(
"Use function works for global functions, only."
)
self.resource = self.pyfunction.get_module().get_resource()
self._check_returns()
def _check_returns(self):
node = self.pyfunction.get_ast()
if _yield_count(node):
raise exceptions.RefactoringError(
"Use function should not be used on generatorS."
)
returns = _return_count(node)
if returns > 1:
raise exceptions.RefactoringError(
"usefunction: Function has more than one return statement."
)
if returns == 1 and not _returns_last(node):
raise exceptions.RefactoringError(
"usefunction: return should be the last statement."
)
def get_changes(self, resources=None, task_handle=taskhandle.NullTaskHandle()):
if resources is None:
resources = self.project.get_python_files()
changes = change.ChangeSet("Using function <%s>" % self.pyfunction.get_name())
if self.resource in resources:
newresources = list(resources)
newresources.remove(self.resource)
for c in self._restructure(newresources, task_handle).changes:
changes.add_change(c)
if self.resource in resources:
for c in self._restructure(
[self.resource], task_handle, others=False
).changes:
changes.add_change(c)
return changes
def get_function_name(self):
return self.pyfunction.get_name()
def _restructure(self, resources, task_handle, others=True):
pattern = self._make_pattern()
goal = self._make_goal(import_=others)
imports = None
if others:
imports = ["import %s" % self._module_name()]
body_region = sourceutils.get_body_region(self.pyfunction)
args_value = {"skip": (self.resource, body_region)}
args = {"": args_value}
restructuring = restructure.Restructure(
self.project, pattern, goal, args=args, imports=imports
)
return restructuring.get_changes(resources=resources, task_handle=task_handle)
def _find_temps(self):
return find_temps(self.project, self._get_body())
def _module_name(self):
return libutils.modname(self.resource)
def _make_pattern(self):
params = self.pyfunction.get_param_names()
body = self._get_body()
body = restructure.replace(body, "return", "pass")
wildcards = list(params)
wildcards.extend(self._find_temps())
if self._does_return():
if self._is_expression():
replacement = "${%s}" % self._rope_returned
else:
replacement = "%s = ${%s}" % (self._rope_result, self._rope_returned)
body = restructure.replace(
body, "return ${%s}" % self._rope_returned, replacement
)
wildcards.append(self._rope_result)
return similarfinder.make_pattern(body, wildcards)
def _get_body(self):
return sourceutils.get_body(self.pyfunction)
def _make_goal(self, import_=False):
params = self.pyfunction.get_param_names()
function_name = self.pyfunction.get_name()
if import_:
function_name = self._module_name() + "." + function_name
goal = "%s(%s)" % (function_name, ", ".join(("${%s}" % p) for p in params))
if self._does_return() and not self._is_expression():
goal = "${%s} = %s" % (self._rope_result, goal)
return goal
def _does_return(self):
body = self._get_body()
removed_return = restructure.replace(body, "return ${result}", "")
return removed_return != body
def _is_expression(self):
return len(self.pyfunction.get_ast().body) == 1
_rope_result = "_rope__result"
_rope_returned = "_rope__returned"
def find_temps(project, code):
code = "def f():\n" + sourceutils.indent_lines(code, 4)
pymodule = libutils.get_string_module(project, code)
result = []
function_scope = pymodule.get_scope().get_scopes()[0]
for name, pyname in function_scope.get_names().items():
if isinstance(pyname, pynames.AssignedName):
result.append(name)
return result
def _returns_last(node):
return node.body and isinstance(node.body[-1], ast.Return)
def _namedexpr_last(node):
if not hasattr(ast, "NamedExpr"): # python<3.8
return False
return (
bool(node.body)
and len(node.body) == 1
and isinstance(node.body[-1].value, ast.NamedExpr)
)
def _yield_count(node):
visitor = _ReturnOrYieldFinder()
visitor.start_walking(node)
return visitor.yields
def _return_count(node):
visitor = _ReturnOrYieldFinder()
visitor.start_walking(node)
return visitor.returns
def _named_expr_count(node):
visitor = _ReturnOrYieldFinder()
visitor.start_walking(node)
return visitor.named_expression
class _ReturnOrYieldFinder(object):
def __init__(self):
self.returns = 0
self.named_expression = 0
self.yields = 0
def _Return(self, node):
self.returns += 1
def _NamedExpr(self, node):
self.named_expression += 1
def _Yield(self, node):
self.yields += 1
def _FunctionDef(self, node):
pass
def _ClassDef(self, node):
pass
def start_walking(self, node):
nodes = [node]
if isinstance(node, ast.FunctionDef):
nodes = ast.get_child_nodes(node)
for child in nodes:
ast.walk(child, self)