# pyflyby/_importclns.py. # Copyright (C) 2011, 2012, 2013, 2014 Karl Chen. # License: MIT http://opensource.org/licenses/MIT from __future__ import (absolute_import, division, print_function, with_statement) from collections import defaultdict from functools import total_ordering import six from pyflyby._flags import CompilerFlags from pyflyby._idents import dotted_prefixes, is_identifier from pyflyby._importstmt import (Import, ImportFormatParams, ImportStatement, NonImportStatementError) from pyflyby._parse import PythonBlock from pyflyby._util import (cached_attribute, cmp, partition, stable_unique) class NoSuchImportError(ValueError): pass class ConflictingImportsError(ValueError): pass @total_ordering class ImportSet(object): r""" Representation of a set of imports organized into import statements. >>> ImportSet(''' ... from m1 import f1 ... from m2 import f1 ... from m1 import f2 ... import m3.m4 as m34 ... ''') ImportSet(''' from m1 import f1, f2 from m2 import f1 from m3 import m4 as m34 ''') An ``ImportSet`` is an immutable data structure. """ def __new__(cls, arg, ignore_nonimports=False, ignore_shadowed=False): """ Return as an `ImportSet`. :param ignore_nonimports: If ``False``, complain about non-imports. If ``True``, ignore non-imports. :param ignore_shadowed: Whether to ignore shadowed imports. If ``False``, then keep all unique imports, even if they shadow each other. Note that an ``ImportSet`` is unordered; an ``ImportSet`` with conflicts will only be useful for very specific cases (e.g. set of imports to forget from known-imports database), and not useful for outputting as code. If ``ignore_shadowed`` is ``True``, then earlier shadowed imports are ignored. :rtype: `ImportSet` """ if isinstance(arg, cls): if ignore_shadowed: return cls._from_imports(arg._importset, ignore_shadowed=True) else: return arg return cls._from_args( arg, ignore_nonimports=ignore_nonimports, ignore_shadowed=ignore_shadowed) @classmethod def _from_imports(cls, imports, ignore_shadowed=False): """ :type imports: Sequence of `Import` s :param ignore_shadowed: See `ImportSet.__new__`. :rtype: `ImportSet` """ # Canonicalize inputs. imports = [Import(imp) for imp in imports] if ignore_shadowed: # Filter by overshadowed imports. Later imports take precedence. by_import_as = {} for imp in imports: if imp.import_as == "*": # Keep all unique star imports. by_import_as[imp] = imp else: by_import_as[imp.import_as] = imp filtered_imports = by_import_as.values() else: filtered_imports = imports # Construct and return. self = object.__new__(cls) self._importset = frozenset(filtered_imports) return self @classmethod def _from_args(cls, args, ignore_nonimports=False, ignore_shadowed=False): """ :type args: ``tuple`` or ``list`` of `ImportStatement` s, `PythonStatement` s, `PythonBlock` s, `FileText`, and/or `Filename` s :param ignore_nonimports: If ``False``, complain about non-imports. If ``True``, ignore non-imports. :param ignore_shadowed: See `ImportSet.__new__`. :rtype: `ImportSet` """ if not isinstance(args, (tuple, list)): args = [args] # Filter empty arguments to allow the subsequent optimizations to work # more often. args = [a for a in args if a] if not args: return cls._EMPTY # If we only got one ``ImportSet``, just return it. if len(args) == 1 and type(args[0]) is cls and not ignore_shadowed: return args[0] # Collect all `Import` s from arguments. imports = [] for arg in args: if isinstance(arg, Import): imports.append(arg) elif isinstance(arg, ImportSet): imports.extend(arg.imports) elif isinstance(arg, ImportStatement): imports.extend(arg.imports) elif isinstance(arg, str) and is_identifier(arg, dotted=True): imports.append(Import(arg)) else: # PythonBlock, PythonStatement, Filename, FileText, str block = PythonBlock(arg) for statement in block.statements: # Ignore comments/blanks. if statement.is_comment_or_blank: pass elif statement.is_import: imports.extend(ImportStatement(statement).imports) elif ignore_nonimports: pass else: raise NonImportStatementError( "Got non-import statement %r" % (statement,)) return cls._from_imports(imports, ignore_shadowed=ignore_shadowed) def with_imports(self, other): """ Return a new `ImportSet` that is the union of ``self`` and ``new_imports``. >>> impset = ImportSet('from m import t1, t2, t3') >>> impset.with_imports('import m.t2a as t2b') ImportSet(''' from m import t1, t2, t2a as t2b, t3 ''') :type other: `ImportSet` (or convertible) :rtype: `ImportSet` """ other = ImportSet(other) return type(self)._from_imports(self._importset | other._importset) def without_imports(self, removals): """ Return a copy of self without the given imports. >>> imports = ImportSet('from m import t1, t2, t3, t4') >>> imports.without_imports(['from m import t3']) ImportSet(''' from m import t1, t2, t4 ''') :type removals: `ImportSet` (or convertible) :rtype: `ImportSet` """ removals = ImportSet(removals) if not removals: return self # Optimization # Preprocess star imports to remove. star_module_removals = set( [imp.split.module_name for imp in removals if imp.split.member_name == "*"]) # Filter imports. new_imports = [] for imp in self: if imp in removals: continue if star_module_removals and imp.split.module_name: prefixes = dotted_prefixes(imp.split.module_name) if any(pfx in star_module_removals for pfx in prefixes): continue new_imports.append(imp) # Return. if len(new_imports) == len(self): return self # Space optimization return type(self)._from_imports(new_imports) @cached_attribute def _by_module_name(self): """ :return: (mapping from name to __future__ imports, mapping from name to non-'from' imports, mapping from name to 'from' imports) """ ftr_imports = defaultdict(set) pkg_imports = defaultdict(set) frm_imports = defaultdict(set) for imp in self._importset: module_name, member_name, import_as = imp.split if module_name is None: pkg_imports[member_name].add(imp) elif module_name == '__future__': ftr_imports[module_name].add(imp) else: frm_imports[module_name].add(imp) return tuple( dict( (k, frozenset(v)) for k, v in six.iteritems(imports)) for imports in [ftr_imports, pkg_imports, frm_imports]) def get_statements(self, separate_from_imports=True): """ Canonicalized `ImportStatement` s. These have been merged by module and sorted. >>> importset = ImportSet(''' ... import a, b as B, c, d.dd as DD ... from __future__ import division ... from _hello import there ... from _hello import * ... from _hello import world ... ''') >>> for s in importset.get_statements(): print(s) from __future__ import division import a import b as B import c from _hello import * from _hello import there, world from d import dd as DD :rtype: ``tuple`` of `ImportStatement` s """ groups = self._by_module_name if not separate_from_imports: def union_dicts(*dicts): result = {} for label, dict in enumerate(dicts): for k, v in six.iteritems(dict): result[(k, label)] = v return result groups = [groups[0], union_dicts(*groups[1:])] result = [] for importgroup in groups: for _, imports in sorted(importgroup.items()): star_imports, nonstar_imports = ( partition(imports, lambda imp: imp.import_as == "*")) assert len(star_imports) <= 1 if star_imports: result.append(ImportStatement(star_imports)) if nonstar_imports: result.append(ImportStatement(sorted(nonstar_imports))) return tuple(result) @cached_attribute def statements(self): """ Canonicalized `ImportStatement` s. These have been merged by module and sorted. :rtype: ``tuple`` of `ImportStatement` s """ return self.get_statements(separate_from_imports=True) @cached_attribute def imports(self): """ Canonicalized imports, in the same order as ``self.statements``. :rtype: ``tuple`` of `Import` s """ return tuple( imp for importgroup in self._by_module_name for _, imports in sorted(importgroup.items()) for imp in sorted(imports)) @cached_attribute def by_import_as(self): """ Map from ``import_as`` to `Import`. >>> ImportSet('from aa.bb import cc as dd').by_import_as {'dd': (Import('from aa.bb import cc as dd'),)} :rtype: ``dict`` mapping from ``str`` to tuple of `Import` s """ d = defaultdict(list) for imp in self._importset: d[imp.import_as].append(imp) return dict( (k, tuple(sorted(stable_unique(v)))) for k, v in six.iteritems(d) ) @cached_attribute def member_names(self): r""" Map from parent module/package ``fullname`` to known member names. >>> impset = ImportSet("import numpy.linalg.info\nfrom sys import exit as EXIT") >>> import pprint >>> pprint.pprint(impset.member_names) {'': ('EXIT', 'numpy', 'sys'), 'numpy': ('linalg',), 'numpy.linalg': ('info',), 'sys': ('exit',)} This is used by the autoimporter module for implementing tab completion. :rtype: ``dict`` mapping from ``str`` to tuple of ``str`` """ d = defaultdict(set) for imp in self._importset: if '.' not in imp.import_as: d[""].add(imp.import_as) prefixes = dotted_prefixes(imp.fullname) d[""].add(prefixes[0]) for prefix in prefixes[1:]: splt = prefix.rsplit(".", 1) d[splt[0]].add(splt[1]) return dict( (k, tuple(sorted(v))) for k, v in six.iteritems(d) ) @cached_attribute def conflicting_imports(self): r""" Returns imports that conflict with each other. >>> ImportSet('import b\nfrom f import a as b\n').conflicting_imports ('b',) >>> ImportSet('import b\nfrom f import a\n').conflicting_imports () :rtype: ``bool`` """ return tuple( k for k, v in six.iteritems(self.by_import_as) if len(v) > 1 and k != "*") @cached_attribute def flags(self): """ If this contains __future__ imports, then the bitwise-ORed of the compiler_flag values associated with the features. Otherwise, 0. """ imports = self._by_module_name[0].get("__future__", []) return CompilerFlags(*[imp.flags for imp in imports]) def __repr__(self): printed = self.pretty_print(allow_conflicts=True) lines = "".join(" "+line for line in printed.splitlines(True)) return "%s('''\n%s''')" % (type(self).__name__, lines) def pretty_print(self, params=None, allow_conflicts=False): """ Pretty-print a block of import statements into a single string. :type params: `ImportFormatParams` :rtype: ``str`` """ params = ImportFormatParams(params) # TODO: instead of complaining about conflicts, just filter out the # shadowed imports at construction time. if not allow_conflicts and self.conflicting_imports: raise ConflictingImportsError( "Refusing to pretty-print because of conflicting imports: " + '; '.join( "%r imported as %r" % ( [imp.fullname for imp in self.by_import_as[i]], i) for i in self.conflicting_imports)) from_spaces = max(1, params.from_spaces) def do_align(statement): return statement.fromname != '__future__' or params.align_future def pp(statement, import_column): if do_align(statement): return statement.pretty_print( params=params, import_column=import_column, from_spaces=from_spaces) else: return statement.pretty_print( params=params, import_column=None, from_spaces=1) statements = self.get_statements( separate_from_imports=params.separate_from_imports) def isint(x): return isinstance(x, int) and not isinstance(x, bool) if not statements: import_column = None elif isinstance(params.align_imports, bool): if params.align_imports: fromimp_stmts = [ s for s in statements if s.fromname and do_align(s)] if fromimp_stmts: import_column = ( max(len(s.fromname) for s in fromimp_stmts) + from_spaces + 5) else: import_column = None else: import_column = None elif isinstance(params.align_imports, int): import_column = params.align_imports elif isinstance(params.align_imports, (tuple, list, set)): # If given a set of candidate alignment columns, then try each # alignment column and pick the one that yields the fewest number # of output lines. if not all(isinstance(x, int) for x in params.align_imports): raise TypeError("expected set of integers; got %r" % (params.align_imports,)) candidates = sorted(set(params.align_imports)) if len(candidates) == 0: raise ValueError("list of zero candidate alignment columns specified") elif len(candidates) == 1: # Optimization. import_column = next(iter(candidates)) else: def argmin(map): items = iter(sorted(map.items())) min_k, min_v = next(items) for k, v in items: if v < min_v: min_k = k min_v = v return min_k def count_lines(import_column): return sum( s.pretty_print( params=params, import_column=import_column, from_spaces=from_spaces).count("\n") for s in statements) # Construct a map from alignment column to total number of # lines. col2length = dict((c, count_lines(c)) for c in candidates) # Pick the column that yields the fewest lines. Break ties by # picking the smaller column. import_column = argmin(col2length) else: raise TypeError( "ImportSet.pretty_print(): unexpected params.align_imports type %s" % (type(params.align_imports).__name__,)) return ''.join(pp(statement, import_column) for statement in statements) def __contains__(self, x): return x in self._importset def __eq__(self, other): if self is other: return True if not isinstance(other, ImportSet): return NotImplemented return self._importset == other._importset def __ne__(self, other): return not (self == other) # The rest are defined by total_ordering def __lt__(self, other): if not isinstance(other, ImportSet): return NotImplemented return self._importset < other._importset def __cmp__(self, other): if self is other: return 0 if not isinstance(other, ImportSet): return NotImplemented return cmp(self._importset, other._importset) def __hash__(self): return hash(self._importset) def __len__(self): return len(self.imports) def __iter__(self): return iter(self.imports) ImportSet._EMPTY = ImportSet._from_imports([]) @total_ordering class ImportMap(object): r""" A map from import fullname identifier to fullname identifier. >>> ImportMap({'a.b': 'aa.bb', 'a.b.c': 'aa.bb.cc'}) ImportMap({'a.b': 'aa.bb', 'a.b.c': 'aa.bb.cc'}) An ``ImportMap`` is an immutable data structure. """ def __new__(cls, arg): if isinstance(arg, cls): return arg if isinstance(arg, (tuple, list)): return cls._merge(arg) if isinstance(arg, dict): if not len(arg): return cls._EMPTY return cls._from_map(arg) else: raise TypeError("ImportMap: expected a dict, not a %s" % (type(arg).__name__,)) @classmethod def _from_map(cls, arg): data = dict((Import(k).fullname, Import(v).fullname) for k, v in arg.items()) self = object.__new__(cls) self._data = data return self @classmethod def _merge(cls, maps): maps = [cls(m) for m in maps] maps = [m for m in maps if m] if not maps: return cls._EMPTY data = {} for map in maps: data.update(map._data) return cls(data) def __getitem__(self, k): k = Import(k).fullname return self._data.__getitem__(k) def __iter__(self): return iter(self._data) def items(self): return self._data.items() def iteritems(self): return six.iteritems(self._data) def iterkeys(self): return six.iterkeys(self._data) def keys(self): return self._data.keys() def values(self): return self._data.values() def __len__(self): return len(self._data) def without_imports(self, removals): """ Return a copy of self without the given imports. Matches both keys and values. """ removals = ImportSet(removals) if not removals: return self # Optimization cls = type(self) result = [(k, v) for k, v in self._data.items() if Import(k) not in removals and Import(v) not in removals] if len(result) == len(self._data): return self # Space optimization return cls(dict(result)) def __repr__(self): s = ", ".join("%r: %r" % (k,v) for k,v in sorted(self.items())) return "ImportMap({%s})" % s def __eq__(self, other): if self is other: return True if not isinstance(other, ImportMap): return NotImplemented return self._data == other._data def __ne__(self, other): return not (self == other) # The rest are defined by total_ordering def __lt__(self, other): if not isinstance(other, ImportMap): return NotImplemented return self._data < other._data def __cmp__(self, other): if self is other: return 0 if not isinstance(other, ImportMap): return NotImplemented return cmp(self._data, other._data) def __hash__(self): h = hash(self._data) self.__hash__ = lambda: h return h ImportMap._EMPTY = ImportMap._from_map({})