This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,631 @@
# pyflyby/_importclns.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from collections import defaultdict
from functools import total_ordering
import six
from pyflyby._flags import CompilerFlags
from pyflyby._idents import dotted_prefixes, is_identifier
from pyflyby._importstmt import (Import, ImportFormatParams,
ImportStatement,
NonImportStatementError)
from pyflyby._parse import PythonBlock
from pyflyby._util import (cached_attribute, cmp, partition,
stable_unique)
class NoSuchImportError(ValueError):
pass
class ConflictingImportsError(ValueError):
pass
@total_ordering
class ImportSet(object):
r"""
Representation of a set of imports organized into import statements.
>>> ImportSet('''
... from m1 import f1
... from m2 import f1
... from m1 import f2
... import m3.m4 as m34
... ''')
ImportSet('''
from m1 import f1, f2
from m2 import f1
from m3 import m4 as m34
''')
An ``ImportSet`` is an immutable data structure.
"""
def __new__(cls, arg, ignore_nonimports=False, ignore_shadowed=False):
"""
Return as an `ImportSet`.
:param ignore_nonimports:
If ``False``, complain about non-imports. If ``True``, ignore
non-imports.
:param ignore_shadowed:
Whether to ignore shadowed imports. If ``False``, then keep all
unique imports, even if they shadow each other. Note that an
``ImportSet`` is unordered; an ``ImportSet`` with conflicts will only
be useful for very specific cases (e.g. set of imports to forget
from known-imports database), and not useful for outputting as code.
If ``ignore_shadowed`` is ``True``, then earlier shadowed imports are
ignored.
:rtype:
`ImportSet`
"""
if isinstance(arg, cls):
if ignore_shadowed:
return cls._from_imports(arg._importset, ignore_shadowed=True)
else:
return arg
return cls._from_args(
arg,
ignore_nonimports=ignore_nonimports,
ignore_shadowed=ignore_shadowed)
@classmethod
def _from_imports(cls, imports, ignore_shadowed=False):
"""
:type imports:
Sequence of `Import` s
:param ignore_shadowed:
See `ImportSet.__new__`.
:rtype:
`ImportSet`
"""
# Canonicalize inputs.
imports = [Import(imp) for imp in imports]
if ignore_shadowed:
# Filter by overshadowed imports. Later imports take precedence.
by_import_as = {}
for imp in imports:
if imp.import_as == "*":
# Keep all unique star imports.
by_import_as[imp] = imp
else:
by_import_as[imp.import_as] = imp
filtered_imports = by_import_as.values()
else:
filtered_imports = imports
# Construct and return.
self = object.__new__(cls)
self._importset = frozenset(filtered_imports)
return self
@classmethod
def _from_args(cls, args, ignore_nonimports=False, ignore_shadowed=False):
"""
:type args:
``tuple`` or ``list`` of `ImportStatement` s, `PythonStatement` s,
`PythonBlock` s, `FileText`, and/or `Filename` s
:param ignore_nonimports:
If ``False``, complain about non-imports. If ``True``, ignore
non-imports.
:param ignore_shadowed:
See `ImportSet.__new__`.
:rtype:
`ImportSet`
"""
if not isinstance(args, (tuple, list)):
args = [args]
# Filter empty arguments to allow the subsequent optimizations to work
# more often.
args = [a for a in args if a]
if not args:
return cls._EMPTY
# If we only got one ``ImportSet``, just return it.
if len(args) == 1 and type(args[0]) is cls and not ignore_shadowed:
return args[0]
# Collect all `Import` s from arguments.
imports = []
for arg in args:
if isinstance(arg, Import):
imports.append(arg)
elif isinstance(arg, ImportSet):
imports.extend(arg.imports)
elif isinstance(arg, ImportStatement):
imports.extend(arg.imports)
elif isinstance(arg, str) and is_identifier(arg, dotted=True):
imports.append(Import(arg))
else: # PythonBlock, PythonStatement, Filename, FileText, str
block = PythonBlock(arg)
for statement in block.statements:
# Ignore comments/blanks.
if statement.is_comment_or_blank:
pass
elif statement.is_import:
imports.extend(ImportStatement(statement).imports)
elif ignore_nonimports:
pass
else:
raise NonImportStatementError(
"Got non-import statement %r" % (statement,))
return cls._from_imports(imports, ignore_shadowed=ignore_shadowed)
def with_imports(self, other):
"""
Return a new `ImportSet` that is the union of ``self`` and
``new_imports``.
>>> impset = ImportSet('from m import t1, t2, t3')
>>> impset.with_imports('import m.t2a as t2b')
ImportSet('''
from m import t1, t2, t2a as t2b, t3
''')
:type other:
`ImportSet` (or convertible)
:rtype:
`ImportSet`
"""
other = ImportSet(other)
return type(self)._from_imports(self._importset | other._importset)
def without_imports(self, removals):
"""
Return a copy of self without the given imports.
>>> imports = ImportSet('from m import t1, t2, t3, t4')
>>> imports.without_imports(['from m import t3'])
ImportSet('''
from m import t1, t2, t4
''')
:type removals:
`ImportSet` (or convertible)
:rtype:
`ImportSet`
"""
removals = ImportSet(removals)
if not removals:
return self # Optimization
# Preprocess star imports to remove.
star_module_removals = set(
[imp.split.module_name
for imp in removals if imp.split.member_name == "*"])
# Filter imports.
new_imports = []
for imp in self:
if imp in removals:
continue
if star_module_removals and imp.split.module_name:
prefixes = dotted_prefixes(imp.split.module_name)
if any(pfx in star_module_removals for pfx in prefixes):
continue
new_imports.append(imp)
# Return.
if len(new_imports) == len(self):
return self # Space optimization
return type(self)._from_imports(new_imports)
@cached_attribute
def _by_module_name(self):
"""
:return:
(mapping from name to __future__ imports,
mapping from name to non-'from' imports,
mapping from name to 'from' imports)
"""
ftr_imports = defaultdict(set)
pkg_imports = defaultdict(set)
frm_imports = defaultdict(set)
for imp in self._importset:
module_name, member_name, import_as = imp.split
if module_name is None:
pkg_imports[member_name].add(imp)
elif module_name == '__future__':
ftr_imports[module_name].add(imp)
else:
frm_imports[module_name].add(imp)
return tuple(
dict( (k, frozenset(v))
for k, v in six.iteritems(imports))
for imports in [ftr_imports, pkg_imports, frm_imports])
def get_statements(self, separate_from_imports=True):
"""
Canonicalized `ImportStatement` s.
These have been merged by module and sorted.
>>> importset = ImportSet('''
... import a, b as B, c, d.dd as DD
... from __future__ import division
... from _hello import there
... from _hello import *
... from _hello import world
... ''')
>>> for s in importset.get_statements(): print(s)
from __future__ import division
import a
import b as B
import c
from _hello import *
from _hello import there, world
from d import dd as DD
:rtype:
``tuple`` of `ImportStatement` s
"""
groups = self._by_module_name
if not separate_from_imports:
def union_dicts(*dicts):
result = {}
for label, dict in enumerate(dicts):
for k, v in six.iteritems(dict):
result[(k, label)] = v
return result
groups = [groups[0], union_dicts(*groups[1:])]
result = []
for importgroup in groups:
for _, imports in sorted(importgroup.items()):
star_imports, nonstar_imports = (
partition(imports, lambda imp: imp.import_as == "*"))
assert len(star_imports) <= 1
if star_imports:
result.append(ImportStatement(star_imports))
if nonstar_imports:
result.append(ImportStatement(sorted(nonstar_imports)))
return tuple(result)
@cached_attribute
def statements(self):
"""
Canonicalized `ImportStatement` s.
These have been merged by module and sorted.
:rtype:
``tuple`` of `ImportStatement` s
"""
return self.get_statements(separate_from_imports=True)
@cached_attribute
def imports(self):
"""
Canonicalized imports, in the same order as ``self.statements``.
:rtype:
``tuple`` of `Import` s
"""
return tuple(
imp
for importgroup in self._by_module_name
for _, imports in sorted(importgroup.items())
for imp in sorted(imports))
@cached_attribute
def by_import_as(self):
"""
Map from ``import_as`` to `Import`.
>>> ImportSet('from aa.bb import cc as dd').by_import_as
{'dd': (Import('from aa.bb import cc as dd'),)}
:rtype:
``dict`` mapping from ``str`` to tuple of `Import` s
"""
d = defaultdict(list)
for imp in self._importset:
d[imp.import_as].append(imp)
return dict( (k, tuple(sorted(stable_unique(v))))
for k, v in six.iteritems(d) )
@cached_attribute
def member_names(self):
r"""
Map from parent module/package ``fullname`` to known member names.
>>> impset = ImportSet("import numpy.linalg.info\nfrom sys import exit as EXIT")
>>> import pprint
>>> pprint.pprint(impset.member_names)
{'': ('EXIT', 'numpy', 'sys'),
'numpy': ('linalg',),
'numpy.linalg': ('info',),
'sys': ('exit',)}
This is used by the autoimporter module for implementing tab completion.
:rtype:
``dict`` mapping from ``str`` to tuple of ``str``
"""
d = defaultdict(set)
for imp in self._importset:
if '.' not in imp.import_as:
d[""].add(imp.import_as)
prefixes = dotted_prefixes(imp.fullname)
d[""].add(prefixes[0])
for prefix in prefixes[1:]:
splt = prefix.rsplit(".", 1)
d[splt[0]].add(splt[1])
return dict( (k, tuple(sorted(v)))
for k, v in six.iteritems(d) )
@cached_attribute
def conflicting_imports(self):
r"""
Returns imports that conflict with each other.
>>> ImportSet('import b\nfrom f import a as b\n').conflicting_imports
('b',)
>>> ImportSet('import b\nfrom f import a\n').conflicting_imports
()
:rtype:
``bool``
"""
return tuple(
k
for k, v in six.iteritems(self.by_import_as)
if len(v) > 1 and k != "*")
@cached_attribute
def flags(self):
"""
If this contains __future__ imports, then the bitwise-ORed of the
compiler_flag values associated with the features. Otherwise, 0.
"""
imports = self._by_module_name[0].get("__future__", [])
return CompilerFlags(*[imp.flags for imp in imports])
def __repr__(self):
printed = self.pretty_print(allow_conflicts=True)
lines = "".join(" "+line for line in printed.splitlines(True))
return "%s('''\n%s''')" % (type(self).__name__, lines)
def pretty_print(self, params=None, allow_conflicts=False):
"""
Pretty-print a block of import statements into a single string.
:type params:
`ImportFormatParams`
:rtype:
``str``
"""
params = ImportFormatParams(params)
# TODO: instead of complaining about conflicts, just filter out the
# shadowed imports at construction time.
if not allow_conflicts and self.conflicting_imports:
raise ConflictingImportsError(
"Refusing to pretty-print because of conflicting imports: " +
'; '.join(
"%r imported as %r" % (
[imp.fullname for imp in self.by_import_as[i]], i)
for i in self.conflicting_imports))
from_spaces = max(1, params.from_spaces)
def do_align(statement):
return statement.fromname != '__future__' or params.align_future
def pp(statement, import_column):
if do_align(statement):
return statement.pretty_print(
params=params, import_column=import_column,
from_spaces=from_spaces)
else:
return statement.pretty_print(
params=params, import_column=None, from_spaces=1)
statements = self.get_statements(
separate_from_imports=params.separate_from_imports)
def isint(x): return isinstance(x, int) and not isinstance(x, bool)
if not statements:
import_column = None
elif isinstance(params.align_imports, bool):
if params.align_imports:
fromimp_stmts = [
s for s in statements if s.fromname and do_align(s)]
if fromimp_stmts:
import_column = (
max(len(s.fromname) for s in fromimp_stmts)
+ from_spaces + 5)
else:
import_column = None
else:
import_column = None
elif isinstance(params.align_imports, int):
import_column = params.align_imports
elif isinstance(params.align_imports, (tuple, list, set)):
# If given a set of candidate alignment columns, then try each
# alignment column and pick the one that yields the fewest number
# of output lines.
if not all(isinstance(x, int) for x in params.align_imports):
raise TypeError("expected set of integers; got %r"
% (params.align_imports,))
candidates = sorted(set(params.align_imports))
if len(candidates) == 0:
raise ValueError("list of zero candidate alignment columns specified")
elif len(candidates) == 1:
# Optimization.
import_column = next(iter(candidates))
else:
def argmin(map):
items = iter(sorted(map.items()))
min_k, min_v = next(items)
for k, v in items:
if v < min_v:
min_k = k
min_v = v
return min_k
def count_lines(import_column):
return sum(
s.pretty_print(
params=params, import_column=import_column,
from_spaces=from_spaces).count("\n")
for s in statements)
# Construct a map from alignment column to total number of
# lines.
col2length = dict((c, count_lines(c)) for c in candidates)
# Pick the column that yields the fewest lines. Break ties by
# picking the smaller column.
import_column = argmin(col2length)
else:
raise TypeError(
"ImportSet.pretty_print(): unexpected params.align_imports type %s"
% (type(params.align_imports).__name__,))
return ''.join(pp(statement, import_column) for statement in statements)
def __contains__(self, x):
return x in self._importset
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, ImportSet):
return NotImplemented
return self._importset == other._importset
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, other):
if not isinstance(other, ImportSet):
return NotImplemented
return self._importset < other._importset
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, ImportSet):
return NotImplemented
return cmp(self._importset, other._importset)
def __hash__(self):
return hash(self._importset)
def __len__(self):
return len(self.imports)
def __iter__(self):
return iter(self.imports)
ImportSet._EMPTY = ImportSet._from_imports([])
@total_ordering
class ImportMap(object):
r"""
A map from import fullname identifier to fullname identifier.
>>> ImportMap({'a.b': 'aa.bb', 'a.b.c': 'aa.bb.cc'})
ImportMap({'a.b': 'aa.bb', 'a.b.c': 'aa.bb.cc'})
An ``ImportMap`` is an immutable data structure.
"""
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, (tuple, list)):
return cls._merge(arg)
if isinstance(arg, dict):
if not len(arg):
return cls._EMPTY
return cls._from_map(arg)
else:
raise TypeError("ImportMap: expected a dict, not a %s"
% (type(arg).__name__,))
@classmethod
def _from_map(cls, arg):
data = dict((Import(k).fullname, Import(v).fullname)
for k, v in arg.items())
self = object.__new__(cls)
self._data = data
return self
@classmethod
def _merge(cls, maps):
maps = [cls(m) for m in maps]
maps = [m for m in maps if m]
if not maps:
return cls._EMPTY
data = {}
for map in maps:
data.update(map._data)
return cls(data)
def __getitem__(self, k):
k = Import(k).fullname
return self._data.__getitem__(k)
def __iter__(self):
return iter(self._data)
def items(self):
return self._data.items()
def iteritems(self):
return six.iteritems(self._data)
def iterkeys(self):
return six.iterkeys(self._data)
def keys(self):
return self._data.keys()
def values(self):
return self._data.values()
def __len__(self):
return len(self._data)
def without_imports(self, removals):
"""
Return a copy of self without the given imports.
Matches both keys and values.
"""
removals = ImportSet(removals)
if not removals:
return self # Optimization
cls = type(self)
result = [(k, v) for k, v in self._data.items()
if Import(k) not in removals and Import(v) not in removals]
if len(result) == len(self._data):
return self # Space optimization
return cls(dict(result))
def __repr__(self):
s = ", ".join("%r: %r" % (k,v) for k,v in sorted(self.items()))
return "ImportMap({%s})" % s
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, ImportMap):
return NotImplemented
return self._data == other._data
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, other):
if not isinstance(other, ImportMap):
return NotImplemented
return self._data < other._data
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, ImportMap):
return NotImplemented
return cmp(self._data, other._data)
def __hash__(self):
h = hash(self._data)
self.__hash__ = lambda: h
return h
ImportMap._EMPTY = ImportMap._from_map({})