This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,55 @@
# pyflyby/__init__.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015, 2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from pyflyby._autoimp import (auto_eval, auto_import,
find_missing_imports)
from pyflyby._dbg import (add_debug_functions_to_builtins,
attach_debugger, debug_on_exception,
debug_statement, debugger,
enable_exception_handler_debugger,
enable_faulthandler,
enable_signal_handler_debugger,
print_traceback, remote_print_stack)
from pyflyby._file import Filename
from pyflyby._flags import CompilerFlags
from pyflyby._importdb import ImportDB
from pyflyby._imports2s import (canonicalize_imports,
reformat_import_statements,
remove_broken_imports,
replace_star_imports,
transform_imports)
from pyflyby._importstmt import (Import, ImportStatement,
NonImportStatementError)
from pyflyby._interactive import (disable_auto_importer,
enable_auto_importer,
install_in_ipython_config_file,
load_ipython_extension,
unload_ipython_extension)
from pyflyby._livepatch import livepatch, xreload
from pyflyby._log import logger
from pyflyby._parse import PythonBlock, PythonStatement
from pyflyby._version import __version__
# Deprecated:
from pyflyby._dbg import (breakpoint, debug_exception,
debug_statement,
enable_exception_handler,
enable_signal_handler_breakpoint,
waitpoint)
# Promote the function & classes that we've chosen to expose publicly to be
# known as pyflyby.Foo instead of pyflyby._module.Foo.
for x in list(globals().values()):
if getattr(x, "__module__", "").startswith("pyflyby."):
x.__module__ = "pyflyby"
del x
# Discourage "from pyflyby import *".
# Use the tidy-imports/autoimporter instead!
__all__ = []

View file

@ -0,0 +1,10 @@
# pyflyby/__main__.py
# Copyright (C) 2014, 2015 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
if __name__ == "__main__":
from pyflyby._py import py_main
py_main()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,515 @@
# pyflyby/_cmdline.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015, 2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
import optparse
import os
import signal
import six
from six import reraise
from six.moves import input
import sys
from textwrap import dedent
import traceback
from pyflyby._file import (FileText, Filename, atomic_write_file,
expand_py_files_from_args, read_file)
from pyflyby._importstmt import ImportFormatParams
from pyflyby._log import logger
from pyflyby._util import cached_attribute, indent
def hfmt(s):
return dedent(s).strip()
def maindoc():
import __main__
return (__main__.__doc__ or '').strip()
def _sigpipe_handler(*args):
# The parent process piped our stdout and closed the pipe before we
# finished writing, e.g. "tidy-imports ... | head" or "tidy-imports ... |
# less". Exit quietly - squelch the "close failed in file object
# destructor" message would otherwise be raised.
raise SystemExit(1)
def parse_args(addopts=None, import_format_params=False, modify_action_params=False):
"""
Do setup for a top-level script and parse arguments.
"""
### Setup.
# Register a SIGPIPE handler.
signal.signal(signal.SIGPIPE, _sigpipe_handler)
### Parse args.
parser = optparse.OptionParser(usage='\n'+maindoc())
def log_level_callbacker(level):
def callback(option, opt_str, value, parser):
logger.set_level(level)
return callback
def debug_callback(option, opt_str, value, parser):
logger.set_level("DEBUG")
parser.add_option("--debug", action="callback",
callback=debug_callback,
help="Debug mode (noisy and fail fast).")
parser.add_option("--verbose", action="callback",
callback=log_level_callbacker("DEBUG"),
help="Be noisy.")
parser.add_option("--quiet", action="callback",
callback=log_level_callbacker("ERROR"),
help="Be quiet.")
parser.add_option("--version", action="callback",
callback=lambda *args: print_version_and_exit(),
help="Print pyflyby version and exit.")
if modify_action_params:
group = optparse.OptionGroup(parser, "Action options")
action_diff = action_external_command('pyflyby-diff')
def parse_action(v):
V = v.strip().upper()
if V == 'PRINT':
return action_print
elif V == 'REPLACE':
return action_replace
elif V == 'QUERY':
return action_query()
elif V == "DIFF":
return action_diff
elif V.startswith("QUERY:"):
return action_query(v[6:])
elif V.startswith("EXECUTE:"):
return action_external_command(v[8:])
elif V == "IFCHANGED":
return action_ifchanged
else:
raise Exception(
"Bad argument %r to --action; "
"expected PRINT or REPLACE or QUERY or IFCHANGED "
"or EXECUTE:..." % (v,))
def set_actions(actions):
actions = tuple(actions)
parser.values.actions = actions
def action_callback(option, opt_str, value, parser):
action_args = value.split(',')
set_actions([parse_action(v) for v in action_args])
def action_callbacker(actions):
def callback(option, opt_str, value, parser):
set_actions(actions)
return callback
group.add_option(
"--actions", type='string', action='callback',
callback=action_callback,
metavar='PRINT|REPLACE|IFCHANGED|QUERY|DIFF|EXECUTE:mycommand',
help=hfmt('''
Comma-separated list of action(s) to take. If PRINT, print
the changed file to stdout. If REPLACE, then modify the
file in-place. If EXECUTE:mycommand, then execute
'mycommand oldfile tmpfile'. If DIFF, then execute
'pyflyby-diff'. If QUERY, then query user to continue.
If IFCHANGED, then continue actions only if file was
changed.'''))
group.add_option(
"--print", "-p", action='callback',
callback=action_callbacker([action_print]),
help=hfmt('''
Equivalent to --action=PRINT (default when stdin or stdout is
not a tty) '''))
group.add_option(
"--diff", "-d", action='callback',
callback=action_callbacker([action_diff]),
help=hfmt('''Equivalent to --action=DIFF'''))
group.add_option(
"--replace", "-r", action='callback',
callback=action_callbacker([action_ifchanged, action_replace]),
help=hfmt('''Equivalent to --action=IFCHANGED,REPLACE'''))
group.add_option(
"--diff-replace", "-R", action='callback',
callback=action_callbacker([action_ifchanged, action_diff, action_replace]),
help=hfmt('''Equivalent to --action=IFCHANGED,DIFF,REPLACE'''))
actions_interactive = [
action_ifchanged, action_diff,
action_query("Replace {filename}?"), action_replace]
group.add_option(
"--interactive", "-i", action='callback',
callback=action_callbacker(actions_interactive),
help=hfmt('''
Equivalent to --action=IFCHANGED,DIFF,QUERY,REPLACE (default
when stdin & stdout are ttys) '''))
if os.isatty(0) and os.isatty(1):
default_actions = actions_interactive
else:
default_actions = [action_print]
parser.set_default('actions', tuple(default_actions))
parser.add_option_group(group)
parser.add_option(
'--symlinks', action='callback', nargs=1, type=str,
dest='symlinks', callback=symlink_callback, help="--symlinks should be one of: " + symlinks_help,
)
parser.set_defaults(symlinks='error')
if import_format_params:
group = optparse.OptionGroup(parser, "Pretty-printing options")
group.add_option('--align-imports', '--align', type='str', default="32",
metavar='N',
help=hfmt('''
Whether and how to align the 'import' keyword in
'from modulename import aliases...'. If 0, then
don't align. If 1, then align within each block
of imports. If an integer > 1, then align at
that column, wrapping with a backslash if
necessary. If a comma-separated list of integers
(tab stops), then pick the column that results in
the fewest number of lines total per block.'''))
group.add_option('--from-spaces', type='int', default=3, metavar='N',
help=hfmt('''
The number of spaces after the 'from' keyword.
(Must be at least 1; default is 3.)'''))
group.add_option('--separate-from-imports', action='store_true',
default=False,
help=hfmt('''
Separate 'from ... import ...'
statements from 'import ...' statements.'''))
group.add_option('--no-separate-from-imports', action='store_false',
dest='separate_from_imports',
help=hfmt('''
(Default) Don't separate 'from ... import ...'
statements from 'import ...' statements.'''))
group.add_option('--align-future', action='store_true',
default=False,
help=hfmt('''
Align the 'from __future__ import ...' statement
like others.'''))
group.add_option('--no-align-future', action='store_false',
dest='align_future',
help=hfmt('''
(Default) Don't align the 'from __future__ import
...' statement.'''))
group.add_option('--width', type='int', default=79, metavar='N',
help=hfmt('''
Maximum line length (default: 79).'''))
group.add_option('--black', action='store_true', default=False,
help=hfmt('''
Use black to format imports. If this option is
used, all other formatting options are ignored.'''))
group.add_option('--hanging-indent', type='choice', default='never',
choices=['never','auto','always'],
metavar='never|auto|always',
dest='hanging_indent',
help=hfmt('''
How to wrap import statements that don't fit on
one line.
If --hanging-indent=always, then always indent
imported tokens at column 4 on the next line.
If --hanging-indent=never (default), then align
import tokens after "import (" (by default column
40); do so even if some symbols are so long that
this would exceed the width (by default 79)).
If --hanging-indent=auto, then use hanging indent
only if it is necessary to prevent exceeding the
width (by default 79).
'''))
def uniform_callback(option, opt_str, value, parser):
parser.values.separate_from_imports = False
parser.values.from_spaces = 3
parser.values.align_imports = '32'
group.add_option('--uniform', '-u', action="callback",
callback=uniform_callback,
help=hfmt('''
(Default) Shortcut for --no-separate-from-imports
--from-spaces=3 --align-imports=32.'''))
def unaligned_callback(option, opt_str, value, parser):
parser.values.separate_from_imports = True
parser.values.from_spaces = 1
parser.values.align_imports = '0'
group.add_option('--unaligned', '-n', action="callback",
callback=unaligned_callback,
help=hfmt('''
Shortcut for --separate-from-imports
--from-spaces=1 --align-imports=0.'''))
parser.add_option_group(group)
if addopts is not None:
addopts(parser)
# This is the only way to provide a default value for an option with a
# callback.
if modify_action_params:
args = ["--symlinks=error"] + sys.argv[1:]
else:
args = None
options, args = parser.parse_args(args=args)
if import_format_params:
align_imports_args = [int(x.strip())
for x in options.align_imports.split(",")]
if len(align_imports_args) == 1 and align_imports_args[0] == 1:
align_imports = True
elif len(align_imports_args) == 1 and align_imports_args[0] == 0:
align_imports = False
else:
align_imports = tuple(sorted(set(align_imports_args)))
options.params = ImportFormatParams(
align_imports =align_imports,
from_spaces =options.from_spaces,
separate_from_imports =options.separate_from_imports,
max_line_length =options.width,
use_black =options.black,
align_future =options.align_future,
hanging_indent =options.hanging_indent,
)
return options, args
def _default_on_error(filename):
raise SystemExit("bad filename %s" % (filename,))
def filename_args(args, on_error=_default_on_error):
"""
Return list of filenames given command-line arguments.
:rtype:
``list`` of `Filename`
"""
if args:
return expand_py_files_from_args(args, on_error)
elif not os.isatty(0):
return [Filename.STDIN]
else:
syntax()
def print_version_and_exit(extra=None):
from pyflyby._version import __version__
msg = "pyflyby %s" % (__version__,)
progname = os.path.realpath(sys.argv[0])
if os.path.exists(progname):
msg += " (%s)" % (os.path.basename(progname),)
print(msg)
if extra:
print(extra)
raise SystemExit(0)
def syntax(message=None, usage=None):
if message:
logger.error(message)
outmsg = ((usage or maindoc()) +
'\n\nFor usage, see: %s --help' % (sys.argv[0],))
print(outmsg, file=sys.stderr)
raise SystemExit(1)
class AbortActions(Exception):
pass
class Modifier(object):
def __init__(self, modifier, filename):
self.modifier = modifier
self.filename = filename
self._tmpfiles = []
@cached_attribute
def input_content(self):
return read_file(self.filename)
# TODO: refactor to avoid having these heavy-weight things inside a
# cached_attribute, which causes annoyance while debugging.
@cached_attribute
def output_content(self):
return FileText(self.modifier(self.input_content), filename=self.filename)
def _tempfile(self):
from tempfile import NamedTemporaryFile
f = NamedTemporaryFile()
self._tmpfiles.append(f)
return f, Filename(f.name)
@cached_attribute
def output_content_filename(self):
f, fname = self._tempfile()
if six.PY3:
f.write(bytes(self.output_content.joined, "utf-8"))
else:
f.write(self.output_content.joined.encode('utf-8'))
f.flush()
return fname
@cached_attribute
def input_content_filename(self):
if isinstance(self.filename, Filename):
return self.filename
# If the input was stdin, and the user wants a diff, then we need to
# write it to a temp file.
f, fname = self._tempfile()
if six.PY3:
f.write(bytes(self.input_content, "utf-8"))
else:
f.write(self.input_content)
f.flush()
return fname
def __del__(self):
for f in self._tmpfiles:
f.close()
def process_actions(filenames, actions, modify_function,
reraise_exceptions=()):
errors = []
def on_error_filename_arg(arg):
print("%s: bad filename %s" % (sys.argv[0], arg), file=sys.stderr)
errors.append("%s: bad filename" % (arg,))
filenames = filename_args(filenames, on_error=on_error_filename_arg)
for filename in filenames:
try:
m = Modifier(modify_function, filename)
for action in actions:
action(m)
except AbortActions:
continue
except reraise_exceptions:
raise
except Exception as e:
errors.append("%s: %s: %s" % (filename, type(e).__name__, e))
type_e = type(e)
try:
tb = sys.exc_info()[2]
if str(filename) not in str(e):
try:
e = type_e("While processing %s: %s" % (filename, e))
pass
except TypeError:
# Exception takes more than one argument
pass
if logger.debug_enabled:
reraise(type_e, e, tb)
traceback.print_exception(type(e), e, tb)
finally:
tb = None # avoid refcycles involving tb
continue
if errors:
msg = "\n%s: encountered the following problems:\n" % (sys.argv[0],)
for e in errors:
lines = e.splitlines()
msg += " " + lines[0] + '\n'.join(
(" %s"%line for line in lines[1:]))
raise SystemExit(msg)
def action_print(m):
output_content = m.output_content
sys.stdout.write(output_content.joined)
def action_ifchanged(m):
if m.output_content.joined == m.input_content.joined:
logger.debug("unmodified: %s", m.filename)
raise AbortActions
def action_replace(m):
if m.filename == Filename.STDIN:
raise Exception("Can't replace stdio in-place")
logger.info("%s: *** modified ***", m.filename)
atomic_write_file(m.filename, m.output_content)
def action_external_command(command):
import subprocess
def action(m):
bindir = os.path.dirname(os.path.realpath(sys.argv[0]))
env = os.environ
env['PATH'] = env['PATH'] + ":" + bindir
fullcmd = "%s %s %s" % (
command, m.input_content_filename, m.output_content_filename)
logger.debug("Executing external command: %s", fullcmd)
ret = subprocess.call(fullcmd, shell=True, env=env)
logger.debug("External command returned %d", ret)
return action
def action_query(prompt="Proceed?"):
def action(m):
p = prompt.format(filename=m.filename)
print()
print("%s [y/N] " % (p), end="")
try:
if input().strip().lower().startswith('y'):
return True
except KeyboardInterrupt:
print("KeyboardInterrupt", file=sys.stderr)
raise SystemExit(1)
print("Aborted")
raise AbortActions
return action
def symlink_callback(option, opt_str, value, parser):
parser.values.actions = tuple(i for i in parser.values.actions if i not in
symlink_callbacks.values())
if value in symlink_callbacks:
parser.values.actions = (symlink_callbacks[value],) + parser.values.actions
else:
raise optparse.OptionValueError("--symlinks must be one of 'error', 'follow', 'skip', or 'replace'. Got %r" % value)
symlinks_help = """\
--symlinks=error (default; gives an error on symlinks),
--symlinks=follow (follows symlinks),
--symlinks=skip (skips symlinks),
--symlinks=replace (replaces symlinks with the target file\
"""
# Warning, the symlink actions will only work if they are run first.
# Otherwise, output_content may already be cached
def symlink_error(m):
if m.filename == Filename.STDIN:
return symlink_follow(m)
if m.filename.islink:
raise SystemExit("""\
Error: %s appears to be a symlink. Use one of the following options to allow symlinks:
%s
""" % (m.filename, indent(symlinks_help, ' ')))
def symlink_follow(m):
if m.filename == Filename.STDIN:
return
if m.filename.islink:
logger.info("Following symlink %s" % m.filename)
m.filename = m.filename.realpath
def symlink_skip(m):
if m.filename == Filename.STDIN:
return symlink_follow(m)
if m.filename.islink:
logger.info("Skipping symlink %s" % m.filename)
raise AbortActions
def symlink_replace(m):
if m.filename == Filename.STDIN:
return symlink_follow(m)
if m.filename.islink:
logger.info("Replacing symlink %s" % m.filename)
# The current behavior automatically replaces symlinks, so do nothing
symlink_callbacks = {
'error': symlink_error,
'follow': symlink_follow,
'skip': symlink_skip,
'replace': symlink_replace,
}

View file

@ -0,0 +1,136 @@
from __future__ import absolute_import, division, print_function
from pyflyby._log import logger
from pyflyby._imports2s import SourceToSourceFileImportsTransformation
from pyflyby._importstmt import Import
import six
# These are comm targets that the frontend (lab/notebook) is expected to
# open. At this point, we handle only missing imports and
# formatting imports
MISSING_IMPORTS = "pyflyby.missing_imports"
FORMATTING_IMPORTS = "pyflyby.format_imports"
INIT_COMMS = "pyflyby.init_comms"
PYFLYBY_START_MSG = "# THIS CELL WAS AUTO-GENERATED BY PYFLYBY\n"
PYFLYBY_END_MSG = "# END AUTO-GENERATED BLOCK\n"
pyflyby_comm_targets= [MISSING_IMPORTS, FORMATTING_IMPORTS]
# A map of the comms opened with a given target name.
comms = {}
# TODO: Document the expected contract for the different
# custom comm messages
def in_jupyter():
from IPython.core.getipython import get_ipython
ip = get_ipython()
if ip is None:
logger.debug("get_ipython() doesn't exist. Comm targets can only"
"be added in an Jupyter notebook/lab/console environment")
return False
else:
try:
ip.kernel.comm_manager
except AttributeError:
logger.debug("Comm targets can only be added in Jupyter "
"notebook/lab/console environment")
return False
else:
return True
def _register_target(target_name):
from IPython.core.getipython import get_ipython
ip = get_ipython()
comm_manager = ip.kernel.comm_manager
comm_manager.register_target(target_name, comm_open_handler)
def initialize_comms():
if in_jupyter():
for target in pyflyby_comm_targets:
_register_target(target)
from ipykernel.comm import Comm
comm = Comm(target_name=INIT_COMMS)
msg = {"type": INIT_COMMS}
logger.debug("Requesting frontend to (re-)initialize comms")
comm.send(msg)
def remove_comms():
for target_name, comm in six.iteritems(comms):
comm.close()
logger.debug("Closing comm for " + target_name)
def send_comm_message(target_name, msg):
if in_jupyter():
try:
comm = comms[target_name]
except KeyError:
logger.debug("Comm with target_name " + target_name + " hasn't been opened")
else:
# Help the frontend distinguish between multiple types
# of custom comm messages
msg["type"] = target_name
comm.send(msg)
logger.debug("Sending comm message for target " + target_name)
def comm_close_handler(comm, message):
comm_id = message["comm_id"]
for target, comm in six.iterkeys(comms):
if comm.comm_id == comm_id:
comms.pop(target)
def _reformat_helper(input_code, imports):
from pyflyby._imports2s import reformat_import_statements
if PYFLYBY_START_MSG in input_code:
before, bmarker, middle = input_code.partition(PYFLYBY_START_MSG)
else:
before, bmarker, middle = "", "", input_code
if PYFLYBY_END_MSG in middle:
middle, emarker, after = middle.partition(PYFLYBY_END_MSG)
else:
middle, emarker, after = middle, "", ""
if imports is not None:
transform = SourceToSourceFileImportsTransformation(middle)
if isinstance(imports, str):
imports = [imports]
for imp in imports:
assert isinstance(imp, str)
if not imp.strip():
continue
transform.add_import(Import(imp))
middle = str(transform.output())
return reformat_import_statements(before + bmarker + middle + emarker + after)
def comm_open_handler(comm, message):
"""
Handles comm_open message for pyflyby custom comm messages.
https://jupyter-client.readthedocs.io/en/stable/messaging.html#opening-a-comm.
Handler for all PYFLYBY custom comm messages that are opened by the frontend
(at this point, just the jupyterlab frontend does this).
"""
comm.on_close(comm_close_handler)
comms[message["content"]["target_name"]] = comm
@comm.on_msg
def _recv(msg):
data = msg["content"]["data"]
if data["type"] == FORMATTING_IMPORTS:
imports = data.get('imports', None)
fmt_code = _reformat_helper(data["input_code"], imports)
comm.send({"formatted_code": str(fmt_code), "type": FORMATTING_IMPORTS})

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,382 @@
# pyflyby/_docxref.py.
# Module for checking Epydoc cross-references.
# Portions of the code below are derived from Epydoc, which is distributed
# under the MIT license:
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and any associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the
# following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# The software is provided "as is", without warranty of any kind, express or
# implied, including but not limited to the warranties of merchantability,
# fitness for a particular purpose and noninfringement. In no event shall
# the authors or copyright holders be liable for any claim, damages or other
# liability, whether in an action of contract, tort or otherwise, arising
# from, out of or in connection with the software or the use or other
# dealings in the software.
from __future__ import (absolute_import, division, print_function,
with_statement)
import re
import six
from six.moves import builtins
from textwrap import dedent
from epydoc.apidoc import (ClassDoc, ModuleDoc, PropertyDoc,
RoutineDoc, UNKNOWN, VariableDoc)
from epydoc.docbuilder import build_doc_index
from epydoc.markup.plaintext import ParsedPlaintextDocstring
from pyflyby._file import Filename
from pyflyby._idents import DottedIdentifier
from pyflyby._log import logger
from pyflyby._modules import ModuleHandle
from pyflyby._util import cached_attribute, memoize, prefixes
# If someone references numpy.*, just assume it's OK - it's not worth
# following into numpy because it's too slow.
ASSUME_MODULES_OK = set(['numpy'])
@memoize
def map_strings_to_line_numbers(module):
"""
Walk ``module.ast``, looking at all string literals. Return a map from
string literals to line numbers (1-index).
:rtype:
``dict`` from ``str`` to (``int``, ``str``)
"""
d = {}
for field in module.block.string_literals():
# Dedent because epydoc dedents strings and we need to look up by
# those. But keep track of original version because we need to count
# exact line numbers.
s = dedent(field.s).strip()
start_lineno = field.startpos.lineno
d[s] = (start_lineno, field.s)
return d
def get_string_linenos(module, searchstring, within_string):
"""
Return the line numbers (1-indexed) within ``filename`` that contain
``searchstring``. Only consider string literals (i.e. not comments).
First look for exact matches of ``within_string`` (modulo indenting) and
then search within that. Only if the ``within_string`` is not found,
search the entire file.
[If there's a comment on the same line as a string that also contains the
searchstring, we'll get confused.]
"""
module = ModuleHandle(module)
regexp = re.compile(searchstring)
map = map_strings_to_line_numbers(module)
results = []
def scan_within_string(results, start_lineno, orig_full_string):
for i, line in enumerate(orig_full_string.splitlines()):
if regexp.search(line):
results.append( start_lineno + i )
try:
lineno, orig_full_string = map[within_string.strip()]
except KeyError:
pass
else:
# We found the larger string exactly within the ast.
scan_within_string(results, lineno, orig_full_string)
if results:
return tuple(results)
# We could continue down if this ever happened.
raise Exception(
"Found superstring in %r but not substring %r within superstring"
% (module.filename, searchstring))
# Try a full text search.
for lineno, orig_full_string in map.values():
scan_within_string(results, lineno, orig_full_string)
if results:
return tuple(sorted(results))
raise Exception(
"Could not find %r anywhere in %r" % (searchstring, module.filename))
def describe_xref(identifier, container):
module = ModuleHandle(str(container.defining_module.canonical_name))
assert module.filename == Filename(container.defining_module.filename)
linenos = get_string_linenos(
module,
"(L{|<)%s" % (identifier,),
container.docstring)
return (module, linenos, str(container.canonical_name), identifier)
def safe_build_doc_index(modules):
# build_doc_index isn't re-entrant due to crappy caching! >:(
from epydoc.docintrospecter import clear_cache
clear_cache()
from epydoc.docparser import _moduledoc_cache
_moduledoc_cache.clear()
# Build a new DocIndex. It swallows exceptions and returns None on error!
# >:(
result = build_doc_index(modules)
if result is None:
raise Exception("Failed to build doc index on %r" % (modules,))
return result
class ExpandedDocIndex(object):
"""
A wrapper around DocIndex that automatically expands with more modules as
needed.
"""
# TODO: this is kludgy and inefficient since it re-reads modules.
def __init__(self, modules):
self.modules = set([ModuleHandle(m) for m in modules])
def add_module(self, module):
"""
Adds ``module`` and recreates the DocIndex with the updated set of
modules.
:return:
Whether anything was added.
"""
module = ModuleHandle(module)
for prefix in module.ancestors:
if prefix in self.modules:
# The module, or a prefix of it, was already added.
return False
for existing_module in sorted(self.modules):
if existing_module.startswith(module):
# This supersedes an existing module.
assert existing_module != module
self.modules.remove(existing_module)
logger.debug("Expanding docindex to include %r", module)
self.modules.add(module)
del self.docindex
return True
def find(self, a, b):
return self.docindex.find(a, b)
def get_vardoc(self, a):
return self.docindex.get_vardoc(a)
@cached_attribute
def docindex(self):
return safe_build_doc_index(
[str(m.name) for m in sorted(self.modules)])
def remove_epydoc_sym_suffix(s):
"""
Remove trailing "'" that Epydoc annoyingly adds to 'shadowed' names.
>>> remove_epydoc_sym_suffix("a.b'.c'.d")
'a.b.c.d'
"""
return re.sub(r"'([.]|$)", r'\1', s)
class XrefScanner(object):
def __init__(self, modules):
self.modules = modules
self.docindex = safe_build_doc_index(modules)
@cached_attribute
def expanded_docindex(self):
return ExpandedDocIndex(self.modules)
def scan(self):
self._failed_xrefs = []
valdocs = sorted(self.docindex.reachable_valdocs(
imports=False, packages=False, bases=False, submodules=False,
subclasses=False, private=True
))
for doc in valdocs:
if isinstance(doc, ClassDoc):
self.scan_class(doc)
elif isinstance(doc, ModuleDoc):
self.scan_module(doc)
return tuple(sorted(self._failed_xrefs))
def scan_module(self, doc):
self.descr(doc)
if doc.is_package is True:
for submodule in doc.submodules:
self.scan_module(submodule)
# self.scan_module_list(doc)
self.scan_details_list(doc, "function")
self.scan_details_list(doc, "other")
def scan_class(self, doc):
self.descr(doc)
self.scan_details_list(doc, "method")
self.scan_details_list(doc, "classvariable")
self.scan_details_list(doc, "instancevariable")
self.scan_details_list(doc, "property")
def scan_details_list(self, doc, value_type):
detailed = True
if isinstance(doc, ClassDoc):
var_docs = doc.select_variables(value_type=value_type,
imported=False, inherited=False,
public=None,
detailed=detailed)
else:
var_docs = doc.select_variables(value_type=value_type,
imported=False,
public=None,
detailed=detailed)
for var_doc in var_docs:
self.scan_details(var_doc)
def scan_details(self, var_doc):
self.descr(var_doc)
if isinstance(var_doc.value, RoutineDoc):
self.return_type(var_doc)
self.return_descr(var_doc)
for (arg_names, arg_descr) in var_doc.value.arg_descrs:
self.scan_docstring(arg_descr, var_doc.value)
for arg in var_doc.value.arg_types:
self.scan_docstring(
var_doc.value.arg_types[arg], var_doc.value)
elif isinstance(var_doc.value, PropertyDoc):
prop_doc = var_doc.value
self.return_type(prop_doc.fget)
self.return_type(prop_doc.fset)
self.return_type(prop_doc.fdel)
else:
self.type_descr(var_doc)
def _scan_attr(self, attr, api_doc):
if api_doc in (None, UNKNOWN):
return ''
pds = getattr(api_doc, attr, None) # pds = ParsedDocstring.
if pds not in (None, UNKNOWN):
self.scan_docstring(pds, api_doc)
elif isinstance(api_doc, VariableDoc):
self._scan_attr(attr, api_doc.value)
def summary(self, api_doc):
self._scan_attr('summary', api_doc)
def descr(self, api_doc):
self._scan_attr('descr', api_doc)
def type_descr(self, api_doc):
self._scan_attr('type_descr', api_doc)
def return_type(self, api_doc):
self._scan_attr('return_type', api_doc)
def return_descr(self, api_doc):
self._scan_attr('return_descr', api_doc)
def check_xref(self, identifier, container):
"""
Check that ``identifier`` cross-references a proper symbol.
Look in modules that we weren't explicitly asked to look in, if
needed.
"""
if identifier in builtins.__dict__:
return True
def check_container():
if self.expanded_docindex.find(identifier, container) is not None:
return True
if isinstance(container, RoutineDoc):
tcontainer = self.expanded_docindex.get_vardoc(
container.canonical_name)
doc = self.expanded_docindex.find(identifier, tcontainer)
while (doc is not None and tcontainer not in (None, UNKNOWN)
and tcontainer.overrides not in (None, UNKNOWN)):
tcontainer = tcontainer.overrides
doc = self.expanded_docindex.find(identifier, tcontainer)
return doc is not None
return False
def check_defining_module(x):
if x is None:
return False
defining_module_name = remove_epydoc_sym_suffix(str(
x.defining_module.canonical_name))
if defining_module_name in ASSUME_MODULES_OK:
return True
if self.expanded_docindex.add_module(defining_module_name):
if check_container():
return True
return False
if check_container():
return True
if (isinstance(container, RoutineDoc) and
identifier in container.all_args()):
return True
if check_defining_module(container):
return True
# If the user has imported foo.bar.baz as baz and now uses
# ``baz.quux``, we need to add the module foo.bar.baz.
for prefix in reversed(list(prefixes(
DottedIdentifier(remove_epydoc_sym_suffix(identifier))))):
if check_defining_module(
self.docindex.find(str(prefix), container)):
return True
try:
module = ModuleHandle.containing(identifier)
except ImportError:
pass
else:
if str(module.name) in ASSUME_MODULES_OK:
return True
if self.expanded_docindex.add_module(module):
if check_container():
return True
return False
def scan_docstring(self, parsed_docstring, container):
if parsed_docstring in (None, UNKNOWN): return ''
if isinstance(parsed_docstring, ParsedPlaintextDocstring):
return ''
def scan_tree(tree):
if isinstance(tree, six.string_types):
return tree
variables = [scan_tree(child) for child in tree.children]
if tree.tag == 'link':
identifier = variables[1]
if not self.check_xref(identifier, container):
self._failed_xrefs.append(
describe_xref(identifier, container) )
return '?'
elif tree.tag == 'indexed':
return '?'
elif tree.tag in ('epytext', 'section', 'tag', 'arg',
'name', 'target', 'html', 'para'):
return ''.join(variables)
return '?'
scan_tree(parsed_docstring._tree)
def find_bad_doc_cross_references(names):
"""
Find docstring cross references that fail to resolve.
:type names:
Sequence of module names or filenames.
:return:
Sequence of ``(module, linenos, container_name, identifier)`` tuples.
"""
xrs = XrefScanner(names)
return xrs.scan()

View file

@ -0,0 +1,725 @@
# pyflyby/_file.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015, 2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from functools import total_ordering
import io
import os
import re
import six
import sys
from six import string_types
from pyflyby._util import cached_attribute, cmp, memoize
class UnsafeFilenameError(ValueError):
pass
# TODO: statcache
@total_ordering
class Filename(object):
"""
A filename.
>>> Filename('/etc/passwd')
Filename('/etc/passwd')
"""
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, six.string_types):
return cls._from_filename(arg)
raise TypeError
@classmethod
def _from_filename(cls, filename):
if not isinstance(filename, six.string_types):
raise TypeError
filename = str(filename)
if not filename:
raise UnsafeFilenameError("(empty string)")
if re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename):
raise UnsafeFilenameError(filename)
if re.search("(^|/)~", filename):
raise UnsafeFilenameError(filename)
self = object.__new__(cls)
self._filename = os.path.abspath(filename)
return self
def __str__(self):
return self._filename
def __repr__(self):
return "%s(%r)" % (type(self).__name__, self._filename)
def __truediv__(self, x):
return type(self)(os.path.join(self._filename, x))
def __hash__(self):
return hash(self._filename)
def __eq__(self, o):
if self is o:
return True
if not isinstance(o, Filename):
return NotImplemented
return self._filename == o._filename
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, o):
if not isinstance(o, Filename):
return NotImplemented
return self._filename < o._filename
def __cmp__(self, o):
if self is o:
return 0
if not isinstance(o, Filename):
return NotImplemented
return cmp(self._filename, o._filename)
@cached_attribute
def ext(self):
"""
Returns the extension of this filename, including the dot.
Returns ``None`` if no extension.
:rtype:
``str`` or ``None``
"""
lhs, dot, rhs = self._filename.rpartition('.')
if not dot:
return None
return dot + rhs
@cached_attribute
def base(self):
return os.path.basename(self._filename)
@cached_attribute
def dir(self):
return type(self)(os.path.dirname(self._filename))
@cached_attribute
def real(self):
return type(self)(os.path.realpath(self._filename))
@property
def realpath(self):
return type(self)(os.path.realpath(self._filename))
@property
def exists(self):
return os.path.exists(self._filename)
@property
def islink(self):
return os.path.islink(self._filename)
@property
def isdir(self):
return os.path.isdir(self._filename)
@property
def isfile(self):
return os.path.isfile(self._filename)
@property
def isreadable(self):
return os.access(self._filename, os.R_OK)
@property
def iswritable(self):
return os.access(self._filename, os.W_OK)
@property
def isexecutable(self):
return os.access(self._filename, os.X_OK)
def startswith(self, prefix):
prefix = Filename(prefix)
if self == prefix:
return True
return self._filename.startswith("%s/" % (prefix,))
def list(self, ignore_unsafe=True):
filenames = [os.path.join(self._filename, f)
for f in sorted(os.listdir(self._filename))]
result = []
for f in filenames:
try:
f = Filename(f)
except UnsafeFilenameError:
if ignore_unsafe:
continue
else:
raise
result.append(f)
return result
@property
def ancestors(self):
"""
Return ancestors of self, from self to /.
>>> Filename("/aa/bb").ancestors
(Filename('/aa/bb'), Filename('/aa'), Filename('/'))
:rtype:
``tuple`` of ``Filename`` s
"""
result = [self]
while True:
dir = result[-1].dir
if dir == result[-1]:
break
result.append(dir)
return tuple(result)
@memoize
def _get_PATH():
PATH = os.environ.get("PATH", "").split(os.pathsep)
result = []
for path in PATH:
if not path:
continue
try:
result.append(Filename(path))
except UnsafeFilenameError:
continue
return tuple(result)
def which(program):
"""
Find ``program`` on $PATH.
:type program:
``str``
:rtype:
`Filename`
:return:
Program on $PATH, or ``None`` if not found.
"""
# See if it exists in the current directory.
candidate = Filename(program)
if candidate.isreadable:
return candidate
for path in _get_PATH():
candidate = path / program
if candidate.isexecutable:
return candidate
return None
Filename.STDIN = Filename("/dev/stdin")
@total_ordering
class FilePos(object):
"""
A (lineno, colno) position within a `FileText`.
Both lineno and colno are 1-indexed.
"""
def __new__(cls, *args):
if len(args) == 0:
return cls._ONE_ONE
if len(args) == 1:
arg, = args
if isinstance(arg, cls):
return arg
elif arg is None:
return cls._ONE_ONE
elif isinstance(arg, tuple):
args = arg
# Fall through
else:
raise TypeError
lineno, colno = cls._intint(args)
if lineno == colno == 1:
return cls._ONE_ONE # space optimization
if lineno < 1:
raise ValueError(
"FilePos: invalid lineno=%d; should be >= 1" % lineno,)
if colno < 1:
raise ValueError(
"FilePos: invalid colno=%d; should be >= 1" % colno,)
return cls._from_lc(lineno, colno)
@staticmethod
def _intint(args):
if (type(args) is tuple and
len(args) == 2 and
type(args[0]) is type(args[1]) is int):
return args
else:
raise TypeError("Expected (int,int); got %r" % (args,))
@classmethod
def _from_lc(cls, lineno, colno):
self = object.__new__(cls)
self.lineno = lineno
self.colno = colno
return self
def __add__(self, delta):
'''
"Add" a coordinate (line,col) delta to this ``FilePos``.
Note that addition here may be a non-obvious. If there is any line
movement, then the existing column number is ignored, and the new
column is the new column delta + 1 (to convert into 1-based numbers).
:rtype:
`FilePos`
'''
ldelta, cdelta = self._intint(delta)
assert ldelta >= 0 and cdelta >= 0
if ldelta == 0:
return FilePos(self.lineno, self.colno + cdelta)
else:
return FilePos(self.lineno + ldelta, 1 + cdelta)
def __str__(self):
return "(%d,%d)" % (self.lineno, self.colno)
def __repr__(self):
return "FilePos%s" % (self,)
@property
def _data(self):
return (self.lineno, self.colno)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, FilePos):
return NotImplemented
return self._data == other._data
def __ne__(self, other):
return not (self == other)
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, FilePos):
return NotImplemented
return cmp(self._data, other._data)
# The rest are defined by total_ordering
def __lt__(self, other):
if self is other:
return 0
if not isinstance(other, FilePos):
return NotImplemented
return self._data < other._data
def __hash__(self):
return hash(self._data)
FilePos._ONE_ONE = FilePos._from_lc(1, 1)
@total_ordering
class FileText(object):
"""
Represents a contiguous sequence of lines from a file.
"""
def __new__(cls, arg, filename=None, startpos=None):
"""
Return a new ``FileText`` instance.
:type arg:
``FileText``, ``Filename``, ``str``, or tuple of ``str``
:param arg:
If a sequence of lines, then each should end with a newline and have
no other newlines. Otherwise, something that can be interpreted or
converted into a sequence of lines.
:type filename:
`Filename`
:param filename:
Filename to attach to this ``FileText``, if not already given by
``arg``.
:type startpos:
``FilePos``
:param startpos:
Starting file position (lineno & colno) of this ``FileText``, if not
already given by ``arg``.
:rtype:
``FileText``
"""
if isinstance(arg, cls):
if filename is startpos is None:
return arg
return arg.alter(filename=filename, startpos=startpos)
elif isinstance(arg, Filename):
return cls(read_file(arg), filename=filename, startpos=startpos)
elif hasattr(arg, "__text__"):
return FileText(arg.__text__(), filename=filename, startpos=startpos)
elif isinstance(arg, six.string_types):
self = object.__new__(cls)
self.joined = arg
else:
raise TypeError("%s: unexpected %s"
% (cls.__name__, type(arg).__name__))
if filename is not None:
filename = Filename(filename)
startpos = FilePos(startpos)
self.filename = filename
self.startpos = startpos
return self
@classmethod
def _from_lines(cls, lines, filename, startpos):
assert type(lines) is tuple
assert len(lines) > 0
assert isinstance(lines[0], string_types)
assert not lines[-1].endswith("\n")
self = object.__new__(cls)
self.lines = lines
self.filename = filename
self.startpos = startpos
return self
@cached_attribute
def lines(self):
r"""
Lines that have been split by newline.
These strings do NOT contain '\n'.
If the input file ended in '\n', then the last item will be the empty
string. This is to avoid having to check lines[-1].endswith('\n')
everywhere.
:rtype:
``tuple`` of ``str``
"""
# Used if only initialized with 'joined'.
# We use str.split() instead of str.splitlines() because the latter
# doesn't distinguish between strings that end in newline or not
# (or requires extra work to process if we use splitlines(True)).
return tuple(self.joined.split('\n'))
@cached_attribute
def joined(self): # used if only initialized with 'lines'
return '\n'.join(self.lines)
@classmethod
def from_filename(cls, filename):
return cls.from_lines(Filename(filename))
def alter(self, filename=None, startpos=None):
if filename is not None:
filename = Filename(filename)
else:
filename = self.filename
if startpos is not None:
startpos = FilePos(startpos)
else:
startpos = self.startpos
if filename == self.filename and startpos == self.startpos:
return self
else:
result = object.__new__(type(self))
result.lines = self.lines
result.joined = self.joined
result.filename = filename
result.startpos = startpos
return result
@cached_attribute
def endpos(self):
"""
The position after the last character in the text.
:rtype:
``FilePos``
"""
startpos = self.startpos
lines = self.lines
lineno = startpos.lineno + len(lines) - 1
if len(lines) == 1:
colno = startpos.colno + len(lines[-1])
else:
colno = 1 + len(lines[-1])
return FilePos(lineno, colno)
def _lineno_to_index(self, lineno):
lineindex = lineno - self.startpos.lineno
# Check that the lineindex is in range. We don't allow pointing at
# the line after the last line because we already ensured that
# self.lines contains an extra empty string if necessary, to indicate
# a trailing newline in the file.
if not 0 <= lineindex < len(self.lines):
raise IndexError(
"Line number %d out of range [%d, %d)"
% (lineno, self.startpos.lineno, self.endpos.lineno))
return lineindex
def _colno_to_index(self, lineindex, colno):
coloffset = self.startpos.colno if lineindex == 0 else 1
colindex = colno - coloffset
line = self.lines[lineindex]
# Check that the colindex is in range. We do allow pointing at the
# character after the last (non-newline) character in the line.
if not 0 <= colindex <= len(line):
raise IndexError(
"Column number %d on line %d out of range [%d, %d]"
% (colno, lineindex+self.startpos.lineno,
coloffset, coloffset+len(line)))
return colindex
def __getitem__(self, arg):
"""
Return the line(s) with the given line number(s).
If slicing, returns an instance of ``FileText``.
Note that line numbers are indexed based on ``self.startpos.lineno``
(which is 1 at the start of the file).
>>> FileText("a\\nb\\nc\\nd")[2]
'b'
>>> FileText("a\\nb\\nc\\nd")[2:4]
FileText('b\\nc\\n', startpos=(2,1))
>>> FileText("a\\nb\\nc\\nd")[0]
Traceback (most recent call last):
...
IndexError: Line number 0 out of range [1, 4)
When slicing, the input arguments can also be given as ``FilePos``
arguments or (lineno,colno) tuples. These are 1-indexed at the start
of the file.
>>> FileText("a\\nb\\nc\\nd")[(2,2):4]
FileText('\\nc\\n', startpos=(2,2))
:rtype:
``str`` or `FileText`
"""
L = self._lineno_to_index
C = self._colno_to_index
if isinstance(arg, slice):
if arg.step is not None and arg.step != 1:
raise ValueError("steps not supported")
# Interpret start (lineno,colno) into indexes.
if arg.start is None:
start_lineindex = 0
start_colindex = 0
elif isinstance(arg.start, int):
start_lineindex = L(arg.start)
start_colindex = 0
else:
startpos = FilePos(arg.start)
start_lineindex = L(startpos.lineno)
start_colindex = C(start_lineindex, startpos.colno)
# Interpret stop (lineno,colno) into indexes.
if arg.stop is None:
stop_lineindex = len(self.lines)
stop_colindex = len(self.lines[-1])
elif isinstance(arg.stop, int):
stop_lineindex = L(arg.stop)
stop_colindex = 0
else:
stoppos = FilePos(arg.stop)
stop_lineindex = L(stoppos.lineno)
stop_colindex = C(stop_lineindex, stoppos.colno)
# {start,stop}_{lineindex,colindex} are now 0-indexed
# [open,closed) ranges.
assert 0 <= start_lineindex <= stop_lineindex < len(self.lines)
assert 0 <= start_colindex <= len(self.lines[start_lineindex])
assert 0 <= stop_colindex <= len(self.lines[stop_lineindex])
# Optimization: return entire range
if (start_lineindex == 0 and
start_colindex == 0 and
stop_lineindex == len(self.lines)-1 and
stop_colindex == len(self.lines[-1])):
return self
# Get the lines we care about. We always include an extra entry
# at the end which we'll chop to the desired number of characters.
result_split = list(self.lines[start_lineindex:stop_lineindex+1])
# Clip the starting and ending strings. We do the end clip first
# in case the result has only one line.
result_split[-1] = result_split[-1][:stop_colindex]
result_split[0] = result_split[0][start_colindex:]
# Compute the new starting line and column numbers.
result_lineno = start_lineindex + self.startpos.lineno
if start_lineindex == 0:
result_colno = start_colindex + self.startpos.colno
else:
result_colno = start_colindex + 1
result_startpos = FilePos(result_lineno, result_colno)
return FileText._from_lines(tuple(result_split),
filename=self.filename,
startpos=result_startpos)
elif isinstance(arg, int):
# Return a single line.
lineindex = L(arg)
return self.lines[lineindex]
else:
raise TypeError("bad type %r" % (type(arg),))
@classmethod
def concatenate(cls, args):
"""
Concatenate a bunch of `FileText` arguments. Uses the ``filename``
and ``startpos`` from the first argument.
:rtype:
`FileText`
"""
args = [FileText(x) for x in args]
if len(args) == 1:
return args[0]
return FileText(
''.join([l.joined for l in args]),
filename=args[0].filename,
startpos=args[0].startpos)
def __repr__(self):
r = "%s(%r" % (type(self).__name__, self.joined,)
if self.filename is not None:
r += ", filename=%r" % (str(self.filename),)
if self.startpos != FilePos():
r += ", startpos=%s" % (self.startpos,)
r += ")"
return r
def __str__(self):
return self.joined
def __eq__(self, o):
if self is o:
return True
if not isinstance(o, FileText):
return NotImplemented
return (self.filename == o.filename and
self.joined == o.joined and
self.startpos == o.startpos)
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, o):
if not isinstance(o, FileText):
return NotImplemented
return ((self.filename, self.joined, self.startpos) <
(o .filename, o .joined, o .startpos))
def __cmp__(self, o):
if self is o:
return 0
if not isinstance(o, FileText):
return NotImplemented
return cmp((self.filename, self.joined, self.startpos),
(o .filename, o .joined, o .startpos))
def __hash__(self):
h = hash((self.filename, self.joined, self.startpos))
self.__hash__ = lambda: h
return h
def read_file(filename):
filename = Filename(filename)
if filename == Filename.STDIN:
data = sys.stdin.read()
else:
with io.open(str(filename), 'r') as f:
data = f.read()
return FileText(data, filename=filename)
def write_file(filename, data):
filename = Filename(filename)
data = FileText(data)
with open(str(filename), 'w') as f:
f.write(data.joined)
def atomic_write_file(filename, data):
filename = Filename(filename)
data = FileText(data)
temp_filename = Filename("%s.tmp.%s" % (filename, os.getpid(),))
write_file(temp_filename, data)
try:
st = os.stat(str(filename)) # OSError if file didn't exit before
os.chmod(str(temp_filename), st.st_mode)
os.chown(str(temp_filename), -1, st.st_gid) # OSError if not member of group
except OSError:
pass
os.rename(str(temp_filename), str(filename))
def expand_py_files_from_args(pathnames, on_error=lambda filename: None):
"""
Enumerate ``*.py`` files, recursively.
Arguments that are files are always included.
Arguments that are directories are recursively searched for ``*.py`` files.
:type pathnames:
``list`` of `Filename` s
:type on_error:
callable
:param on_error:
Function that is called for arguments directly specified in ``pathnames``
that don't exist or are otherwise inaccessible.
:rtype:
``list`` of `Filename` s
"""
if not isinstance(pathnames, (tuple, list)):
pathnames = [pathnames]
pathnames = [Filename(f) for f in pathnames]
result = []
# Check for problematic arguments. Note that we intentionally only do
# this for directly specified arguments, not for recursively traversed
# arguments.
stack = []
for pathname in reversed(pathnames):
if pathname.isfile:
stack.append((pathname, True))
elif pathname.isdir:
stack.append((pathname, False))
else:
on_error(pathname)
while stack:
pathname, isfile = stack.pop(-1)
if isfile:
result.append(pathname)
continue
for f in reversed(pathname.list()):
# Check inclusions/exclusions for recursion. Note that we
# intentionally do this in the recursive step rather than the
# base step because if the user specification includes
# e.g. .pyflyby, we do want to include it; however, we don't
# want to recurse into .pyflyby ourselves.
if f.base.startswith("."):
continue
if f.base == "__pycache__":
continue
if f.isfile:
if f.ext == ".py":
stack.append((f, True))
elif f.isdir:
stack.append((f, False))
else:
# Silently ignore non-files/dirs from traversal.
pass
return result

View file

@ -0,0 +1,236 @@
# pyflyby/_flags.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
import __future__
import ast
import operator
import six
from six.moves import reduce
import warnings
from pyflyby._util import cached_attribute
# Initialize mappings from compiler_flag to feature name and vice versa.
_FLAG2NAME = {}
_NAME2FLAG = {}
for name in __future__.all_feature_names:
flag = getattr(__future__, name).compiler_flag
_FLAG2NAME[flag] = name
_NAME2FLAG[name] = flag
for name in dir(ast):
if name.startswith('PyCF'):
flag_name = name[len('PyCF_'):].lower()
flag = getattr(ast, name)
_FLAG2NAME[flag] = flag_name
_NAME2FLAG[flag_name] = flag
_FLAGNAME_ITEMS = sorted(_FLAG2NAME.items())
_ALL_FLAGS = reduce(operator.or_, _FLAG2NAME.keys())
class CompilerFlags(int):
"""
Representation of Python "compiler flags", i.e. features from __future__.
>>> print(CompilerFlags(0x18000).__interactive_display__()) # doctest: +SKIP
CompilerFlags(0x18000) # from __future__ import with_statement, print_function
>>> print(CompilerFlags(0x10000, 0x8000).__interactive_display__()) # doctest: +SKIP
CompilerFlags(0x18000) # from __future__ import with_statement, print_function
>>> print(CompilerFlags('with_statement', 'print_function').__interactive_display__()) # doctest: +SKIP
CompilerFlags(0x18000) # from __future__ import with_statement, print_function
This can be used as an argument to the built-in compile() function. For
instance, in Python 2::
>>> compile("print('x', file=None)", "?", "exec", flags=0, dont_inherit=1) #doctest:+SKIP
Traceback (most recent call last):
...
SyntaxError: invalid syntax
>>> compile("print('x', file=None)", "?", "exec", flags=CompilerFlags("print_function"), dont_inherit=1) #doctest:+ELLIPSIS
<code object ...>
"""
def __new__(cls, *args):
"""
Construct a new ``CompilerFlags`` instance.
:param args:
Any number (zero or more) ``CompilerFlags`` s, ``int`` s, or ``str`` s,
which are bitwise-ORed together.
:rtype:
`CompilerFlags`
"""
if len(args) == 0:
return cls._ZERO
elif len(args) == 1:
arg, = args
if isinstance(arg, cls):
return arg
elif arg is None:
return cls._ZERO
elif isinstance(arg, int):
warnings.warn('creating CompilerFlags from integers is deprecated, '
' flags values change between Python versions. If you are sure use .from_int',
DeprecationWarning, stacklevel=2)
return cls.from_int(arg)
elif isinstance(arg, six.string_types):
return cls.from_str(arg)
elif isinstance(arg, ast.AST):
return cls.from_ast(arg)
elif isinstance(arg, (tuple, list)):
return cls(*arg)
else:
raise TypeError("CompilerFlags: unknown type %s"
% (type(arg).__name__,))
else:
flags = []
for x in args:
if isinstance(x, cls):
flags.append(int(x))
elif isinstance(x, int):
warnings.warn(
"creating CompilerFlags from integers is deprecated, "
" flags values change between Python versions. If you are sure use .from_int",
DeprecationWarning,
stacklevel=2,
)
flags.append(x)
elif isinstance(x, str):
flags.append(int(cls(x)))
else:
raise ValueError
#assert flags == [0x10000, 0x8000], flags
return cls.from_int(reduce(operator.or_, flags))
@classmethod
def from_int(cls, arg):
if arg == -1:
return cls._UNKNOWN # Instance optimization
if arg == 0:
return cls._ZERO # Instance optimization
self = int.__new__(cls, arg)
bad_flags = int(self) & ~_ALL_FLAGS
if bad_flags:
raise ValueError(
"CompilerFlags: unknown flag value(s) %s %s" % (bin(bad_flags), hex(bad_flags)))
return self
@classmethod
def from_str(cls, arg):
try:
flag = _NAME2FLAG[arg]
except KeyError:
raise ValueError(
"CompilerFlags: unknown flag %r" % (arg,))
return cls.from_int(flag)
@classmethod
def from_ast(cls, nodes):
"""
Parse the compiler flags from AST node(s).
:type nodes:
``ast.AST`` or sequence thereof
:rtype:
``CompilerFlags``
"""
if isinstance(nodes, ast.Module):
nodes = nodes.body
elif isinstance(nodes, ast.AST):
nodes = [nodes]
flags = []
for node in nodes:
if not isinstance(node, ast.ImportFrom):
# Got a non-import; stop looking further.
break
if not node.module == "__future__":
# Got a non-__future__-import; stop looking further.
break
# Get the feature names.
names = [n.name for n in node.names]
flags.extend(names)
return cls(flags)
@cached_attribute
def names(self):
return tuple(
n
for f, n in _FLAGNAME_ITEMS
if f & self)
def __or__(self, o):
if o == 0:
return self
if not isinstance(o, CompilerFlags):
o = CompilerFlags(o)
if self == 0:
return o
return CompilerFlags.from_int(int(self) | int(o))
def __ror__(self, o):
return self | o
def __and__(self, o):
if not isinstance(o, int):
o = CompilerFlags(o)
return CompilerFlags.from_int(int(self) & int(o))
def __rand__(self, o):
return self & o
def __xor__(self, o):
if not isinstance(o, CompilerFlags):
o = CompilerFlags.from_int(o)
return CompilerFlags.from_int(int(self) ^ int(o))
def __rxor__(self, o):
return self ^ o
def __repr__(self):
return "CompilerFlags(%s)" % (hex(self),)
def __str__(self):
return hex(self)
def __interactive_display__(self):
s = repr(self)
if self != 0:
s += " # from __future__ import " + ", ".join(self.names)
return s
CompilerFlags._ZERO = int.__new__(CompilerFlags, 0)
CompilerFlags._UNKNOWN = int.__new__(CompilerFlags, -1)
# flags that _may_ exists on future versions.
_future_flags = {
"nested_scopes",
"generators",
"division",
"absolute_import",
"with_statement",
"print_function",
"unicode_literals",
"barry_as_FLUFL",
"generator_stop",
"annotations",
"allow_top_level_await",
"only_ast",
"type_comments",
}
for k in _future_flags:
setattr(CompilerFlags, k, CompilerFlags._UNKNOWN)
for k, v in _NAME2FLAG.items():
setattr(CompilerFlags, k, CompilerFlags.from_int(v))

View file

@ -0,0 +1,180 @@
# pyflyby/_format.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
import six
class FormatParams(object):
max_line_length = 79
wrap_paren = True
indent = 4
hanging_indent = 'never'
use_black = False
def __new__(cls, *args, **kwargs):
if not kwargs and len(args) == 1 and isinstance(args[0], cls):
return args[0]
self = object.__new__(cls)
# TODO: be more careful here
dicts = []
for arg in args:
if arg is None:
pass
elif isinstance(arg, cls):
dicts.append(arg.__dict__)
else:
raise TypeError
if kwargs:
dicts.append(kwargs)
for kwargs in dicts:
for key, value in six.iteritems(kwargs):
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError("bad kwarg %r" % (key,))
return self
def fill(tokens, sep=(", ", ""), prefix="", suffix="", newline="\n",
max_line_length=80):
r"""
Given a sequences of strings, fill them into a single string with up to
``max_line_length`` characters each.
>>> fill(["'hello world'", "'hello two'"],
... prefix=("print ", " "), suffix=(" \\", ""),
... max_line_length=25)
"print 'hello world', \\\n 'hello two'\n"
:param tokens:
Sequence of strings to fill. There must be at least one token.
:param sep:
Separator string to append to each token. If a 2-element tuple, then
indicates the separator between tokens and the separator after the last
token. Trailing whitespace is removed from each line before appending
the suffix, but not from between tokens on the same line.
:param prefix:
String to prepend at the beginning of each line. If a 2-element tuple,
then indicates the prefix for the first line and prefix for subsequent
lines.
:param suffix:
String to append to the end of each line. If a 2-element tuple, then
indicates the suffix for all lines except the last, and the suffix for
the last line.
:return:
Filled string.
"""
N = max_line_length
assert len(tokens) > 0
if isinstance(prefix, tuple):
first_prefix, cont_prefix = prefix
else:
first_prefix = cont_prefix = prefix
if isinstance(suffix, tuple):
nonterm_suffix, term_suffix = suffix
else:
nonterm_suffix = term_suffix = suffix
if isinstance(sep, tuple):
nonterm_sep, term_sep = sep
else:
nonterm_sep = term_sep = sep
lines = [first_prefix + tokens[0]]
for token, is_last in zip(tokens[1:], [False]*(len(tokens)-2) + [True]):
suffix = term_suffix if is_last else nonterm_suffix
sep = (term_sep if is_last else nonterm_sep).rstrip()
# Does the next token fit?
if len(lines[-1] + nonterm_sep + token + sep + suffix) <= N:
# Yes; add it.
lines[-1] += nonterm_sep + token
else:
# No; break into new line.
lines[-1] += nonterm_sep.rstrip() + nonterm_suffix + newline
lines.append(cont_prefix + token)
lines[-1] += term_sep.rstrip() + term_suffix + newline
return ''.join(lines)
def pyfill(prefix, tokens, params=FormatParams()):
"""
Fill a Python statement.
>>> print(pyfill('print ', ["foo.bar", "baz", "quux", "quuuuux"]), end='')
print foo.bar, baz, quux, quuuuux
>>> print(pyfill('print ', ["foo.bar", "baz", "quux", "quuuuux"],
... FormatParams(max_line_length=15, hanging_indent='auto')), end='')
print (foo.bar,
baz,
quux,
quuuuux)
>>> print(pyfill('print ', ["foo.bar", "baz", "quux", "quuuuux"],
... FormatParams(max_line_length=14, hanging_indent='auto')), end='')
print (
foo.bar,
baz, quux,
quuuuux)
:param prefix:
Prefix for first line.
:param tokens:
Sequence of string tokens
:type params:
`FormatParams`
:rtype:
``str``
"""
N = params.max_line_length
if params.wrap_paren:
# Check how we will break up the tokens.
len_full = sum(len(tok) for tok in tokens) + 2 * (len(tokens)-1)
if len(prefix) + len_full <= N:
# The entire thing fits on one line; no parens needed. We check
# this first because breaking into lines adds paren overhead.
#
# Output looks like:
# from foo import abc, defgh, ijkl, mnopq, rst
return prefix + ", ".join(tokens) + "\n"
if params.hanging_indent == "never":
hanging_indent = False
elif params.hanging_indent == "always":
hanging_indent = True
elif params.hanging_indent == "auto":
# Decide automatically whether to do hanging-indent mode. If any
# line would exceed the max_line_length, then do hanging indent;
# else don't.
#
# In order to use non-hanging-indent mode, the first line would
# have an overhead of 2 because of "(" and ",". We check the
# longest token since even if the first token fits, we still want
# to avoid later tokens running over N.
maxtoklen = max(len(token) for token in tokens)
hanging_indent = (len(prefix) + maxtoklen + 2 > N)
else:
raise ValueError("bad params.hanging_indent=%r"
% (params.hanging_indent,))
if hanging_indent:
# Hanging indent mode. We need a single opening paren and
# continue all imports on separate lines.
#
# Output looks like:
# from foo import (
# abc, defgh, ijkl,
# mnopq, rst)
return (prefix + "(\n"
+ fill(tokens, max_line_length=N,
prefix=(" " * params.indent), suffix=("", ")")))
else:
# Non-hanging-indent mode.
#
# Output looks like:
# from foo import (abc, defgh,
# ijkl, mnopq,
# rst)
pprefix = prefix + "("
return fill(tokens, max_line_length=N,
prefix=(pprefix, " " * len(pprefix)), suffix=("", ")"))
else:
raise NotImplementedError

View file

@ -0,0 +1,256 @@
# pyflyby/_idents.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from functools import total_ordering
from keyword import kwlist
import re
import six
from pyflyby._util import cached_attribute, cmp
# Don't consider "print" a keyword, in order to be compatible with user code
# that uses "from __future__ import print_function".
_my_kwlist = list(kwlist)
if six.PY2:
_my_kwlist.remove("print")
_my_iskeyword = frozenset(_my_kwlist).__contains__
# TODO: use DottedIdentifier.prefixes
def dotted_prefixes(dotted_name, reverse=False):
"""
Return the prefixes of a dotted name.
>>> dotted_prefixes("aa.bb.cc")
['aa', 'aa.bb', 'aa.bb.cc']
>>> dotted_prefixes("aa.bb.cc", reverse=True)
['aa.bb.cc', 'aa.bb', 'aa']
:type dotted_name:
``str``
:param reverse:
If False (default), return shortest to longest. If True, return longest
to shortest.
:rtype:
``list`` of ``str``
"""
name_parts = dotted_name.split(".")
if reverse:
idxes = range(len(name_parts), 0, -1)
else:
idxes = range(1, len(name_parts)+1)
result = ['.'.join(name_parts[:i]) or '.' for i in idxes]
return result
_name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
_dotted_name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*([.][a-zA-Z_][a-zA-Z0-9_]*)*$")
_dotted_name_prefix_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*([.][a-zA-Z_][a-zA-Z0-9_]*)*[.]?$")
def is_identifier(s, dotted=False, prefix=False):
"""
Return whether ``s`` is a valid Python identifier name.
>>> is_identifier("foo")
True
>>> is_identifier("foo+bar")
False
>>> is_identifier("from")
False
By default, we check whether ``s`` is a single valid identifier, meaning
dots are not allowed. If ``dotted=True``, then we check each dotted
component::
>>> is_identifier("foo.bar")
False
>>> is_identifier("foo.bar", dotted=True)
True
>>> is_identifier("foo..bar", dotted=True)
False
>>> is_identifier("foo.from", dotted=True)
False
By default, the string must comprise a valid identifier. If
``prefix=True``, then allow strings that are prefixes of valid identifiers.
Prefix=False excludes the empty string, strings with a trailing dot, and
strings with a trailing keyword component, but prefix=True does not
exclude these.
>>> is_identifier("foo.bar.", dotted=True)
False
>>> is_identifier("foo.bar.", dotted=True, prefix=True)
True
>>> is_identifier("foo.or", dotted=True)
False
>>> is_identifier("foo.or", dotted=True, prefix=True)
True
:type s:
``str``
:param dotted:
If ``False`` (default), then the input must be a single name such as
"foo". If ``True``, then the input can be a single name or a dotted name
such as "foo.bar.baz".
:param prefix:
If ``False`` (Default), then the input must be a valid identifier. If
``True``, then the input can be a valid identifier or the prefix of a
valid identifier.
:rtype:
``bool``
"""
if not isinstance(s, six.string_types):
raise TypeError("is_identifier(): expected a string; got a %s"
% (type(s).__name__,))
if six.PY3:
if prefix:
return is_identifier(s + '_', dotted=dotted, prefix=False)
if dotted:
return all(is_identifier(w, dotted=False) for w in s.split('.'))
return s.isidentifier() and not _my_iskeyword(s)
if prefix:
if not s:
return True
if dotted:
return bool(
_dotted_name_prefix_re.match(s) and
not any(_my_iskeyword(w) for w in s.split(".")[:-1]))
else:
return bool(_name_re.match(s))
else:
if dotted:
# Use a regular expression that works for dotted names. (As an
# alternate implementation, one could imagine calling
# all(is_identifier(w) for w in s.split(".")). We don't do that
# because s could be a long text string.)
return bool(
_dotted_name_re.match(s) and
not any(_my_iskeyword(w) for w in s.split(".")))
else:
return bool(_name_re.match(s) and not _my_iskeyword(s))
def brace_identifiers(text):
"""
Parse a string and yield all tokens of the form "{some_token}".
>>> list(brace_identifiers("{salutation}, {your_name}."))
['salutation', 'your_name']
"""
if isinstance(text, bytes):
text = text.decode('utf-8', errors='replace')
for match in re.finditer("{([a-zA-Z_][a-zA-Z0-9_]*)}", text):
yield match.group(1)
class BadDottedIdentifierError(ValueError):
pass
# TODO: Use in various places, esp where e.g. dotted_prefixes is used.
@total_ordering
class DottedIdentifier(object):
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, six.string_types):
return cls._from_name(arg)
if isinstance(arg, (tuple, list)):
return cls._from_name(".".join(arg))
raise TypeError("DottedIdentifier: unexpected %s"
% (type(arg).__name__,))
@classmethod
def _from_name(cls, name):
self = object.__new__(cls)
self.name = str(name)
if not is_identifier(self.name, dotted=True):
if len(self.name) > 20:
raise BadDottedIdentifierError("Invalid python symbol name")
else:
raise BadDottedIdentifierError("Invalid python symbol name %r"
% (name,))
self.parts = tuple(self.name.split('.'))
return self
@cached_attribute
def parent(self):
if len(self.parts) > 1:
return DottedIdentifier('.'.join(self.parts[:-1]))
else:
return None
@cached_attribute
def prefixes(self):
parts = self.parts
idxes = range(1, len(parts)+1)
result = ['.'.join(parts[:i]) for i in idxes]
return tuple(DottedIdentifier(x) for x in result)
def startswith(self, o):
o = type(self)(o)
return self.parts[:len(o.parts)] == o.parts
def __getitem__(self, x):
return type(self)(self.parts[x])
def __len__(self):
return len(self.parts)
def __iter__(self):
return (type(self)(x) for x in self.parts)
def __add__(self, suffix):
return type(self)("%s.%s") % (self, suffix)
def __str__(self):
return self.name
def __repr__(self):
return "%s(%r)" % (type(self).__name__, self.name)
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, DottedIdentifier):
return NotImplemented
return self.name == other.name
def __ne__(self, other):
if self is other:
return False
if not isinstance(other, DottedIdentifier):
return NotImplemented
return self.name != other.name
# The rest are defined by total_ordering
def __lt__(self, other):
if not isinstance(other, DottedIdentifier):
return NotImplemented
return self.name < other.name
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, DottedIdentifier):
return NotImplemented
return cmp(self.name, other.name)

View file

@ -0,0 +1,631 @@
# pyflyby/_importclns.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from collections import defaultdict
from functools import total_ordering
import six
from pyflyby._flags import CompilerFlags
from pyflyby._idents import dotted_prefixes, is_identifier
from pyflyby._importstmt import (Import, ImportFormatParams,
ImportStatement,
NonImportStatementError)
from pyflyby._parse import PythonBlock
from pyflyby._util import (cached_attribute, cmp, partition,
stable_unique)
class NoSuchImportError(ValueError):
pass
class ConflictingImportsError(ValueError):
pass
@total_ordering
class ImportSet(object):
r"""
Representation of a set of imports organized into import statements.
>>> ImportSet('''
... from m1 import f1
... from m2 import f1
... from m1 import f2
... import m3.m4 as m34
... ''')
ImportSet('''
from m1 import f1, f2
from m2 import f1
from m3 import m4 as m34
''')
An ``ImportSet`` is an immutable data structure.
"""
def __new__(cls, arg, ignore_nonimports=False, ignore_shadowed=False):
"""
Return as an `ImportSet`.
:param ignore_nonimports:
If ``False``, complain about non-imports. If ``True``, ignore
non-imports.
:param ignore_shadowed:
Whether to ignore shadowed imports. If ``False``, then keep all
unique imports, even if they shadow each other. Note that an
``ImportSet`` is unordered; an ``ImportSet`` with conflicts will only
be useful for very specific cases (e.g. set of imports to forget
from known-imports database), and not useful for outputting as code.
If ``ignore_shadowed`` is ``True``, then earlier shadowed imports are
ignored.
:rtype:
`ImportSet`
"""
if isinstance(arg, cls):
if ignore_shadowed:
return cls._from_imports(arg._importset, ignore_shadowed=True)
else:
return arg
return cls._from_args(
arg,
ignore_nonimports=ignore_nonimports,
ignore_shadowed=ignore_shadowed)
@classmethod
def _from_imports(cls, imports, ignore_shadowed=False):
"""
:type imports:
Sequence of `Import` s
:param ignore_shadowed:
See `ImportSet.__new__`.
:rtype:
`ImportSet`
"""
# Canonicalize inputs.
imports = [Import(imp) for imp in imports]
if ignore_shadowed:
# Filter by overshadowed imports. Later imports take precedence.
by_import_as = {}
for imp in imports:
if imp.import_as == "*":
# Keep all unique star imports.
by_import_as[imp] = imp
else:
by_import_as[imp.import_as] = imp
filtered_imports = by_import_as.values()
else:
filtered_imports = imports
# Construct and return.
self = object.__new__(cls)
self._importset = frozenset(filtered_imports)
return self
@classmethod
def _from_args(cls, args, ignore_nonimports=False, ignore_shadowed=False):
"""
:type args:
``tuple`` or ``list`` of `ImportStatement` s, `PythonStatement` s,
`PythonBlock` s, `FileText`, and/or `Filename` s
:param ignore_nonimports:
If ``False``, complain about non-imports. If ``True``, ignore
non-imports.
:param ignore_shadowed:
See `ImportSet.__new__`.
:rtype:
`ImportSet`
"""
if not isinstance(args, (tuple, list)):
args = [args]
# Filter empty arguments to allow the subsequent optimizations to work
# more often.
args = [a for a in args if a]
if not args:
return cls._EMPTY
# If we only got one ``ImportSet``, just return it.
if len(args) == 1 and type(args[0]) is cls and not ignore_shadowed:
return args[0]
# Collect all `Import` s from arguments.
imports = []
for arg in args:
if isinstance(arg, Import):
imports.append(arg)
elif isinstance(arg, ImportSet):
imports.extend(arg.imports)
elif isinstance(arg, ImportStatement):
imports.extend(arg.imports)
elif isinstance(arg, str) and is_identifier(arg, dotted=True):
imports.append(Import(arg))
else: # PythonBlock, PythonStatement, Filename, FileText, str
block = PythonBlock(arg)
for statement in block.statements:
# Ignore comments/blanks.
if statement.is_comment_or_blank:
pass
elif statement.is_import:
imports.extend(ImportStatement(statement).imports)
elif ignore_nonimports:
pass
else:
raise NonImportStatementError(
"Got non-import statement %r" % (statement,))
return cls._from_imports(imports, ignore_shadowed=ignore_shadowed)
def with_imports(self, other):
"""
Return a new `ImportSet` that is the union of ``self`` and
``new_imports``.
>>> impset = ImportSet('from m import t1, t2, t3')
>>> impset.with_imports('import m.t2a as t2b')
ImportSet('''
from m import t1, t2, t2a as t2b, t3
''')
:type other:
`ImportSet` (or convertible)
:rtype:
`ImportSet`
"""
other = ImportSet(other)
return type(self)._from_imports(self._importset | other._importset)
def without_imports(self, removals):
"""
Return a copy of self without the given imports.
>>> imports = ImportSet('from m import t1, t2, t3, t4')
>>> imports.without_imports(['from m import t3'])
ImportSet('''
from m import t1, t2, t4
''')
:type removals:
`ImportSet` (or convertible)
:rtype:
`ImportSet`
"""
removals = ImportSet(removals)
if not removals:
return self # Optimization
# Preprocess star imports to remove.
star_module_removals = set(
[imp.split.module_name
for imp in removals if imp.split.member_name == "*"])
# Filter imports.
new_imports = []
for imp in self:
if imp in removals:
continue
if star_module_removals and imp.split.module_name:
prefixes = dotted_prefixes(imp.split.module_name)
if any(pfx in star_module_removals for pfx in prefixes):
continue
new_imports.append(imp)
# Return.
if len(new_imports) == len(self):
return self # Space optimization
return type(self)._from_imports(new_imports)
@cached_attribute
def _by_module_name(self):
"""
:return:
(mapping from name to __future__ imports,
mapping from name to non-'from' imports,
mapping from name to 'from' imports)
"""
ftr_imports = defaultdict(set)
pkg_imports = defaultdict(set)
frm_imports = defaultdict(set)
for imp in self._importset:
module_name, member_name, import_as = imp.split
if module_name is None:
pkg_imports[member_name].add(imp)
elif module_name == '__future__':
ftr_imports[module_name].add(imp)
else:
frm_imports[module_name].add(imp)
return tuple(
dict( (k, frozenset(v))
for k, v in six.iteritems(imports))
for imports in [ftr_imports, pkg_imports, frm_imports])
def get_statements(self, separate_from_imports=True):
"""
Canonicalized `ImportStatement` s.
These have been merged by module and sorted.
>>> importset = ImportSet('''
... import a, b as B, c, d.dd as DD
... from __future__ import division
... from _hello import there
... from _hello import *
... from _hello import world
... ''')
>>> for s in importset.get_statements(): print(s)
from __future__ import division
import a
import b as B
import c
from _hello import *
from _hello import there, world
from d import dd as DD
:rtype:
``tuple`` of `ImportStatement` s
"""
groups = self._by_module_name
if not separate_from_imports:
def union_dicts(*dicts):
result = {}
for label, dict in enumerate(dicts):
for k, v in six.iteritems(dict):
result[(k, label)] = v
return result
groups = [groups[0], union_dicts(*groups[1:])]
result = []
for importgroup in groups:
for _, imports in sorted(importgroup.items()):
star_imports, nonstar_imports = (
partition(imports, lambda imp: imp.import_as == "*"))
assert len(star_imports) <= 1
if star_imports:
result.append(ImportStatement(star_imports))
if nonstar_imports:
result.append(ImportStatement(sorted(nonstar_imports)))
return tuple(result)
@cached_attribute
def statements(self):
"""
Canonicalized `ImportStatement` s.
These have been merged by module and sorted.
:rtype:
``tuple`` of `ImportStatement` s
"""
return self.get_statements(separate_from_imports=True)
@cached_attribute
def imports(self):
"""
Canonicalized imports, in the same order as ``self.statements``.
:rtype:
``tuple`` of `Import` s
"""
return tuple(
imp
for importgroup in self._by_module_name
for _, imports in sorted(importgroup.items())
for imp in sorted(imports))
@cached_attribute
def by_import_as(self):
"""
Map from ``import_as`` to `Import`.
>>> ImportSet('from aa.bb import cc as dd').by_import_as
{'dd': (Import('from aa.bb import cc as dd'),)}
:rtype:
``dict`` mapping from ``str`` to tuple of `Import` s
"""
d = defaultdict(list)
for imp in self._importset:
d[imp.import_as].append(imp)
return dict( (k, tuple(sorted(stable_unique(v))))
for k, v in six.iteritems(d) )
@cached_attribute
def member_names(self):
r"""
Map from parent module/package ``fullname`` to known member names.
>>> impset = ImportSet("import numpy.linalg.info\nfrom sys import exit as EXIT")
>>> import pprint
>>> pprint.pprint(impset.member_names)
{'': ('EXIT', 'numpy', 'sys'),
'numpy': ('linalg',),
'numpy.linalg': ('info',),
'sys': ('exit',)}
This is used by the autoimporter module for implementing tab completion.
:rtype:
``dict`` mapping from ``str`` to tuple of ``str``
"""
d = defaultdict(set)
for imp in self._importset:
if '.' not in imp.import_as:
d[""].add(imp.import_as)
prefixes = dotted_prefixes(imp.fullname)
d[""].add(prefixes[0])
for prefix in prefixes[1:]:
splt = prefix.rsplit(".", 1)
d[splt[0]].add(splt[1])
return dict( (k, tuple(sorted(v)))
for k, v in six.iteritems(d) )
@cached_attribute
def conflicting_imports(self):
r"""
Returns imports that conflict with each other.
>>> ImportSet('import b\nfrom f import a as b\n').conflicting_imports
('b',)
>>> ImportSet('import b\nfrom f import a\n').conflicting_imports
()
:rtype:
``bool``
"""
return tuple(
k
for k, v in six.iteritems(self.by_import_as)
if len(v) > 1 and k != "*")
@cached_attribute
def flags(self):
"""
If this contains __future__ imports, then the bitwise-ORed of the
compiler_flag values associated with the features. Otherwise, 0.
"""
imports = self._by_module_name[0].get("__future__", [])
return CompilerFlags(*[imp.flags for imp in imports])
def __repr__(self):
printed = self.pretty_print(allow_conflicts=True)
lines = "".join(" "+line for line in printed.splitlines(True))
return "%s('''\n%s''')" % (type(self).__name__, lines)
def pretty_print(self, params=None, allow_conflicts=False):
"""
Pretty-print a block of import statements into a single string.
:type params:
`ImportFormatParams`
:rtype:
``str``
"""
params = ImportFormatParams(params)
# TODO: instead of complaining about conflicts, just filter out the
# shadowed imports at construction time.
if not allow_conflicts and self.conflicting_imports:
raise ConflictingImportsError(
"Refusing to pretty-print because of conflicting imports: " +
'; '.join(
"%r imported as %r" % (
[imp.fullname for imp in self.by_import_as[i]], i)
for i in self.conflicting_imports))
from_spaces = max(1, params.from_spaces)
def do_align(statement):
return statement.fromname != '__future__' or params.align_future
def pp(statement, import_column):
if do_align(statement):
return statement.pretty_print(
params=params, import_column=import_column,
from_spaces=from_spaces)
else:
return statement.pretty_print(
params=params, import_column=None, from_spaces=1)
statements = self.get_statements(
separate_from_imports=params.separate_from_imports)
def isint(x): return isinstance(x, int) and not isinstance(x, bool)
if not statements:
import_column = None
elif isinstance(params.align_imports, bool):
if params.align_imports:
fromimp_stmts = [
s for s in statements if s.fromname and do_align(s)]
if fromimp_stmts:
import_column = (
max(len(s.fromname) for s in fromimp_stmts)
+ from_spaces + 5)
else:
import_column = None
else:
import_column = None
elif isinstance(params.align_imports, int):
import_column = params.align_imports
elif isinstance(params.align_imports, (tuple, list, set)):
# If given a set of candidate alignment columns, then try each
# alignment column and pick the one that yields the fewest number
# of output lines.
if not all(isinstance(x, int) for x in params.align_imports):
raise TypeError("expected set of integers; got %r"
% (params.align_imports,))
candidates = sorted(set(params.align_imports))
if len(candidates) == 0:
raise ValueError("list of zero candidate alignment columns specified")
elif len(candidates) == 1:
# Optimization.
import_column = next(iter(candidates))
else:
def argmin(map):
items = iter(sorted(map.items()))
min_k, min_v = next(items)
for k, v in items:
if v < min_v:
min_k = k
min_v = v
return min_k
def count_lines(import_column):
return sum(
s.pretty_print(
params=params, import_column=import_column,
from_spaces=from_spaces).count("\n")
for s in statements)
# Construct a map from alignment column to total number of
# lines.
col2length = dict((c, count_lines(c)) for c in candidates)
# Pick the column that yields the fewest lines. Break ties by
# picking the smaller column.
import_column = argmin(col2length)
else:
raise TypeError(
"ImportSet.pretty_print(): unexpected params.align_imports type %s"
% (type(params.align_imports).__name__,))
return ''.join(pp(statement, import_column) for statement in statements)
def __contains__(self, x):
return x in self._importset
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, ImportSet):
return NotImplemented
return self._importset == other._importset
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, other):
if not isinstance(other, ImportSet):
return NotImplemented
return self._importset < other._importset
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, ImportSet):
return NotImplemented
return cmp(self._importset, other._importset)
def __hash__(self):
return hash(self._importset)
def __len__(self):
return len(self.imports)
def __iter__(self):
return iter(self.imports)
ImportSet._EMPTY = ImportSet._from_imports([])
@total_ordering
class ImportMap(object):
r"""
A map from import fullname identifier to fullname identifier.
>>> ImportMap({'a.b': 'aa.bb', 'a.b.c': 'aa.bb.cc'})
ImportMap({'a.b': 'aa.bb', 'a.b.c': 'aa.bb.cc'})
An ``ImportMap`` is an immutable data structure.
"""
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, (tuple, list)):
return cls._merge(arg)
if isinstance(arg, dict):
if not len(arg):
return cls._EMPTY
return cls._from_map(arg)
else:
raise TypeError("ImportMap: expected a dict, not a %s"
% (type(arg).__name__,))
@classmethod
def _from_map(cls, arg):
data = dict((Import(k).fullname, Import(v).fullname)
for k, v in arg.items())
self = object.__new__(cls)
self._data = data
return self
@classmethod
def _merge(cls, maps):
maps = [cls(m) for m in maps]
maps = [m for m in maps if m]
if not maps:
return cls._EMPTY
data = {}
for map in maps:
data.update(map._data)
return cls(data)
def __getitem__(self, k):
k = Import(k).fullname
return self._data.__getitem__(k)
def __iter__(self):
return iter(self._data)
def items(self):
return self._data.items()
def iteritems(self):
return six.iteritems(self._data)
def iterkeys(self):
return six.iterkeys(self._data)
def keys(self):
return self._data.keys()
def values(self):
return self._data.values()
def __len__(self):
return len(self._data)
def without_imports(self, removals):
"""
Return a copy of self without the given imports.
Matches both keys and values.
"""
removals = ImportSet(removals)
if not removals:
return self # Optimization
cls = type(self)
result = [(k, v) for k, v in self._data.items()
if Import(k) not in removals and Import(v) not in removals]
if len(result) == len(self._data):
return self # Space optimization
return cls(dict(result))
def __repr__(self):
s = ", ".join("%r: %r" % (k,v) for k,v in sorted(self.items()))
return "ImportMap({%s})" % s
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, ImportMap):
return NotImplemented
return self._data == other._data
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, other):
if not isinstance(other, ImportMap):
return NotImplemented
return self._data < other._data
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, ImportMap):
return NotImplemented
return cmp(self._data, other._data)
def __hash__(self):
h = hash(self._data)
self.__hash__ = lambda: h
return h
ImportMap._EMPTY = ImportMap._from_map({})

View file

@ -0,0 +1,581 @@
# pyflyby/_importdb.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from collections import defaultdict
import os
import re
import six
from pyflyby._file import Filename, expand_py_files_from_args
from pyflyby._idents import dotted_prefixes
from pyflyby._importclns import ImportMap, ImportSet
from pyflyby._importstmt import Import, ImportStatement
from pyflyby._log import logger
from pyflyby._parse import PythonBlock
from pyflyby._util import cached_attribute, memoize, stable_unique
@memoize
def _find_etc_dirs():
result = []
dirs = Filename(__file__).real.dir.ancestors[:-1]
for dir in dirs:
candidate = dir / "etc/pyflyby"
if candidate.isdir:
result.append(candidate)
break
global_dir = Filename("/etc/pyflyby")
if global_dir.exists:
result.append(global_dir)
return result
def _get_env_var(env_var_name, default):
'''
Get an environment variable and split on ":", replacing ``-`` with the
default.
'''
assert re.match("^[A-Z_]+$", env_var_name)
assert isinstance(default, (tuple, list))
value = list(filter(None, os.environ.get(env_var_name, '').split(':')))
if not value:
return default
# Replace '-' with ``default``
try:
idx = value.index('-')
except ValueError:
pass
else:
value[idx:idx+1] = default
return value
def _get_python_path(env_var_name, default_path, target_dirname):
'''
Expand an environment variable specifying pyflyby input config files.
- Default to ``default_path`` if the environment variable is undefined.
- Process colon delimiters.
- Replace "-" with ``default_path``.
- Expand triple dots.
- Recursively traverse directories.
:rtype:
``tuple`` of ``Filename`` s
'''
pathnames = _get_env_var(env_var_name, default_path)
if pathnames == ["EMPTY"]:
# The special code PYFLYBY_PATH=EMPTY means we intentionally want to
# use an empty PYFLYBY_PATH (and don't fall back to the default path,
# nor warn about an empty path).
return ()
for p in pathnames:
if re.match("/|[.]/|[.][.][.]/|~/", p):
continue
raise ValueError(
"{env_var_name} components should start with / or ./ or ~/ or .../. "
"Use {env_var_name}=./{p} instead of {env_var_name}={p} if you really "
"want to use the current directory."
.format(env_var_name=env_var_name, p=p))
pathnames = [os.path.expanduser(p) for p in pathnames]
pathnames = _expand_tripledots(pathnames, target_dirname)
pathnames = [Filename(fn) for fn in pathnames]
pathnames = stable_unique(pathnames)
pathnames = expand_py_files_from_args(pathnames)
if not pathnames:
logger.warning(
"No import libraries found (%s=%r, default=%r)"
% (env_var_name, os.environ.get(env_var_name), default_path))
return tuple(pathnames)
# TODO: stop memoizing here after using StatCache. Actually just inline into
# _ancestors_on_same_partition
@memoize
def _get_st_dev(filename):
filename = Filename(filename)
try:
return os.stat(str(filename)).st_dev
except OSError:
return None
def _ancestors_on_same_partition(filename):
"""
Generate ancestors of ``filename`` that exist and are on the same partition
as the first existing ancestor of ``filename``.
For example, suppose a partition is mounted on /u/homer; /u is a different
partition. Suppose /u/homer/aa exists but /u/homer/aa/bb does not exist.
Then::
>>> _ancestors_on_same_partition(Filename("/u/homer/aa/bb/cc")) # doctest: +SKIP
[Filename("/u/homer", Filename("/u/homer/aa")]
:rtype:
``list`` of ``Filename``
"""
result = []
dev = None
for f in filename.ancestors:
this_dev = _get_st_dev(f)
if this_dev is None:
continue
if dev is None:
dev = this_dev
elif dev != this_dev:
break
result.append(f)
return result
def _expand_tripledots(pathnames, target_dirname):
"""
Expand pathnames of the form ``".../foo/bar"`` as "../../foo/bar",
"../foo/bar", "./foo/bar" etc., up to the oldest ancestor with the same
st_dev.
For example, suppose a partition is mounted on /u/homer; /u is a different
partition. Then::
>>> _expand_tripledots(["/foo", ".../tt"], "/u/homer/aa") # doctest: +SKIP
[Filename("/foo"), Filename("/u/homer/tt"), Filename("/u/homer/aa/tt")]
:type pathnames:
sequence of ``str`` (not ``Filename``)
:type target_dirname:
`Filename`
:rtype:
``list`` of `Filename`
"""
target_dirname = Filename(target_dirname)
if not isinstance(pathnames, (tuple, list)):
pathnames = [pathnames]
result = []
for pathname in pathnames:
if not pathname.startswith(".../"):
result.append(Filename(pathname))
continue
suffix = pathname[4:]
expanded = [
p / suffix for p in _ancestors_on_same_partition(target_dirname) ]
result.extend(expanded[::-1])
return result
class ImportDB(object):
"""
A database of known, mandatory, canonical imports.
@iattr known_imports:
Set of known imports. For use by tidy-imports and autoimporter.
@iattr mandatory_imports:
Set of imports that must be added by tidy-imports.
@iattr canonical_imports:
Map of imports that tidy-imports transforms on every run.
@iattr forget_imports:
Set of imports to remove from known_imports, mandatory_imports,
canonical_imports.
"""
def __new__(cls, *args):
if len(args) != 1:
raise TypeError
arg, = args
if isinstance(arg, cls):
return arg
if isinstance(arg, ImportSet):
return cls._from_data(arg, [], [], [])
return cls._from_args(arg) # PythonBlock, Filename, etc
_default_cache = {}
@classmethod
def clear_default_cache(cls):
"""
Clear the class cache of default ImportDBs.
Subsequent calls to ImportDB.get_default() will not reuse previously
cached results. Existing ImportDB instances are not affected by this
call.
"""
if cls._default_cache:
if logger.debug_enabled:
allpyfiles = set()
for tup in cls._default_cache:
if tup[0] != 2:
continue
for tup2 in tup[1:]:
for f in tup2:
assert isinstance(f, Filename)
if f.ext == '.py':
allpyfiles.add(f)
nfiles = len(allpyfiles)
logger.debug("ImportDB: Clearing default cache of %d files",
nfiles)
cls._default_cache.clear()
@classmethod
def get_default(cls, target_filename):
"""
Return the default import library for the given target filename.
This will read various .../.pyflyby files as specified by
$PYFLYBY_PATH as well as older deprecated environment variables.
Memoized.
:param target_filename:
The target filename for which to get the import database. Note that
the target filename itself is not read. Instead, the target
filename is relevant because we look for .../.pyflyby based on the
target filename.
:rtype:
`ImportDB`
"""
# We're going to canonicalize target_filenames in a number of steps.
# At each step, see if we've seen the input so far. We do the cache
# checking incrementally since the steps involve syscalls. Since this
# is going to potentially be executed inside the IPython interactive
# loop, we cache as much as possible.
# TODO: Consider refreshing periodically. Check if files have
# been touched, and if so, return new data. Check file timestamps at
# most once every 60 seconds.
cache_keys = []
target_filename = Filename(target_filename or ".")
if target_filename.startswith("/dev"):
target_filename = Filename(".")
target_dirname = target_filename
# TODO: with StatCache
while True:
cache_keys.append((1,
target_dirname,
os.getenv("PYFLYBY_PATH"),
os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"),
os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH")))
try:
return cls._default_cache[cache_keys[-1]]
except KeyError:
pass
if target_dirname.isdir:
break
target_dirname = target_dirname.dir
target_dirname = target_dirname.real
if target_dirname != cache_keys[-1][0]:
cache_keys.append((1,
target_dirname,
os.getenv("PYFLYBY_PATH"),
os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"),
os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH")))
try:
return cls._default_cache[cache_keys[-1]]
except KeyError:
pass
DEFAULT_PYFLYBY_PATH = []
DEFAULT_PYFLYBY_PATH += [str(p) for p in _find_etc_dirs()]
DEFAULT_PYFLYBY_PATH += [
".../.pyflyby",
"~/.pyflyby",
]
logger.debug("DEFAULT_PYFLYBY_PATH=%s", DEFAULT_PYFLYBY_PATH)
filenames = _get_python_path("PYFLYBY_PATH", DEFAULT_PYFLYBY_PATH,
target_dirname)
mandatory_imports_filenames = ()
if "SUPPORT DEPRECATED BEHAVIOR":
PYFLYBY_PATH = _get_env_var("PYFLYBY_PATH", DEFAULT_PYFLYBY_PATH)
# If the old deprecated environment variables are set, then heed
# them.
if os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"):
# Use PYFLYBY_PATH as the default for
# PYFLYBY_KNOWN_IMPORTS_PATH. Note that the default is
# relevant even though we only enter this code path when the
# variable is set to anything, because the env var can
# reference "-" to include the default.
# Before pyflyby version 0.8, the default value would have
# been
# [d/"known_imports" for d in PYFLYBY_PATH]
# Instead of using that, we just use PYFLYBY_PATH directly as
# the default. This simplifies things and avoids need for a
# "known_imports=>." symlink for backwards compatibility. It
# means that ~/.pyflyby/**/*.py (as opposed to only
# ~/.pyflyby/known_imports/**/*.py) would be included.
# Although this differs slightly from the old behavior, it
# matches the behavior of the newer PYFLYBY_PATH; matching the
# new behavior seems higher utility than exactly matching the
# old behavior. Files under ~/.pyflyby/mandatory_imports will
# be included in known_imports as well, but that should not
# cause any problems.
default_path = PYFLYBY_PATH
# Expand $PYFLYBY_KNOWN_IMPORTS_PATH.
filenames = _get_python_path(
"PYFLYBY_KNOWN_IMPORTS_PATH", default_path, target_dirname)
logger.debug(
"The environment variable PYFLYBY_KNOWN_IMPORTS_PATH is deprecated. "
"Use PYFLYBY_PATH.")
if os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH"):
# Compute the "default" path.
# Note that we still calculate the erstwhile default value,
# even though it's no longer the defaults, in order to still
# allow the "-" in the variable.
default_path = [
os.path.join(d,"mandatory_imports") for d in PYFLYBY_PATH]
# Expand $PYFLYBY_MANDATORY_IMPORTS_PATH.
mandatory_imports_filenames = _get_python_path(
"PYFLYBY_MANDATORY_IMPORTS_PATH",
default_path, target_dirname)
logger.debug(
"The environment variable PYFLYBY_MANDATORY_IMPORTS_PATH is deprecated. "
"Use PYFLYBY_PATH and write __mandatory_imports__=['...'] in your files.")
cache_keys.append((2, filenames, mandatory_imports_filenames))
try:
return cls._default_cache[cache_keys[-1]]
except KeyError:
pass
result = cls._from_filenames(filenames, mandatory_imports_filenames)
for k in cache_keys:
cls._default_cache[k] = result
return result
@classmethod
def interpret_arg(cls, arg, target_filename):
if arg is None:
return cls.get_default(target_filename)
else:
return cls(arg)
@classmethod
def _from_data(cls, known_imports, mandatory_imports,
canonical_imports, forget_imports):
self = object.__new__(cls)
self.forget_imports = ImportSet(forget_imports )
self.known_imports = ImportSet(known_imports ).without_imports(forget_imports)
self.mandatory_imports = ImportSet(mandatory_imports).without_imports(forget_imports)
# TODO: provide more fine-grained control about canonical_imports.
self.canonical_imports = ImportMap(canonical_imports).without_imports(forget_imports)
return self
@classmethod
def _from_args(cls, args):
# TODO: support merging input ImportDBs. For now we support
# `PythonBlock` s and convertibles such as `Filename`.
return cls._from_code(args)
@classmethod
def _from_code(cls, blocks,
_mandatory_imports_blocks_deprecated=(),
_forget_imports_blocks_deprecated=(),
):
"""
Load an import database from code.
>>> ImportDB._from_code('''
... import foo, bar as barf
... from xx import yy
... __mandatory_imports__ = ['__future__.division',
... 'import aa . bb . cc as dd']
... __forget_imports__ = ['xx.yy', 'from xx import zz']
... __canonical_imports__ = {'bad.baad': 'good.goood'}
... ''')
ImportDB('''
import bar as barf
import foo
<BLANKLINE>
__mandatory_imports__ = [
'from __future__ import division',
'from aa.bb import cc as dd',
]
<BLANKLINE>
__canonical_imports__ = {
'bad.baad': 'good.goood',
}
<BLANKLINE>
__forget_imports__ = [
'from xx import yy',
'from xx import zz',
]
''')
:rtype:
`ImportDB`
"""
if not isinstance(blocks, (tuple, list)):
blocks = [blocks]
if not isinstance(_mandatory_imports_blocks_deprecated, (tuple, list)):
_mandatory_imports_blocks_deprecated = [_mandatory_imports_blocks_deprecated]
if not isinstance(_forget_imports_blocks_deprecated, (tuple, list)):
_forget_imports_blocks_deprecated = [_forget_imports_blocks_deprecated]
known_imports = []
mandatory_imports = []
canonical_imports = []
forget_imports = []
blocks = [PythonBlock(b) for b in blocks]
for block in blocks:
for statement in block.statements:
if statement.is_comment_or_blank:
continue
if statement.is_import:
known_imports.extend(ImportStatement(statement).imports)
continue
try:
name, value = statement.get_assignment_literal_value()
if name == "__mandatory_imports__":
mandatory_imports.append(cls._parse_import_set(value))
elif name == "__canonical_imports__":
canonical_imports.append(cls._parse_import_map(value))
elif name == "__forget_imports__":
forget_imports.append(cls._parse_import_set(value))
else:
raise ValueError(
"Unknown assignment to %r (expected one of "
"__mandatory_imports__, __canonical_imports__, "
"__forget_imports__)" % (name,))
except ValueError as e:
raise ValueError(
"While parsing %s: error in %r: %s"
% (block.filename, statement, e))
for block in _mandatory_imports_blocks_deprecated:
mandatory_imports.append(ImportSet(block))
for block in _forget_imports_blocks_deprecated:
forget_imports.append(ImportSet(block))
return cls._from_data(known_imports,
mandatory_imports,
canonical_imports,
forget_imports)
@classmethod
def _from_filenames(cls, filenames, _mandatory_filenames_deprecated=[]):
"""
Load an import database from filenames.
This function exists to support deprecated behavior.
When we stop supporting the old behavior, we will delete this function.
:type filenames:
Sequence of `Filename` s
:param filenames:
Filenames of files to read.
:rtype:
`ImportDB`
"""
if not isinstance(filenames, (tuple, list)):
filenames = [filenames]
filenames = [Filename(f) for f in filenames]
logger.debug("ImportDB: loading [%s], mandatory=[%s]",
', '.join(map(str, filenames)),
', '.join(map(str, _mandatory_filenames_deprecated)))
if "SUPPORT DEPRECATED BEHAVIOR":
# Before 2014-10, pyflyby read the following:
# * known_imports from $PYFLYBY_PATH/known_imports/**/*.py or
# $PYFLYBY_KNOWN_IMPORTS_PATH/**/*.py,
# * mandatory_imports from $PYFLYBY_PATH/mandatory_imports/**/*.py or
# $PYFLYBY_MANDATORY_IMPORTS_PATH/**/*.py, and
# * forget_imports from $PYFLYBY_PATH/known_imports/**/__remove__.py
# After 2014-10, pyflyby reads the following:
# * $PYFLYBY_PATH/**/*.py
# (with directives inside the file)
# For backwards compatibility, for now we continue supporting the
# old, deprecated behavior.
blocks = []
mandatory_imports_blocks = [
Filename(f) for f in _mandatory_filenames_deprecated]
forget_imports_blocks = []
for filename in filenames:
if filename.base == "__remove__.py":
forget_imports_blocks.append(filename)
elif "mandatory_imports" in str(filename).split("/"):
mandatory_imports_blocks.append(filename)
else:
blocks.append(filename)
return cls._from_code(
blocks, mandatory_imports_blocks, forget_imports_blocks)
else:
return cls._from_code(filenames)
@classmethod
def _parse_import_set(cls, arg):
if isinstance(arg, six.string_types):
arg = [arg]
if not isinstance(arg, (tuple, list)):
raise ValueError("Expected a list, not a %s" % (type(arg).__name__,))
for item in arg:
if not isinstance(item, six.string_types):
raise ValueError(
"Expected a list of str, not %s" % (type(item).__name__,))
return ImportSet(arg)
@classmethod
def _parse_import_map(cls, arg):
if isinstance(arg, six.string_types):
arg = [arg]
if not isinstance(arg, dict):
raise ValueError("Expected a dict, not a %s" % (type(arg).__name__,))
for k, v in arg.items():
if not isinstance(k, six.string_types):
raise ValueError(
"Expected a dict of str, not %s" % (type(k).__name__,))
if not isinstance(v, six.string_types):
raise ValueError(
"Expected a dict of str, not %s" % (type(v).__name__,))
return ImportMap(arg)
@cached_attribute
def by_fullname_or_import_as(self):
"""
Map from ``fullname`` and ``import_as`` to `Import` s.
>>> import pprint
>>> db = ImportDB('from aa.bb import cc as dd')
>>> pprint.pprint(db.by_fullname_or_import_as)
{'aa': (Import('import aa'),),
'aa.bb': (Import('import aa.bb'),),
'dd': (Import('from aa.bb import cc as dd'),)}
:rtype:
``dict`` mapping from ``str`` to tuple of `Import` s
"""
# TODO: make known_imports take into account the below forget_imports,
# then move this function into ImportSet
d = defaultdict(set)
for imp in self.known_imports.imports:
# Given an import like "from foo.bar import quux as QUUX", add the
# following entries:
# - "QUUX" => "from foo.bar import quux as QUUX"
# - "foo.bar" => "import foo.bar"
# - "foo" => "import foo"
# We don't include an entry labeled "quux" because the user has
# implied he doesn't want to pollute the global namespace with
# "quux", only "QUUX".
d[imp.import_as].add(imp)
for prefix in dotted_prefixes(imp.fullname)[:-1]:
d[prefix].add(Import.from_parts(prefix, prefix))
return dict( (k, tuple(sorted(v - set(self.forget_imports.imports))))
for k, v in six.iteritems(d))
def __repr__(self):
printed = self.pretty_print()
lines = "".join(" "+line for line in printed.splitlines(True))
return "%s('''\n%s''')" % (type(self).__name__, lines)
def pretty_print(self):
s = self.known_imports.pretty_print()
if self.mandatory_imports:
s += "\n__mandatory_imports__ = [\n"
for imp in self.mandatory_imports.imports:
s += " '%s',\n" % imp
s += "]\n"
if self.canonical_imports:
s += "\n__canonical_imports__ = {\n"
for k, v in sorted(self.canonical_imports.items()):
s += " '%s': '%s',\n" % (k, v)
s += "}\n"
if self.forget_imports:
s += "\n__forget_imports__ = [\n"
for imp in self.forget_imports.imports:
s += " '%s',\n" % imp
s += "]\n"
return s

View file

@ -0,0 +1,606 @@
# pyflyby/_imports2s.py.
# Copyright (C) 2011-2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from pyflyby._autoimp import scan_for_import_issues
from pyflyby._file import FileText, Filename
from pyflyby._flags import CompilerFlags
from pyflyby._importclns import ImportSet, NoSuchImportError
from pyflyby._importdb import ImportDB
from pyflyby._importstmt import ImportFormatParams, ImportStatement
from pyflyby._log import logger
from pyflyby._parse import PythonBlock
from pyflyby._util import ImportPathCtx, Inf, NullCtx, memoize
import re
from six import exec_
class SourceToSourceTransformationBase(object):
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, (PythonBlock, FileText, Filename, str)):
return cls._from_source_code(arg)
raise TypeError("%s: got unexpected %s"
% (cls.__name__, type(arg).__name__))
@classmethod
def _from_source_code(cls, codeblock):
self = object.__new__(cls)
self.input = PythonBlock(codeblock)
self.preprocess()
return self
def preprocess(self):
pass
def pretty_print(self, params=None):
raise NotImplementedError
def output(self, params=None):
"""
Pretty-print and return as a `PythonBlock`.
:rtype:
`PythonBlock`
"""
result = self.pretty_print(params=params)
result = PythonBlock(result, filename=self.input.filename)
return result
class SourceToSourceTransformation(SourceToSourceTransformationBase):
def preprocess(self):
self.output = self.input
def pretty_print(self, params=None):
return self.output.text
class SourceToSourceImportBlockTransformation(SourceToSourceTransformationBase):
def preprocess(self):
self.importset = ImportSet(self.input, ignore_shadowed=True)
def pretty_print(self, params=None):
params = ImportFormatParams(params)
return self.importset.pretty_print(params)
class LineNumberNotFoundError(Exception):
pass
class LineNumberAmbiguousError(Exception):
pass
class NoImportBlockError(Exception):
pass
class ImportAlreadyExistsError(Exception):
pass
class SourceToSourceFileImportsTransformation(SourceToSourceTransformationBase):
def preprocess(self):
# Group into blocks of imports and non-imports. Get a sequence of all
# imports for the transformers to operate on.
self.blocks = []
self.import_blocks = []
for is_imports, subblock in self.input.groupby(lambda ps: ps.is_import):
if is_imports:
trans = SourceToSourceImportBlockTransformation(subblock)
self.import_blocks.append(trans)
else:
trans = SourceToSourceTransformation(subblock)
self.blocks.append(trans)
def pretty_print(self, params=None):
params = ImportFormatParams(params)
result = [block.pretty_print(params=params) for block in self.blocks]
return FileText.concatenate(result)
def find_import_block_by_lineno(self, lineno):
"""
Find the import block containing the given line number.
:type lineno:
``int``
:rtype:
`SourceToSourceImportBlockTransformation`
"""
results = [
b
for b in self.import_blocks
if b.input.startpos.lineno <= lineno <= b.input.endpos.lineno]
if len(results) == 0:
raise LineNumberNotFoundError(lineno)
if len(results) > 1:
raise LineNumberAmbiguousError(lineno)
return results[0]
def remove_import(self, imp, lineno):
"""
Remove the given import.
:type imp:
`Import`
:type lineno:
``int``
"""
block = self.find_import_block_by_lineno(lineno)
try:
imports = block.importset.by_import_as[imp.import_as]
except KeyError:
raise NoSuchImportError
assert len(imports)
if len(imports) > 1:
raise Exception("Multiple imports to remove: %r" % (imports,))
imp = imports[0]
block.importset = block.importset.without_imports([imp])
return imp
def select_import_block_by_closest_prefix_match(self, imp, max_lineno):
"""
Heuristically pick an import block that ``imp`` "fits" best into. The
selection is based on the block that contains the import with the
longest common prefix.
:type imp:
`Import`
:param max_lineno:
Only return import blocks earlier than ``max_lineno``.
:rtype:
`SourceToSourceImportBlockTransformation`
"""
# Create a data structure that annotates blocks with data by which
# we'll sort.
annotated_blocks = [
( (max([0] + [len(imp.prefix_match(oimp))
for oimp in block.importset.imports]),
block.input.endpos.lineno),
block )
for block in self.import_blocks
if block.input.endpos.lineno <= max_lineno ]
if not annotated_blocks:
raise NoImportBlockError()
annotated_blocks.sort()
if imp.split.module_name == '__future__':
# For __future__ imports, only add to an existing block that
# already contains __future__ import(s). If there are no existing
# import blocks containing __future__, don't return any result
# here, so that we will add a new one at the top.
if not annotated_blocks[-1][0][0] > 0:
raise NoImportBlockError
return annotated_blocks[-1][1]
def insert_new_blocks_after_comments(self, blocks):
blocks = [SourceToSourceTransformationBase(block) for block in blocks]
if isinstance(self.blocks[0], SourceToSourceImportBlockTransformation):
# Kludge. We should add an "output" attribute to
# SourceToSourceImportBlockTransformation and enumerate over that,
# instead of enumerating over the input below.
self.blocks[0:0] = blocks
return
# Get the "statements" in the first block.
statements = self.blocks[0].input.statements
# Find the insertion point.
for idx, statement in enumerate(statements):
if not statement.is_comment_or_blank_or_string_literal:
if idx == 0:
# First block starts with a noncomment, so insert before
# it.
self.blocks[0:0] = blocks
else:
# Found a non-comment after comment, so break it up and
# insert in the middle.
self.blocks[:1] = (
[SourceToSourceTransformation(
PythonBlock.concatenate(statements[:idx],
assume_contiguous=True))] +
blocks +
[SourceToSourceTransformation(
PythonBlock.concatenate(statements[idx:],
assume_contiguous=True))])
break
else:
# First block is entirely comments, so just insert after it.
self.blocks[1:1] = blocks
def insert_new_import_block(self):
"""
Adds a new empty imports block. It is added before the first
non-comment statement. Intended to be used when the input contains no
import blocks (before uses).
"""
block = SourceToSourceImportBlockTransformation("")
sepblock = SourceToSourceTransformation("")
sepblock.output = PythonBlock("\n")
self.insert_new_blocks_after_comments([block, sepblock])
self.import_blocks.insert(0, block)
return block
def add_import(self, imp, lineno=Inf):
"""
Add the specified import. Picks an existing global import block to
add to, or if none found, creates a new one near the beginning of the
module.
:type imp:
`Import`
:param lineno:
Line before which to add the import. ``Inf`` means no constraint.
"""
try:
block = self.select_import_block_by_closest_prefix_match(
imp, lineno)
except NoImportBlockError:
block = self.insert_new_import_block()
if imp in block.importset.imports:
raise ImportAlreadyExistsError(imp)
block.importset = block.importset.with_imports([imp])
def reformat_import_statements(codeblock, params=None):
r"""
Reformat each top-level block of import statements within a block of code.
Blank lines, comments, etc. are left alone and separate blocks of imports.
Parse the entire code block into an ast, group into consecutive import
statements and other lines. Each import block consists entirely of
'import' (or 'from ... import') statements. Other lines, including blanks
and comment lines, are not touched.
>>> print(reformat_import_statements(
... 'from foo import bar2 as bar2x, bar1\n'
... 'import foo.bar3 as bar3x\n'
... 'import foo.bar4\n'
... '\n'
... 'import foo.bar0 as bar0\n').text.joined)
import foo.bar4
from foo import bar1, bar2 as bar2x, bar3 as bar3x
<BLANKLINE>
from foo import bar0
<BLANKLINE>
:type codeblock:
`PythonBlock` or convertible (``str``)
:type params:
`ImportFormatParams`
:rtype:
`PythonBlock`
"""
params = ImportFormatParams(params)
transformer = SourceToSourceFileImportsTransformation(codeblock)
return transformer.output(params=params)
def ImportPathForRelativeImportsCtx(codeblock):
"""
Context manager that temporarily modifies ``sys.path`` so that relative
imports for the given ``codeblock`` work as expected.
:type codeblock:
`PythonBlock`
"""
codeblock = PythonBlock(codeblock)
if not codeblock.filename:
return NullCtx()
if codeblock.flags & CompilerFlags("absolute_import"):
return NullCtx()
return ImportPathCtx(str(codeblock.filename.dir))
def fix_unused_and_missing_imports(codeblock,
add_missing=True,
remove_unused="AUTOMATIC",
add_mandatory=True,
db=None,
params=None):
r"""
Check for unused and missing imports, and fix them automatically.
Also formats imports.
In the example below, ``m1`` and ``m3`` are unused, so are automatically
removed. ``np`` was undefined, so an ``import numpy as np`` was
automatically added.
>>> codeblock = PythonBlock(
... 'from foo import m1, m2, m3, m4\n'
... 'm2, m4, np.foo', filename="/tmp/foo.py")
>>> print(fix_unused_and_missing_imports(codeblock, add_mandatory=False))
[PYFLYBY] /tmp/foo.py: removed unused 'from foo import m1'
[PYFLYBY] /tmp/foo.py: removed unused 'from foo import m3'
[PYFLYBY] /tmp/foo.py: added 'import numpy as np'
import numpy as np
from foo import m2, m4
m2, m4, np.foo
:type codeblock:
`PythonBlock` or convertible (``str``)
:rtype:
`PythonBlock`
"""
codeblock = PythonBlock(codeblock)
if remove_unused == "AUTOMATIC":
fn = codeblock.filename
remove_unused = not (fn and
(fn.base == "__init__.py"
or ".pyflyby" in str(fn).split("/")))
elif remove_unused is True or remove_unused is False:
pass
else:
raise ValueError("Invalid remove_unused=%r" % (remove_unused,))
params = ImportFormatParams(params)
db = ImportDB.interpret_arg(db, target_filename=codeblock.filename)
# Do a first pass reformatting the imports to get rid of repeated or
# shadowed imports, e.g. L1 here:
# import foo # L1
# import foo # L2
# foo # L3
codeblock = reformat_import_statements(codeblock, params=params)
filename = codeblock.filename
transformer = SourceToSourceFileImportsTransformation(codeblock)
missing_imports, unused_imports = scan_for_import_issues(
codeblock, find_unused_imports=remove_unused, parse_docstrings=True)
logger.debug("missing_imports = %r", missing_imports)
logger.debug("unused_imports = %r", unused_imports)
if remove_unused and unused_imports:
# Go through imports to remove. [This used to be organized by going
# through import blocks and removing all relevant blocks from there,
# but if one removal caused problems the whole thing would fail. The
# CPU cost of calling without_imports() multiple times isn't worth
# that.]
# TODO: don't remove unused mandatory imports. [This isn't
# implemented yet because this isn't necessary for __future__ imports
# since they aren't reported as unused, and those are the only ones we
# have by default right now.]
for lineno, imp in unused_imports:
try:
imp = transformer.remove_import(imp, lineno)
except NoSuchImportError:
logger.error(
"%s: couldn't remove import %r", filename, imp,)
except LineNumberNotFoundError as e:
logger.error(
"%s: unused import %r on line %d not global",
filename, str(imp), e.args[0])
else:
logger.info("%s: removed unused '%s'", filename, imp)
if add_missing and missing_imports:
missing_imports.sort(key=lambda k: (k[1], k[0]))
known = db.known_imports.by_import_as
# Decide on where to put each import to be added. Find the import
# block with the longest common prefix. Tie-break by preferring later
# blocks.
added_imports = set()
for lineno, ident in missing_imports:
import_as = ident.parts[0]
try:
imports = known[import_as]
except KeyError:
logger.warning(
"%s:%s: undefined name %r and no known import for it",
filename, lineno, import_as)
continue
if len(imports) != 1:
logger.error("%s: don't know which of %r to use",
filename, imports)
continue
imp_to_add = imports[0]
if imp_to_add in added_imports:
continue
transformer.add_import(imp_to_add, lineno)
added_imports.add(imp_to_add)
logger.info("%s: added %r", filename,
imp_to_add.pretty_print().strip())
if add_mandatory:
# Todo: allow not adding to empty __init__ files?
mandatory = db.mandatory_imports.imports
for imp in mandatory:
try:
transformer.add_import(imp)
except ImportAlreadyExistsError:
pass
else:
logger.info("%s: added mandatory %r",
filename, imp.pretty_print().strip())
return transformer.output(params=params)
def remove_broken_imports(codeblock, params=None):
"""
Try to execute each import, and remove the ones that don't work.
Also formats imports.
:type codeblock:
`PythonBlock` or convertible (``str``)
:rtype:
`PythonBlock`
"""
codeblock = PythonBlock(codeblock)
params = ImportFormatParams(params)
filename = codeblock.filename
transformer = SourceToSourceFileImportsTransformation(codeblock)
for block in transformer.import_blocks:
broken = []
for imp in list(block.importset.imports):
ns = {}
try:
exec_(imp.pretty_print(), ns)
except Exception as e:
logger.info("%s: Could not import %r; removing it: %s: %s",
filename, imp.fullname, type(e).__name__, e)
broken.append(imp)
block.importset = block.importset.without_imports(broken)
return transformer.output(params=params)
def replace_star_imports(codeblock, params=None):
r"""
Replace lines such as::
from foo.bar import *
with
from foo.bar import f1, f2, f3
Note that this requires involves actually importing ``foo.bar``, which may
have side effects. (TODO: rewrite to avoid this?)
The result includes all imports from the ``email`` module. The result
excludes shadowed imports. In this example:
1. The original ``MIMEAudio`` import is shadowed, so it is removed.
2. The ``MIMEImage`` import in the ``email`` module is shadowed by a
subsequent import, so it is omitted.
>>> codeblock = PythonBlock('from keyword import *', filename="/tmp/x.py")
>>> print(replace_star_imports(codeblock)) # doctest: +SKIP
[PYFLYBY] /tmp/x.py: replaced 'from keyword import *' with 2 imports
from keyword import iskeyword, kwlist
<BLANKLINE>
Usually you'll want to remove unused imports after replacing star imports.
:type codeblock:
`PythonBlock` or convertible (``str``)
:rtype:
`PythonBlock`
"""
from pyflyby._modules import ModuleHandle
params = ImportFormatParams(params)
codeblock = PythonBlock(codeblock)
filename = codeblock.filename
transformer = SourceToSourceFileImportsTransformation(codeblock)
for block in transformer.import_blocks:
# Iterate over the import statements in ``block.input``. We do this
# instead of using ``block.importset`` because the latter doesn't
# preserve the order of inputs. The order is important for
# determining what's shadowed.
imports = [
imp
for s in block.input.statements
for imp in ImportStatement(s).imports
]
# Process "from ... import *" statements.
new_imports = []
for imp in imports:
if imp.split.member_name != "*":
new_imports.append(imp)
elif imp.split.module_name.startswith("."):
# The source contains e.g. "from .foo import *". Right now we
# don't have a good way to figure out the absolute module
# name, so we can't get at foo. That said, there's a decent
# chance that this is inside an __init__ anyway, which is one
# of the few justifiable use cases for star imports in library
# code.
logger.warning("%s: can't replace star imports in relative import: %s",
filename, imp.pretty_print().strip())
new_imports.append(imp)
else:
module = ModuleHandle(imp.split.module_name)
try:
with ImportPathForRelativeImportsCtx(codeblock):
exports = module.exports
except Exception as e:
logger.warning(
"%s: couldn't import '%s' to enumerate exports, "
"leaving unchanged: '%s'. %s: %s",
filename, module.name, imp, type(e).__name__, e)
new_imports.append(imp)
continue
if not exports:
# We found nothing in the target module. This probably
# means that module itself is just importing things from
# other modules. Currently we intentionally exclude those
# imports since usually we don't want them. TODO: do
# something better here.
logger.warning("%s: found nothing to import from %s, ",
"leaving unchanged: '%s'",
filename, module, imp)
new_imports.append(imp)
else:
new_imports.extend(exports)
logger.info("%s: replaced %r with %d imports", filename,
imp.pretty_print().strip(), len(exports))
block.importset = ImportSet(new_imports, ignore_shadowed=True)
return transformer.output(params=params)
def transform_imports(codeblock, transformations, params=None):
"""
Transform imports as specified by ``transformations``.
transform_imports() perfectly replaces all imports in top-level import
blocks.
For the rest of the code body, transform_imports() does a crude textual
string replacement. This is imperfect but handles most cases. There may
be some false positives, but this is difficult to avoid. Generally we do
want to do replacements even within in strings and comments.
>>> result = transform_imports("from m import x", {"m.x": "m.y.z"})
>>> print(result.text.joined.strip())
from m.y import z as x
:type codeblock:
`PythonBlock` or convertible (``str``)
:type transformations:
``dict`` from ``str`` to ``str``
:param transformations:
A map of import prefixes to replace, e.g. {"aa.bb": "xx.yy"}
:rtype:
`PythonBlock`
"""
codeblock = PythonBlock(codeblock)
params = ImportFormatParams(params)
transformer = SourceToSourceFileImportsTransformation(codeblock)
@memoize
def transform_import(imp):
# Transform a block of imports.
# TODO: optimize
# TODO: handle transformations containing both a.b=>x and a.b.c=>y
for k, v in transformations.items():
imp = imp.replace(k, v)
return imp
def transform_block(block):
# Do a crude string replacement in the PythonBlock.
block = PythonBlock(block)
s = block.text.joined
for k, v in transformations.items():
s = re.sub("\\b%s\\b" % (re.escape(k)), v, s)
return PythonBlock(s, flags=block.flags)
# Loop over transformer blocks.
for block in transformer.blocks:
if isinstance(block, SourceToSourceImportBlockTransformation):
input_imports = block.importset.imports
output_imports = [ transform_import(imp) for imp in input_imports ]
block.importset = ImportSet(output_imports, ignore_shadowed=True)
else:
block.output = transform_block(block.input)
return transformer.output(params=params)
def canonicalize_imports(codeblock, params=None, db=None):
"""
Transform ``codeblock`` as specified by ``__canonical_imports__`` in the
global import library.
:type codeblock:
`PythonBlock` or convertible (``str``)
:rtype:
`PythonBlock`
"""
codeblock = PythonBlock(codeblock)
params = ImportFormatParams(params)
db = ImportDB.interpret_arg(db, target_filename=codeblock.filename)
transformations = db.canonical_imports
return transform_imports(codeblock, transformations, params=params)

View file

@ -0,0 +1,534 @@
# pyflyby/_importstmt.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
import ast
from collections import namedtuple
from functools import total_ordering
from pyflyby._flags import CompilerFlags
from pyflyby._format import FormatParams, pyfill
from pyflyby._idents import is_identifier
from pyflyby._parse import PythonStatement
from pyflyby._util import (Inf, cached_attribute, cmp,
longest_common_prefix)
class ImportFormatParams(FormatParams):
align_imports = True
"""
Whether and how to align 'from modulename import aliases...'. If ``True``,
then the 'import' keywords will be aligned within a block. If an integer,
then the 'import' keyword will always be at that column. They will be
wrapped if necessary.
"""
from_spaces = 1
"""
The number of spaces after the 'from' keyword. (Must be at least 1.)
"""
separate_from_imports = True
"""
Whether all 'from ... import ...' in an import block should come after
'import ...' statements. ``separate_from_imports = False`` works well with
``from_spaces = 3``. ('from __future__ import ...' always comes first.)
"""
align_future = False
"""
Whether 'from __future__ import ...' statements should be aligned with
others. If False, uses a single space after the 'from' and 'import'
keywords.
"""
class NonImportStatementError(TypeError):
"""
Unexpectedly got a statement that wasn't an import.
"""
ImportSplit = namedtuple("ImportSplit",
"module_name member_name import_as")
"""
Representation of a single import at the token level::
from [...]<module_name> import <member_name> as <import_as>
If <module_name> is ``None``, then there is no "from" clause; instead just::
import <member_name> as <import_as>
"""
@total_ordering
class Import(object):
"""
Representation of the desire to import a single name into the current
namespace.
>>> Import.from_parts(".foo.bar", "bar")
Import('from .foo import bar')
>>> Import("from . import foo")
Import('from . import foo')
>>> Import("from . import foo").fullname
'.foo'
>>> Import("import foo . bar")
Import('import foo.bar')
>>> Import("import foo . bar as baz")
Import('from foo import bar as baz')
>>> Import("import foo . bar as bar")
Import('from foo import bar')
>>> Import("foo.bar")
Import('from foo import bar')
"""
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, ImportSplit):
return cls.from_split(arg)
if isinstance(arg, (ImportStatement, PythonStatement)):
return cls._from_statement(arg)
if isinstance(arg, str):
return cls._from_identifier_or_statement(arg)
raise TypeError
@classmethod
def from_parts(cls, fullname, import_as):
if not isinstance(fullname, str):
raise TypeError
if not isinstance(import_as, str):
raise TypeError
self = object.__new__(cls)
self.fullname = fullname
self.import_as = import_as
return self
@classmethod
def _from_statement(cls, statement):
"""
:type statement:
`ImportStatement` or convertible (`PythonStatement`, ``str``)
:rtype:
`Import`
"""
statement = ImportStatement(statement)
imports = statement.imports
if len(imports) != 1:
raise ValueError(
"Got %d imports instead of 1 in %r" % (len(imports), statement))
return imports[0]
@classmethod
def _from_identifier_or_statement(cls, arg):
"""
Parse either a raw identifier or a statement.
>>> Import._from_identifier_or_statement('foo.bar.baz')
Import('from foo.bar import baz')
>>> Import._from_identifier_or_statement('import foo.bar.baz')
Import('import foo.bar.baz')
:rtype:
`Import`
"""
if is_identifier(arg, dotted=True):
return cls.from_parts(arg, arg.split('.')[-1])
else:
return cls._from_statement(arg)
@cached_attribute
def split(self):
"""
Split this `Import` into a ``ImportSplit`` which represents the
token-level ``module_name``, ``member_name``, ``import_as``.
Note that at the token level, ``import_as`` can be ``None`` to represent
that the import statement doesn't have an "as ..." clause, whereas the
``import_as`` attribute on an ``Import`` object is never ``None``.
>>> Import.from_parts(".foo.bar", "bar").split
ImportSplit(module_name='.foo', member_name='bar', import_as=None)
>>> Import("from . import foo").split
ImportSplit(module_name='.', member_name='foo', import_as=None)
>>> Import.from_parts(".foo", "foo").split
ImportSplit(module_name='.', member_name='foo', import_as=None)
>>> Import.from_parts("foo.bar", "foo.bar").split
ImportSplit(module_name=None, member_name='foo.bar', import_as=None)
:rtype:
`ImportSplit`
"""
if self.import_as == self.fullname:
return ImportSplit(None, self.fullname, None)
level = 0
qname = self.fullname
for level, char in enumerate(qname):
if char != '.':
break
prefix = qname[:level]
qname = qname[level:]
if '.' in qname:
module_name, member_name = qname.rsplit(".", 1)
else:
module_name = ''
member_name = qname
module_name = prefix + module_name
import_as = self.import_as
if import_as == member_name:
import_as = None
return ImportSplit(module_name or None, member_name, import_as)
@classmethod
def from_split(cls, impsplit):
"""
Construct an `Import` instance from ``module_name``, ``member_name``,
``import_as``.
:rtype:
`Import`
"""
impsplit = ImportSplit(*impsplit)
module_name, member_name, import_as = impsplit
if import_as is None:
import_as = member_name
if module_name is None:
result = cls.from_parts(member_name, import_as)
else:
fullname = "%s%s%s" % (
module_name,
"" if module_name.endswith(".") else ".",
member_name)
result = cls.from_parts(fullname, import_as)
# result.split will usually be the same as impsplit, but could be
# different if the input was 'import foo.bar as baz', which we
# canonicalize to 'from foo import bar as baz'.
return result
def prefix_match(self, imp):
"""
Return the longest common prefix between ``self`` and ``imp``.
>>> Import("import ab.cd.ef").prefix_match(Import("import ab.cd.xy"))
('ab', 'cd')
:type imp:
`Import`
:rtype:
``tuple`` of ``str``
"""
imp = Import(imp)
n1 = self.fullname.split('.')
n2 = imp.fullname.split('.')
return tuple(longest_common_prefix(n1, n2))
def replace(self, prefix, replacement):
"""
Return a new ``Import`` that replaces ``prefix`` with ``replacement``.
>>> Import("from aa.bb import cc").replace("aa.bb", "xx.yy")
Import('from xx.yy import cc')
>>> Import("from aa import bb").replace("aa.bb", "xx.yy")
Import('from xx import yy as bb')
:rtype:
``Import``
"""
prefix_parts = prefix.split('.')
replacement_parts = replacement.split('.')
fullname_parts = self.fullname.split('.')
if fullname_parts[:len(prefix_parts)] != prefix_parts:
# No prefix match.
return self
fullname_parts[:len(prefix_parts)] = replacement_parts
import_as_parts = self.import_as.split('.')
if import_as_parts[:len(prefix_parts)] == prefix_parts:
import_as_parts[:len(prefix_parts)] = replacement_parts
return self.from_parts('.'.join(fullname_parts),
'.'.join(import_as_parts))
@cached_attribute
def flags(self):
"""
If this is a __future__ import, then the compiler_flag associated with
it. Otherwise, 0.
"""
if self.split.module_name == "__future__":
return CompilerFlags(self.split.member_name)
else:
return CompilerFlags.from_int(0)
@property
def _data(self):
return (self.fullname, self.import_as)
def pretty_print(self, params=FormatParams()):
return ImportStatement([self]).pretty_print(params)
def __str__(self):
return self.pretty_print(FormatParams(max_line_length=Inf)).rstrip()
def __repr__(self):
return "%s(%r)" % (type(self).__name__, str(self))
def __hash__(self):
return hash(self._data)
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, Import):
return NotImplemented
return cmp(self._data, other._data)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, Import):
return NotImplemented
return self._data == other._data
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, other):
if self is other:
return False
if not isinstance(other, Import):
return NotImplemented
return self._data < other._data
@total_ordering
class ImportStatement(object):
"""
Token-level representation of an import statement containing multiple
imports from a single module. Corresponds to an ``ast.ImportFrom`` or
``ast.Import``.
"""
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, (PythonStatement, str)):
return cls._from_statement(arg)
if isinstance(arg, (ast.ImportFrom, ast.Import)):
return cls._from_ast_node(arg)
if isinstance(arg, Import):
return cls._from_imports([arg])
if isinstance(arg, (tuple, list)) and len(arg):
if isinstance(arg[0], Import):
return cls._from_imports(arg)
raise TypeError
@classmethod
def from_parts(cls, fromname, aliases):
self = object.__new__(cls)
self.fromname = fromname
if not len(aliases):
raise ValueError
def interpret_alias(arg):
if isinstance(arg, str):
return (arg, None)
if not isinstance(arg, tuple):
raise TypeError
if not len(arg) == 2:
raise TypeError
if not isinstance(arg[0], str):
raise TypeError
if not (arg[1] is None or isinstance(arg[1], str)):
raise TypeError
return arg
self.aliases = tuple(interpret_alias(a) for a in aliases)
return self
@classmethod
def _from_statement(cls, statement):
"""
>>> ImportStatement._from_statement("from foo import bar, bar2, bar")
ImportStatement('from foo import bar, bar2, bar')
>>> ImportStatement._from_statement("from foo import bar as bar")
ImportStatement('from foo import bar as bar')
>>> ImportStatement._from_statement("from foo.bar import baz")
ImportStatement('from foo.bar import baz')
>>> ImportStatement._from_statement("import foo.bar")
ImportStatement('import foo.bar')
>>> ImportStatement._from_statement("from .foo import bar")
ImportStatement('from .foo import bar')
>>> ImportStatement._from_statement("from . import bar, bar2")
ImportStatement('from . import bar, bar2')
:type statement:
`PythonStatement`
:rtype:
`ImportStatement`
"""
statement = PythonStatement(statement)
return cls._from_ast_node(statement.ast_node)
@classmethod
def _from_ast_node(cls, node):
"""
Construct an `ImportStatement` from an `ast` node.
:rtype:
`ImportStatement`
"""
if isinstance(node, ast.ImportFrom):
if isinstance(node.module, str):
module = node.module
elif node.module is None:
# In python2.7, ast.parse("from . import blah") yields
# node.module = None. In python2.6, it's the empty string.
module = ''
else:
raise TypeError("unexpected node.module=%s"
% type(node.module).__name__)
fromname = '.' * node.level + module
elif isinstance(node, ast.Import):
fromname = None
else:
raise NonImportStatementError
aliases = [ (alias.name, alias.asname) for alias in node.names ]
return cls.from_parts(fromname, aliases)
@classmethod
def _from_imports(cls, imports):
"""
Construct an `ImportStatement` from a sequence of ``Import`` s. They
must all have the same ``fromname``.
:type imports:
Sequence of `Import` s
:rtype:
`ImportStatement`
"""
if not all(isinstance(imp, Import) for imp in imports):
raise TypeError
if not len(imports) > 0:
raise ValueError
module_names = set(imp.split.module_name for imp in imports)
if len(module_names) > 1:
raise Exception(
"Inconsistent module names %r" % (sorted(module_names),))
fromname = list(module_names)[0]
aliases = [ imp.split[1:] for imp in imports ]
return cls.from_parts(fromname, aliases)
@cached_attribute
def imports(self):
"""
Return a sequence of `Import` s.
:rtype:
``tuple`` of `Import` s
"""
return tuple(
Import.from_split((self.fromname, alias[0], alias[1]))
for alias in self.aliases)
@cached_attribute
def flags(self):
"""
If this is a __future__ import, then the bitwise-ORed of the
compiler_flag values associated with the features. Otherwise, 0.
"""
return CompilerFlags(*[imp.flags for imp in self.imports])
def pretty_print(self, params=FormatParams(),
import_column=None, from_spaces=1):
"""
Pretty-print into a single string.
:type params:
`FormatParams`
:param modulename_ljust:
Number of characters to left-justify the 'from' name.
:rtype:
``str``
"""
s0 = ''
s = ''
assert from_spaces >= 1
if self.fromname is not None:
s += "from%s%s " % (' ' * from_spaces, self.fromname)
if import_column is not None:
if len(s) > import_column:
# The caller wants the 'import' statement lined up left of
# where the current end of the line is. So wrap it
# specially like this::
# from foo import ...
# from foo.bar.baz \
# import ...
s0 = s + '\\\n'
s = ' ' * import_column
else:
s = s.ljust(import_column)
s += "import "
tokens = []
for importname, asname in self.aliases:
if asname is not None:
t = "%s as %s" % (importname, asname)
else:
t = "%s" % (importname,)
tokens.append(t)
res = s0 + pyfill(s, tokens, params=params)
if params.use_black:
import black
mode = black.FileMode()
return black.format_str(res, mode=mode)
return res
@property
def _data(self):
return (self.fromname, self.aliases)
def __str__(self):
return self.pretty_print(FormatParams(max_line_length=Inf)).rstrip()
def __repr__(self):
return "%s(%r)" % (type(self).__name__, str(self))
def __cmp__(self, other):
if self is other:
return 0
if not isinstance(other, ImportStatement):
return NotImplemented
return cmp(self._data, other._data)
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, ImportStatement):
return NotImplemented
return self._data == other._data
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, other):
if not isinstance(other, ImportStatement):
return NotImplemented
return self._data < other._data
def __hash__(self):
return hash(self._data)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,807 @@
# pyflyby/_livepatch.py
# Copyright (C) 2011, 2012, 2013, 2014, 2015 Karl Chen.
r"""
livepatch/xreload: Alternative to reload().
xreload performs a "live patch" of the modules/classes/functions/etc that have
already been loaded in memory. It does so by executing the module in a
scratch namespace, and then patching classes, methods and functions in-place.
New objects are copied into the target namespace.
This addresses cases where one module imported functions from another
module.
For example, suppose m1.py contains::
from m2 import foo
def print_foo():
return foo()
and m2.py contains::
def foo():
return 42
If you edit m2.py and modify ``foo``, then reload(m2) on its own would not do
what you want. You would also need to reload(m1) after reload(m2). This is
because the built-in reload affects the module being reloaded, but references
to the old module remain. On the other hand, xreload() patches the existing
m2.foo, so that live references to it are updated.
In table form::
Undesired effect: reload(m2)
Undesired effect: reload(m1); reload(m2)
Desired effect: reload(m2); reload(m1)
Desired effect: xreload(m2)
Desired effect: xreload(m1); xreload(m2)
Desired effect: xreload(m2); xreload(m1)
Even with just two modules, we can see that xreload() is an improvement. When
working with a large set of interdependent modules, it becomes infeasible to
know the precise sequence of reload() calls that would be necessary.
xreload() really shines in that case.
This implementation of xreload() was originally based the following
mailing-list post by Guido van Rossum:
https://mail.python.org/pipermail/edu-sig/2007-February/007787.html
Customizing behavior
====================
If a class/function/module/etc has an attribute __livepatch__, then this
function is called *instead* of performing the regular livepatch mechanism.
The __livepatch__() function is called with the following arguments:
- ``old`` : The object to be updated with contents of ``new``
- ``new`` : The object whose contents to put into ``old``
- ``do_livepatch``: A function that can be called to do the standard
livepatch, replacing the contents of ``old`` with ``new``.
If it's not possible to livepatch ``old``, it returns
``new``. The ``do_livepatch`` function takes no arguments.
Calling the ``do_livepatch`` function is roughly
equivalent to calling ``pyflyby.livepatch(old, new,
modname=modname, heed_hook=False)``.
- ``modname`` : The module currently being updated. Recursively called
updates should keep track of the module being updated to
avoid touching other modules.
These arguments are matched by *name* and are passed only if the
``__livepatch__`` function is declared to take such named arguments or it takes
\**kwargs. If the ``__livepatch__`` function takes \**kwargs, it should ignore
unknown arguments, in case new parameters are added in the future.
If the object being updated is an object instance, and ``__livepatch__`` is a
method, then the function is bound to the new object, i.e. the ``self``
parameter is the same as ``new``.
If the ``__livepatch__`` function successfully patched the ``old`` object, then
it should return ``old``. If it is unable to patch, it should return ``new``.
Examples
--------
By default, any attributes on an existing function are updated with ones from
the new function. If you want a memoized function to keep its cache across
xreload, you could implement that like this::
def memoize(function):
cache = {}
def wrapped_fn(*args):
try:
return cache[args]
except KeyError:
result = function(*args)
cache[args] = result
return result
wrapped_fn.cache = cache
def my_livepatch(old, new, do_livepatch):
keep_cache = dict(old.cache)
result = do_livepatch()
result.cache.update(keep_cache)
return result
wrapped_fn.__livepatch__ = my_livepatch
return wrapped_fn
XXX change example b/c cache is already cleared by default
XXX maybe global cache
class MyObj(...):
def __livepatch__(self, old):
self.__dict__.update(old.__dict__)
return self
class MyObj(...):
def __init__(self):
self._my_cache = {}
def __livepatch__(self, old, do_livepatch):
keep_cache = dict(old._my_cache)
result = do_livepatch()
result._my_cache.update(keep_cache)
return result
XXX test
"""
from __future__ import (absolute_import, division, print_function,
with_statement)
import ast
import os
import re
import six
import sys
import time
import types
from six import PY2
from six.moves import reload_module
import inspect
from pyflyby._log import logger
# Keep track of when the process was started.
if os.uname()[0] == 'Linux':
_PROCESS_START_TIME = os.stat("/proc/%d"%os.getpid()).st_ctime
else:
try:
import psutil
except ImportError:
# Todo: better fallback
_PROCESS_START_TIME = time.time()
else:
_PROCESS_START_TIME = psutil.Process(os.getpid()).create_time()
class UnknownModuleError(ImportError):
pass
def livepatch(old, new, modname=None,
visit_stack=(), cache=None, assume_type=None,
heed_hook=True):
"""
Livepatch ``old`` with contents of ``new``.
If ``old`` can't be livepatched, then return ``new``.
:param old:
The object to be updated
:param new:
The object used as the source for the update.
:type modname:
``str``
:param modname:
Only livepatch ``old`` if it was defined in the given fully-qualified
module name. If ``None``, then update regardless of module.
:param assume_type:
Update as if both ``old`` and ``new`` were of type ``assume_type``. If
``None``, then ``old`` and ``new`` must have the same type.
For internal use.
:param cache:
Cache of already-updated objects. Map from (id(old), id(new)) to result.
:param visit_stack:
Ids of objects that are currently being updated.
Used to deal with reference cycles.
For internal use.
:param heed_hook:
If ``True``, heed the ``__livepatch__`` hook on ``new``, if any.
If ``False``, ignore any ``__livepatch__`` hook on ``new``.
:return:
Either live-patched ``old``, or ``new``.
"""
if old is new:
return new
# If we're already visiting this object (due to a reference cycle), then
# don't recurse again.
if id(old) in visit_stack:
return old
if cache is None:
cache = {}
cachekey = (id(old), id(new))
try:
return cache[cachekey]
except KeyError:
pass
visit_stack += (id(old),)
def do_livepatch():
new_modname = _get_definition_module(new)
if modname and new_modname and new_modname != modname:
# Ignore objects that have been imported from another module.
# Just update their references.
return new
if assume_type is not None:
use_type = assume_type
else:
oldtype = type(old)
newtype = type(new)
if oldtype is newtype:
# Easy, common case: Type didn't change.
use_type = oldtype
elif (oldtype.__name__ == newtype.__name__ and
oldtype.__module__ == newtype.__module__ == modname and
getattr(sys.modules[modname],
newtype.__name__, None) is newtype and
oldtype is livepatch(
oldtype, newtype, modname=modname,
visit_stack=visit_stack, cache=cache)):
# Type of this object was defined in this module. This
# includes metaclasses defined in the same module.
use_type = oldtype
else:
# If the type changed, then give up.
return new
try:
mro = type.mro(use_type)
except TypeError:
mro = [use_type, object] # old-style class
# Dispatch on type. Include parent classes (in C3 linearized
# method resolution order), in particular so that this works on
# classes with custom metaclasses that subclass ``type``.
for t in mro:
try:
update = _LIVEPATCH_DISPATCH_TABLE[t]
break
except KeyError:
pass
else:
# We should have found at least ``object``
raise AssertionError("unreachable")
# Dispatch.
return update(old, new, modname=modname,
cache=cache, visit_stack=visit_stack)
if heed_hook:
hook = (getattr(new, "__livepatch__", None) or
getattr(new, "__reload_update__", None))
# XXX if unbound method or a descriptor, then we should ignore it.
# XXX test for that.
else:
hook = None
if hook is None:
# No hook is defined or the caller instructed us to ignore it.
# Do the standard livepatch.
result = do_livepatch()
else:
# Call a hook for updating.
# Build dict of optional kwargs.
avail_kwargs = dict(
old=old,
new=new,
do_livepatch=do_livepatch,
modname=modname,
cache=cache,
visit_stack=visit_stack)
# Find out which optional kwargs the hook wants.
kwargs = {}
if PY2:
argspec = inspect.getargspec(hook)
else:
argspec = inspect.getfullargspec(hook)
argnames = argspec.args
if hasattr(hook, "__func__"):
# Skip 'self' arg.
argnames = argnames[1:]
# Pick kwargs that are wanted and available.
args = []
kwargs = {}
for n in argnames:
try:
kwargs[n] = avail_kwargs[n]
if argspec.keywords if PY2 else argspec.varkw:
break
except KeyError:
# For compatibility, allow first argument to be 'old' with any
# name, as long as there's no other arg 'old'.
# We intentionally allow this even if the user specified
# **kwargs.
if not args and not kwargs and 'old' not in argnames:
args.append(old)
else:
# Rely on default being set. If a default isn't set, the
# user will get a TypeError.
pass
if argspec.keywords if PY2 else argspec.varkw:
# Use all available kwargs.
kwargs = avail_kwargs
# Call hook.
result = hook(*args, **kwargs)
cache[cachekey] = result
return result
def _livepatch__module(old_mod, new_mod, modname, cache, visit_stack):
"""
Livepatch a module.
"""
result = livepatch(old_mod.__dict__, new_mod.__dict__,
modname=modname,
cache=cache, visit_stack=visit_stack)
assert result is old_mod.__dict__
return old_mod
def _livepatch__dict(old_dict, new_dict, modname, cache, visit_stack):
"""
Livepatch a dict.
"""
oldnames = set(old_dict)
newnames = set(new_dict)
# Add newly introduced names.
for name in newnames - oldnames:
old_dict[name] = new_dict[name]
# Delete names that are no longer current.
for name in oldnames - newnames:
del old_dict[name]
# Livepatch existing entries.
updated_names = sorted(oldnames & newnames, key=str)
for name in updated_names:
old = old_dict[name]
updated = livepatch(old, new_dict[name],
modname=modname,
cache=cache, visit_stack=visit_stack)
if updated is not old:
old_dict[name] = updated
return old_dict
def _livepatch__function(old_func, new_func, modname, cache, visit_stack):
"""
Livepatch a function.
"""
# If the name differs, then don't update the existing function - this
# is probably a reassigned function.
if old_func.__name__ != new_func.__name__:
return new_func
# Check if the function's closure is compatible. If not, then return the
# new function without livepatching. Note that cell closures can't be
# modified; we can only livepatch cell values.
old_closure = old_func.__closure__ or ()
new_closure = new_func.__closure__ or ()
if len(old_closure) != len(new_closure):
return new_func
if old_func.__code__.co_freevars != new_func.__code__.co_freevars:
return new_func
for oldcell, newcell in zip(old_closure, new_closure):
oldcellv = oldcell.cell_contents
newcellv = newcell.cell_contents
if type(oldcellv) != type(newcellv):
return new_func
if isinstance(oldcellv, (
types.FunctionType, types.MethodType, six.class_types, dict)):
# Updateable type. (Todo: make this configured globally.)
continue
try:
if oldcellv is newcellv or oldcellv == newcellv:
continue
except Exception:
pass
# Non-updateable and not the same as before.
return new_func
# Update function code, defaults, doc.
old_func.__code__ = new_func.__code__
old_func.__defaults__ = new_func.__defaults__
old_func.__doc__ = new_func.__doc__
# Update dict.
livepatch(old_func.__dict__, new_func.__dict__,
modname=modname, cache=cache, visit_stack=visit_stack)
# Update the __closure__. We can't set __closure__ because it's a
# read-only attribute; we can only livepatch its cells' values.
for oldcell, newcell in zip(old_closure, new_closure):
oldcellv = oldcell.cell_contents
newcellv = newcell.cell_contents
livepatch(oldcellv, newcellv,
modname=modname, cache=cache, visit_stack=visit_stack)
return old_func
def _livepatch__method(old_method, new_method, modname, cache, visit_stack):
"""
Livepatch a method.
"""
_livepatch__function(old_method.__func__, new_method.__func__,
modname=modname,
cache=cache, visit_stack=visit_stack)
return old_method
def _livepatch__setattr(oldobj, newobj, name, modname, cache, visit_stack):
"""
Livepatch something via setattr, i.e.::
oldobj.{name} = livepatch(oldobj.{name}, newobj.{name}, ...)
"""
newval = getattr(newobj, name)
assert type(newval) is not types.MemberDescriptorType
try:
oldval = getattr(oldobj, name)
except AttributeError:
# This shouldn't happen, but just ignore it.
setattr(oldobj, name, newval)
return
# If it's the same object, then skip. Note that if even if 'newval ==
# oldval', as long as they're not the same object instance, we still
# livepatch. We want mutable data structures get livepatched instead of
# replaced. Avoiding calling '==' also avoids the risk of user code
# having defined '==' to do something unexpected.
if newval is oldval:
return
# Livepatch the member object.
newval = livepatch(
oldval, newval, modname=modname, cache=cache, visit_stack=visit_stack)
# If the livepatch succeeded then we don't need to setattr. It should be
# a no-op but we avoid it just to minimize any chance of setattr causing
# problems in corner cases.
if newval is oldval:
return
# Livepatch failed, so we have to update the container with the new member
# value.
setattr(oldobj, name, newval)
def _livepatch__class(oldclass, newclass, modname, cache, visit_stack):
"""
Livepatch a class.
This is similar to _livepatch__dict(oldclass.__dict__, newclass.__dict__).
However, we can't just operate on the dict, because class dictionaries are
special objects that don't allow setitem, even though we can setattr on
the class.
"""
# Collect the names to update.
olddict = oldclass.__dict__
newdict = newclass.__dict__
# Make sure slottiness hasn't changed -- i.e. if class was changed to have
# slots, or changed to not have slots, or if the slot names changed in any
# way, then we can't livepatch the class.
# Note that this is about whether instances of this class are affected by
# __slots__ or not. The class type itself will always use a __dict__.
if olddict.get("__slots__") != newdict.get("__slots__"):
return newclass
oldnames = set(olddict)
newnames = set(newdict)
for name in oldnames - newnames:
delattr(oldclass, name)
for name in newnames - oldnames:
setattr(oldclass, name, newdict[name])
oldclass.__bases__ = newclass.__bases__
names = oldnames & newnames
names.difference_update(olddict.get("__slots__", []))
names.discard("__slots__")
names.discard("__dict__")
# Python < 3.3 doesn't support modifying __doc__ on classes with
# non-custom metaclasses. Attempt to do it and ignore failures.
# http://bugs.python.org/issue12773
names.discard("__doc__")
try:
oldclass.__doc__ = newclass.__doc__
except AttributeError:
pass
# Loop over attributes to be updated.
for name in sorted(names):
_livepatch__setattr(
oldclass, newclass, name, modname, cache, visit_stack)
return oldclass
def _livepatch__object(oldobj, newobj, modname, cache, visit_stack):
"""
Livepatch a general object.
"""
# It's not obvious whether ``oldobj`` and ``newobj`` are actually supposed
# to represent the same object. For now, we take a middle ground of
# livepatching iff the class was also defined in the same module. In that
# case at least we know that the object was defined in this module and
# therefore more likely that we should livepatch.
if modname and _get_definition_module(type(oldobj)) != modname:
return newobj
if hasattr(type(oldobj), "__slots__"):
assert oldobj.__slots__ == newobj.__slots__
for name in newobj.__slots__:
hasold = hasattr(oldobj, name)
hasnew = hasattr(newobj, name)
if hasold and hasnew:
_livepatch__setattr(oldobj, newobj, name,
modname, cache, visit_stack)
elif hasold and not hasnew:
delattr(oldobj, name)
elif not hasold and hasnew:
setattr(oldobj, getattr(newobj, name))
elif not hasold and not hasnew:
pass
else:
raise AssertionError
return oldobj
elif type(getattr(oldobj, "__dict__", None)) is dict:
livepatch(
oldobj.__dict__, newobj.__dict__,
modname=modname, cache=cache, visit_stack=visit_stack)
return oldobj
else:
return newobj
if six.PY2:
_LIVEPATCH_DISPATCH_TABLE = {
object : _livepatch__object,
dict : _livepatch__dict,
type : _livepatch__class,
types.ClassType : _livepatch__class,
types.FunctionType: _livepatch__function,
types.MethodType : _livepatch__method,
types.ModuleType : _livepatch__module,
}
elif six.PY3:
_LIVEPATCH_DISPATCH_TABLE = {
object : _livepatch__object,
dict : _livepatch__dict,
type : _livepatch__class,
types.FunctionType: _livepatch__function,
types.MethodType : _livepatch__method,
types.ModuleType : _livepatch__module,
}
def _get_definition_module(obj):
"""
Get the name of the module that an object is defined in, or ``None`` if
unknown.
For classes and functions, this returns the ``__module__`` attribute.
For object instances, this returns ``None``, ignoring the ``__module__``
attribute. The reason is that the ``__module__`` attribute on an instance
just gives the module that the class was defined in, which is not
necessarily the module where the instance was constructed.
:rtype:
``str``
"""
if isinstance(obj, (type, six.class_types, types.FunctionType,
types.MethodType)):
return getattr(obj, "__module__", None)
else:
return None
def _format_age(t):
secs = time.time() - t
if secs > 120:
return "%dm%ds" %(secs//60, secs%60)
else:
return "%ds" %(secs,)
def _interpret_module(arg):
def mod_fn(module):
return getattr(module, "__file__", None)
if isinstance(arg, six.string_types):
try:
return sys.modules[arg]
except KeyError:
pass
if arg.startswith("/"):
fn = os.path.realpath(arg)
if fn.endswith(".pyc") or fn.endswith(".pyo"):
fn = fn[:-1]
if fn.endswith(".py"):
relevant_fns = set([fn, fn+"c", fn+"o"])
else:
relevant_fns = set([fn])
found_modules = [
m for _,m in sorted(sys.modules.items())
if os.path.realpath(mod_fn(m) or "/") in relevant_fns ]
if not found_modules:
raise UnknownModuleError(
"No loaded module uses path %s" % (fn,))
if len(found_modules) > 1:
raise UnknownModuleError(
"Multiple loaded modules use path %s: %r"
% (fn, found_modules))
return found_modules[0]
if arg.endswith(".py") and "/" not in arg:
name = arg[:-3]
relevant_bns = set([arg, arg+"c", arg+"o"])
found_modules = [
m for n,m in sorted(sys.modules.items())
if (n==name or
os.path.basename(mod_fn(m) or "/") in relevant_bns)]
if not found_modules:
raise UnknownModuleError(
"No loaded module named %s" % (name,))
if len(found_modules) > 1:
raise UnknownModuleError(
"Multiple loaded modules named %s: %r"
% (name, found_modules))
return found_modules[0]
raise UnknownModuleError(arg)
if isinstance(arg, types.ModuleType):
return arg
try:
# Allow fake modules.
if sys.modules[arg.__name__] is arg:
return arg
except Exception:
pass
raise TypeError("Expected module, module name, or filename; got %s"
% (type(arg).__name__))
def _xreload_module(module, filename, force=False):
"""
Reload a module in place, using livepatch.
:type module:
``ModuleType``
:param module:
Module to reload.
:param force:
Whether to reload even if the module has not been modified since the
previous load. If ``False``, then do nothing. If ``True``, then reload.
"""
import linecache
if not filename or not filename.endswith(".py"):
# If there's no *.py source file for this module, then fallback to
# built-in reload().
return reload_module(module)
# Compare mtime of the file with the load time of the module. If the file
# wasn't touched, we don't need to do anything.
try:
mtime = os.stat(filename).st_mtime
except OSError:
logger.info("Can't find %s", filename)
return None
if not force:
try:
old_loadtime = module.__loadtime__
except AttributeError:
# We only have a __loadtime__ attribute if we were the ones that
# loaded it. Otherwise, fall back to the process start time as a
# conservative bound.
old_loadtime = _PROCESS_START_TIME
if old_loadtime > mtime:
logger.debug(
"NOT reloading %s (file %s modified %s ago but loaded %s ago)",
module.__name__, filename, _format_age(mtime),
_format_age(old_loadtime))
return None
# Keep track of previously imported source. If the file's timestamp
# was touched, but the content unchanged, we can avoid reloading.
cached_lines = linecache.cache.get(filename, (None,None,None,None))[2]
else:
cached_lines = None
# Re-read source for module from disk, and update the linecache.
source = ''.join(linecache.updatecache(filename))
# Skip reload if the content didn't change.
if cached_lines is not None and source == ''.join(cached_lines):
logger.debug(
"NOT reloading %s (file %s touched %s ago but content unchanged)",
module.__name__, filename, _format_age(mtime))
return module
logger.info("Reloading %s (modified %s ago) from %s",
module.__name__, _format_age(mtime), filename)
# Compile into AST. We do this as a separate step from compiling to byte
# code so that we can get the module docstring.
astnode = compile(source, filename, "exec", ast.PyCF_ONLY_AST, 1)
# Get the new docstring.
try:
doc = astnode.body[0].value.s
except (AttributeError, IndexError):
doc = None
# Compile into code.
code = compile(astnode, filename, "exec", 0, 1)
# Execute the code. We do so in a temporary namespace so that if this
# fails, nothing changes. It's important to set __name__ so that relative
# imports work correctly.
new_mod = types.ModuleType(module.__name__)
new_mod.__file__ = filename
new_mod.__doc__ = doc
if hasattr(module, "__path__"):
new_mod.__path__ = module.__path__
MISSING = object()
saved_mod = sys.modules.get(module.__name__, MISSING)
try:
# Temporarily put the temporary module in sys.modules, in case the
# code references sys.modules[__name__] for some reason. Normally on
# success, we will revert this what that was there before (which
# normally should be ``module``). If an error occurs, we'll also
# revert. If the user has defined a __livepatch__ hook at the module
# level, it's possible for result to not be the old module.
sys.modules[module.__name__] = new_mod
# *** Execute new code ***
exec(code, new_mod.__dict__)
# Normally ``module`` is of type ``ModuleType``. However, in some
# cases, the module might have done a "proxy module" trick where the
# module is replaced by a proxy object of some other type. Regardless
# of the actual type, we do the update as ``module`` were of type
# ``ModuleType``.
assume_type = types.ModuleType
# Livepatch the module.
result = livepatch(module, new_mod, module.__name__,
assume_type=assume_type)
sys.modules[module.__name__] = result
except:
# Either the module failed executing or the livepatch failed.
# Revert to previous state.
# Note that this isn't perfect because it's possible that the module
# modified some global state in other modules.
if saved_mod is MISSING:
del sys.modules[module.__name__]
else:
sys.modules[module.__name__] = saved_mod
raise
# Update the time we last loaded the module. We intentionally use mtime
# here instead of time.time(). If we are on NFS, it's possible for the
# filer's mtime and time.time() to not be synchronized. We will be
# comparing to mtime next time, so if we use only mtime, we'll be fine.
module.__loadtime__ = mtime
return module
def _get_module_py_file(module):
filename = getattr(module, "__file__", None)
if not filename:
return None
filename = re.sub("[.]py[co]$", ".py", filename)
return filename
def xreload(*args):
"""
Reload module(s).
This function is more useful than the built-in reload(). xreload() uses a
"live patch" approach that modifies existing functions, classes, and
objects in-place.
This addresses cases where one module imported functions from another
module.
For example, suppose m1.py contains::
from m2 import foo
def print_foo():
return foo()
and m2.py contains::
def foo():
return 42
If you edit m2.py and modify ``foo``, then reload(m2) on its own would not
do what you want. The built-in reload affects the module being reloaded,
but references to the old module remain. On the other hand, xreload()
patches the existing m2.foo, so that live references to it are updated.
:type args:
``str`` s and/or ``ModuleType`` s
:param args:
Module(s) to reload. If no argument is specified, then reload all
recently modified modules.
"""
if not args:
for name, module in sorted(sys.modules.items()):
if name == "__main__":
continue
filename = _get_module_py_file(module)
if not filename:
continue
_xreload_module(module, filename)
return
# Treat xreload(list_of_module) like xreload(*list_of_modules). We
# intentionally do this after the above check so that xreload([]) does
# nothing.
if len(args) == 1 and isinstance(args[0], (tuple, list)):
args = args[0]
for arg in args:
module = _interpret_module(arg)
# Get the *.py filename for this module.
filename = _get_module_py_file(module)
# Reload the module.
_xreload_module(module, filename)

View file

@ -0,0 +1,240 @@
# pyflyby/_log.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015, 2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from contextlib import contextmanager
import logging
from logging import Handler, Logger
import os
from six.moves import builtins
import sys
class _PyflybyHandler(Handler):
_pre_log_function = None
_logged_anything_during_context = False
_interactive_prefix = "\033[0m\033[33m[PYFLYBY]\033[0m "
_noninteractive_prefix = "[PYFLYBY] "
def emit(self, record):
"""
Emit a log record.
"""
try:
# Call pre-log hook.
if self._pre_log_function is not None:
if not self._logged_anything_during_context:
self._pre_log_function()
self._logged_anything_during_context = True
# Format (currently a no-op).
msg = self.format(record)
# Add prefix per line.
if _is_ipython() or _is_interactive(sys.stderr):
prefix = self._interactive_prefix
else:
prefix = self._noninteractive_prefix
msg = ''.join(["%s%s\n" % (prefix, line) for line in msg.splitlines()])
# First, flush stdout, to make sure that stdout and stderr don't get
# interleaved. Normally this is automatic, but when stdout is piped,
# it can be necessary to force a flush to avoid interleaving.
sys.stdout.flush()
# Write log message.
if sys.stderr.__class__.__module__.startswith("prompt_toolkit"):
with _PromptToolkitStdoutProxyRawCtx(sys.stderr):
sys.stderr.write(msg)
sys.stderr.flush()
else:
sys.stderr.write(msg)
# Flush now - we don't want any interleaving of stdout/stderr.
sys.stderr.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
@contextmanager
def HookCtx(self, pre, post):
"""
Enter a context where:
* ``pre`` is called before the first time a log record is emitted
during the context, and
* ``post`` is called at the end of the context, if any log records
were emitted during the context.
:type pre:
``callable``
:param pre:
Function to call before the first time something is logged during
this context.
:type post:
``callable``
:param post:
Function to call before returning from the context, if anything was
logged during the context.
"""
assert self._pre_log_function is None
self._pre_log_function = pre
try:
yield
finally:
if self._logged_anything_during_context:
post()
self._logged_anything_during_context = False
self._pre_log_function = None
def _is_interactive(file):
filemod = type(file).__module__
if filemod.startswith("IPython.") or filemod.startswith("prompt_toolkit."):
# Inside IPython notebook/kernel
return True
try:
fileno = file.fileno()
except Exception:
return False # dunno
return os.isatty(fileno)
def _is_ipython():
"""
Returns true if we're currently running inside IPython.
"""
# This currently only works for versions of IPython that are modern enough
# to install 'builtins.get_ipython()'.
if 'IPython' not in sys.modules:
return False
if not hasattr(builtins, "get_ipython"):
return False
ip = builtins.get_ipython()
if ip is None:
return False
return True
@contextmanager
def _PromptToolkitStdoutProxyRawCtx(proxy):
"""
Hack to defeat the "feature" where
prompt_toolkit.interface._StdoutProxy(sys.stderr) causes ANSI escape codes
to not be written.
"""
# prompt_toolkit replaces sys.stderr with a proxy object. This proxy
# object replaces ESC (\xb1) with '?'. That breaks our colorization of
# the [PYFLYBY] log prefix. To work around this, we need to temporarily
# set _StdoutProxy._raw to True during the write() call. However, the
# write() call actually just stores a lambda to be executed later, and
# that lambda references self._raw by reference. So we can't just set
# _raw before we call sys.stderr.write(), since the _raw variable is not
# read yet at that point. We need to hook the internals so that we store
# a wrapped lambda which temporarily sets _raw to True. Yuck, this is so
# brittle. Tested with prompt_toolkit 1.0.15.
if not hasattr(type(proxy), '_do') or not hasattr(proxy, '_raw'):
yield
return
MISSING = object()
prev = proxy.__dict__.get('_do', MISSING)
original_do = proxy._do
def wrapped_do_raw(self, func):
def wrapped_func():
prev_raw = self._raw
try:
self._raw = True
func()
finally:
self._raw = prev_raw
original_do(wrapped_func)
try:
proxy._do = wrapped_do_raw.__get__(proxy)
yield
finally:
if prev is MISSING:
proxy.__dict__.pop('_do', None)
else:
proxy.__dict__ = prev
@contextmanager
def _NoRegisterLoggerHandlerInHandlerListCtx():
"""
Work around a bug in the ``logging`` module for Python 2.x-3.2.
The Python stdlib ``logging`` module has a bug where you sometimes get the
following warning at exit::
Exception TypeError: "'NoneType' object is not callable" in <function
_removeHandlerRef at 0x10a1b3f50> ignored
This is caused by shutdown ordering affecting which globals in the logging
module are available to the _removeHandlerRef function.
Python 3.3 fixes this.
For earlier versions of Python, this context manager works around the
issue by avoiding registering a handler in the _handlerList. This means
that we no longer call "flush()" from the atexit callback. However, that
was a no-op anyway, and even if we needed it, we could call it ourselves
atexit.
:see:
http://bugs.python.org/issue9501
"""
if not hasattr(logging, "_handlerList"):
yield
return
if sys.version_info >= (3, 3):
yield
return
try:
orig_handlerList = logging._handlerList[:]
yield
finally:
logging._handlerList[:] = orig_handlerList
class PyflybyLogger(Logger):
_LEVELS = dict( (k, getattr(logging, k))
for k in ['DEBUG', 'INFO', 'WARNING', 'ERROR'] )
def __init__(self, name, level):
Logger.__init__(self, name)
with _NoRegisterLoggerHandlerInHandlerListCtx():
handler = _PyflybyHandler()
self.addHandler(handler)
self.set_level(level)
def set_level(self, level):
"""
Set the pyflyby logger's level to ``level``.
:type level:
``str``
"""
if isinstance(level, int):
level_num = level
else:
try:
level_num = self._LEVELS[level.upper()]
except KeyError:
raise ValueError("Bad log level %r" % (level,))
Logger.setLevel(self, level_num)
@property
def debug_enabled(self):
return self.level <= logging.DEBUG
@property
def info_enabled(self):
return self.level <= logging.INFO
def HookCtx(self, pre, post):
return self.handlers[0].HookCtx(pre, post)
logger = PyflybyLogger('pyflyby', os.getenv("PYFLYBY_LOG_LEVEL") or "INFO")

View file

@ -0,0 +1,429 @@
# pyflyby/_modules.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from functools import total_ordering
import os
import re
import six
from six import reraise
import sys
import types
from pyflyby._file import FileText, Filename
from pyflyby._idents import DottedIdentifier, is_identifier
from pyflyby._log import logger
from pyflyby._util import (ExcludeImplicitCwdFromPathCtx,
cached_attribute, cmp, memoize,
prefixes)
class ErrorDuringImportError(ImportError):
"""
Exception raised by import_module if the module exists but an exception
occurred while attempting to import it. That nested exception could be
ImportError, e.g. if a module tries to import another module that doesn't
exist.
"""
@memoize
def import_module(module_name):
module_name = str(module_name)
logger.debug("Importing %r", module_name)
try:
result = __import__(module_name, fromlist=['dummy'])
if result.__name__ != module_name:
logger.debug("Note: import_module(%r).__name__ == %r",
module_name, result.__name__)
return result
except ImportError as e:
# We got an ImportError. Figure out whether this is due to the module
# not existing, or whether the module exists but caused an ImportError
# (perhaps due to trying to import another problematic module).
# Do this by looking at the exception traceback. If the previous
# frame in the traceback is this function (because locals match), then
# it should be the internal import machinery reporting that the module
# doesn't exist. Re-raise the exception as-is.
# If some sys.meta_path or other import hook isn't compatible with
# such a check, here are some things we could do:
# - Use pkgutil.find_loader() after the fact to check if the module
# is supposed to exist. Note that we shouldn't rely solely on
# this before attempting to import, because find_loader() doesn't
# work with meta_path.
# - Write a memoized global function that compares in the current
# environment the difference between attempting to import a
# non-existent module vs a problematic module, and returns a
# function that uses the working discriminators.
real_importerror1 = type(e) is ImportError
real_importerror2 = (sys.exc_info()[2].tb_frame.f_locals is locals())
m = re.match("^No module named (.*)$", str(e))
real_importerror3 = (m and m.group(1) == module_name
or module_name.endswith("."+m.group(1)))
logger.debug("import_module(%r): real ImportError: %s %s %s",
module_name,
real_importerror1, real_importerror2, real_importerror3)
if real_importerror1 and real_importerror2 and real_importerror3:
raise
reraise(ErrorDuringImportError(
"Error while attempting to import %s: %s: %s"
% (module_name, type(e).__name__, e)), None, sys.exc_info()[2])
except Exception as e:
reraise(ErrorDuringImportError(
"Error while attempting to import %s: %s: %s"
% (module_name, type(e).__name__, e)), None, sys.exc_info()[2])
def _my_iter_modules(path, prefix=''):
# Modified version of pkgutil.ImpImporter.iter_modules(), patched to
# handle inaccessible subdirectories.
if path is None:
return
try:
filenames = os.listdir(path)
except OSError:
return # silently ignore inaccessible paths
filenames.sort() # handle packages before same-named modules
yielded = {}
import inspect
for fn in filenames:
modname = inspect.getmodulename(fn)
if modname=='__init__' or modname in yielded:
continue
subpath = os.path.join(path, fn)
ispkg = False
try:
if not modname and os.path.isdir(path) and '.' not in fn:
modname = fn
for fn in os.listdir(subpath):
subname = inspect.getmodulename(fn)
if subname=='__init__':
ispkg = True
break
else:
continue # not a package
except OSError:
continue # silently ignore inaccessible subdirectories
if modname and '.' not in modname:
yielded[modname] = 1
yield prefix + modname, ispkg
def pyc_to_py(filename):
if filename.endswith(".pyc") or filename.endswith(".pyo"):
filename = filename[:-1]
return filename
@total_ordering
class ModuleHandle(object):
"""
A handle to a module.
"""
def __new__(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, Filename):
return cls._from_filename(arg)
if isinstance(arg, (six.string_types, DottedIdentifier)):
return cls._from_modulename(arg)
if isinstance(arg, types.ModuleType):
return cls._from_module(arg)
raise TypeError("ModuleHandle: unexpected %s" % (type(arg).__name__,))
_cls_cache = {}
@classmethod
def _from_modulename(cls, modulename):
modulename = DottedIdentifier(modulename)
try:
return cls._cls_cache[modulename]
except KeyError:
pass
self = object.__new__(cls)
self.name = modulename
cls._cls_cache[modulename] = self
return self
@classmethod
def _from_module(cls, module):
if not isinstance(module, types.ModuleType):
raise TypeError
self = cls._from_modulename(module.__name__)
assert self.module is module
return self
@classmethod
def _from_filename(cls, filename):
filename = Filename(filename)
raise NotImplementedError(
"TODO: look at sys.path to guess module name")
@cached_attribute
def parent(self):
if not self.name.parent:
return None
return ModuleHandle(self.name.parent)
@cached_attribute
def ancestors(self):
return tuple(ModuleHandle(m) for m in self.name.prefixes)
@cached_attribute
def module(self):
"""
Return the module instance.
:rtype:
``types.ModuleType``
:raise ErrorDuringImportError:
The module should exist but an error occurred while attempting to
import it.
:raise ImportError:
The module doesn't exist.
"""
# First check if prefix component is importable.
if self.parent:
self.parent.module
# Import.
return import_module(self.name)
@cached_attribute
def exists(self):
"""
Return whether the module exists, according to pkgutil.
Note that this doesn't work for things that are only known by using
sys.meta_path.
"""
name = str(self.name)
if name in sys.modules:
return True
if self.parent and not self.parent.exists:
return False
import pkgutil
try:
loader = pkgutil.find_loader(name)
except Exception:
# Catch all exceptions, not just ImportError. If the __init__.py
# for the parent package of the module raises an exception, it'll
# propagate to here.
loader = None
return loader is not None
@cached_attribute
def filename(self):
"""
Return the filename, if appropriate.
The module itself will not be imported, but if the module is not a
top-level module/package, accessing this attribute may cause the
parent package to be imported.
:rtype:
`Filename`
"""
# Use the loader mechanism to find the filename. We do so instead of
# using self.module.__file__, because the latter forces importing a
# module, which may be undesirable.
import pkgutil
try:
loader = pkgutil.get_loader(str(self.name))
except ImportError:
return None
if not loader:
return None
# Get the filename using loader.get_filename(). Note that this does
# more than just loader.filename: for example, it adds /__init__.py
# for packages.
filename = loader.get_filename()
if not filename:
return None
return Filename(pyc_to_py(filename))
@cached_attribute
def text(self):
return FileText(self.filename)
def __text__(self):
return self.text
@cached_attribute
def block(self):
from pyflyby._parse import PythonBlock
return PythonBlock(self.text)
@staticmethod
@memoize
def list():
"""
Enumerate all top-level packages/modules.
:rtype:
``tuple`` of `ModuleHandle` s
"""
import pkgutil
# Get the list of top-level packages/modules using pkgutil.
# We exclude "." from sys.path while doing so. Python includes "." in
# sys.path by default, but this is undesirable for autoimporting. If
# we autoimported random python scripts in the current directory, we
# could accidentally execute code with side effects. If the current
# working directory is /tmp, trying to enumerate modules there also
# causes problems, because there are typically directories there not
# readable by the current user.
with ExcludeImplicitCwdFromPathCtx():
modlist = pkgutil.iter_modules(None)
module_names = [t[1] for t in modlist]
# pkgutil includes all *.py even if the name isn't a legal python
# module name, e.g. if a directory in $PYTHONPATH has files named
# "try.py" or "123.py", pkgutil will return entries named "try" or
# "123". Filter those out.
module_names = [m for m in module_names if is_identifier(m)]
# Canonicalize.
return tuple(ModuleHandle(m) for m in sorted(set(module_names)))
@cached_attribute
def submodules(self):
"""
Enumerate the importable submodules of this module.
>>> ModuleHandle("email").submodules # doctest:+ELLIPSIS
(..., 'email.encoders', ..., 'email.mime', ...)
:rtype:
``tuple`` of `ModuleHandle` s
"""
import pkgutil
module = self.module
try:
path = module.__path__
except AttributeError:
return ()
# Enumerate the modules at a given path. Prefer to use ``pkgutil`` if
# we can. However, if it fails due to OSError, use our own version
# which is robust to that.
try:
submodule_names = [t[1] for t in pkgutil.iter_modules(path)]
except OSError:
submodule_names = [t[0] for p in path for t in _my_iter_modules(p)]
return tuple(ModuleHandle("%s.%s" % (self.name,m))
for m in sorted(set(submodule_names)))
@cached_attribute
def exports(self):
"""
Get symbols exported by this module.
Note that this requires involves actually importing this module, which
may have side effects. (TODO: rewrite to avoid this?)
:rtype:
`ImportSet` or ``None``
:return:
Exports, or ``None`` if nothing exported.
"""
from pyflyby._importclns import ImportStatement, ImportSet
module = self.module
try:
members = module.__all__
except AttributeError:
members = dir(module)
# Filter by non-private.
members = [n for n in members if not n.startswith("_")]
# Filter by definition in the module.
def from_this_module(name):
# TODO: could do this more robustly by parsing the AST and
# looking for STOREs (definitions/assignments/etc).
x = getattr(module, name)
m = getattr(x, "__module__", None)
if not m:
return False
return DottedIdentifier(m).startswith(self.name)
members = [n for n in members if from_this_module(n)]
else:
if not all(type(s) == str for s in members):
raise Exception(
"Module %r contains non-string entries in __all__"
% (str(self.name),))
# Filter out artificially added "deep" members.
members = [n for n in members if "." not in n]
if not members:
return None
return ImportSet(
[ ImportStatement.from_parts(str(self.name), members) ])
def __str__(self):
return str(self.name)
def __repr__(self):
return "%s(%r)" % (type(self).__name__, str(self.name))
def __hash__(self):
return hash(self.name)
def __cmp__(self, o):
if self is o:
return 0
if not isinstance(o, ModuleHandle):
return NotImplemented
return cmp(self.name, o.name)
def __eq__(self, o):
if self is o:
return True
if not isinstance(o, ModuleHandle):
return NotImplemented
return self.name == o.name
def __ne__(self, other):
return not (self == other)
# The rest are defined by total_ordering
def __lt__(self, o):
if not isinstance(o, ModuleHandle):
return NotImplemented
return self.name < o.name
def __getitem__(self, x):
if isinstance(x, slice):
return type(self)(self.name[x])
raise TypeError
@classmethod
def containing(cls, identifier):
"""
Try to find the module that defines a name such as ``a.b.c`` by trying
to import ``a``, ``a.b``, and ``a.b.c``.
:return:
The name of the 'deepest' module (most commonly it would be ``a.b``
in this example).
:rtype:
`Module`
"""
# In the code below we catch "Exception" rather than just ImportError
# or AttributeError since importing and __getattr__ing can raise other
# exceptions.
identifier = DottedIdentifier(identifier)
try:
module = ModuleHandle(identifier[:1])
result = module.module
except Exception as e:
raise ImportError(e)
for part, prefix in zip(identifier, prefixes(identifier))[1:]:
try:
result = getattr(result, str(part))
except Exception:
try:
module = cls(prefix)
result = module.module
except Exception as e:
raise ImportError(e)
else:
if isinstance(result, types.ModuleType):
module = cls(result)
logger.debug("Imported %r to get %r", module, identifier)
return module

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,506 @@
# pyflyby/_util.py.
# Copyright (C) 2011, 2012, 2013, 2014, 2015, 2018 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
from __future__ import (absolute_import, division, print_function,
with_statement)
from contextlib import contextmanager
import inspect
import os
import six
from six import PY3, reraise
import sys
import types
# Python 2/3 compatibility
DictProxyType = type(object.__dict__)
class Object(object):
pass
def memoize(function):
cache = {}
def wrapped_fn(*args, **kwargs):
cache_key = (args, tuple(sorted(kwargs.items())))
try:
return cache[cache_key]
except KeyError:
result = function(*args, **kwargs)
cache[cache_key] = result
return result
wrapped_fn.cache = cache
return wrapped_fn
class WrappedAttributeError(Exception):
pass
class cached_attribute(object):
"""Computes attribute value and caches it in instance.
Example::
class MyClass(object):
@cached_attribute
def myMethod(self):
# ...
Use "del inst.myMethod" to clear cache."""
# http://code.activestate.com/recipes/276643/
def __init__(self, method, name=None):
self.method = method
self.name = name or method.__name__
def __get__(self, inst, cls):
if inst is None:
return self
try:
result = self.method(inst)
except AttributeError as e:
reraise(WrappedAttributeError, WrappedAttributeError(str(e)), sys.exc_info()[2])
setattr(inst, self.name, result)
return result
def stable_unique(items):
"""
Return a copy of ``items`` without duplicates. The order of other items is
unchanged.
>>> stable_unique([1,4,6,4,6,5,7])
[1, 4, 6, 5, 7]
"""
result = []
seen = set()
for item in items:
if item in seen:
continue
seen.add(item)
result.append(item)
return result
def longest_common_prefix(items1, items2):
"""
Return the longest common prefix.
>>> longest_common_prefix("abcde", "abcxy")
'abc'
:rtype:
``type(items1)``
"""
n = 0
for x1, x2 in zip(items1, items2):
if x1 != x2:
break
n += 1
return items1[:n]
def prefixes(parts):
"""
>>> list(prefixes("abcd"))
['a', 'ab', 'abc', 'abcd']
"""
for i in range(1, len(parts)+1):
yield parts[:i]
def indent(lines, prefix):
r"""
>>> indent('hello\nworld\n', '@@')
'@@hello\n@@world\n'
"""
return "".join("%s%s\n"%(prefix,line) for line in lines.splitlines(False))
def partition(iterable, predicate):
"""
>>> partition('12321233221', lambda c: int(c) % 2 == 0)
(['2', '2', '2', '2', '2'], ['1', '3', '1', '3', '3', '1'])
"""
falses = []
trues = []
for item in iterable:
if predicate(item):
trues.append(item)
else:
falses.append(item)
return trues, falses
Inf = float('Inf')
@contextmanager
def NullCtx():
"""
Context manager that does nothing.
"""
yield
@contextmanager
def ImportPathCtx(path_additions):
"""
Context manager that temporarily prepends ``sys.path`` with ``path_additions``.
"""
if not isinstance(path_additions, (tuple, list)):
path_additions = [path_additions]
old_path = sys.path[:]
sys.path[0:0] = path_additions
try:
yield
finally:
sys.path[:] = old_path
@contextmanager
def CwdCtx(path):
"""
Context manager that temporarily enters a new working directory.
"""
old_cwd = os.getcwd()
os.chdir(str(path))
try:
yield
finally:
os.chdir(old_cwd)
@contextmanager
def EnvVarCtx(**kwargs):
"""
Context manager that temporarily modifies os.environ.
"""
unset = object()
old = {}
try:
for k, v in kwargs.items():
old[k] = os.environ.get(k, unset)
os.environ[k] = v
yield
finally:
for k, v in old.items():
if v is unset:
del os.environ[k]
else:
os.environ[k] = v
@contextmanager
def ExcludeImplicitCwdFromPathCtx():
"""
Context manager that temporarily removes "." from ``sys.path``.
"""
old_path = sys.path
try:
sys.path = [p for p in sys.path if p not in (".", "")]
yield
finally:
sys.path[:] = old_path
class FunctionWithGlobals(object):
"""
A callable that at runtime adds extra variables to the target function's
global namespace.
This is written as a class with a __call__ method. We do so rather than
using a metafunction, so that we can also implement __getattr__ to look
through to the target.
"""
def __init__(self, function, **variables):
self.__function = function
self.__variables = variables
try:
self.__original__ = variables["__original__"]
except KeyError:
pass
def __call__(self, *args, **kwargs):
function = self.__function
variables = self.__variables
undecorated = function
while True:
try:
undecorated = undecorated.undecorated
except AttributeError:
break
globals = undecorated.__globals__
UNSET = object()
old = {}
for k in variables:
old[k] = globals.get(k, UNSET)
try:
for k, v in six.iteritems(variables):
globals[k] = v
return function(*args, **kwargs)
finally:
for k, v in six.iteritems(old):
if v is UNSET:
del globals[k]
else:
globals[k] = v
def __getattr__(self, k):
return getattr(self.__original__, k)
def __get__(self, inst, cls=None):
if PY3:
if inst is None:
return self
return types.MethodType(self, inst)
else:
return types.MethodType(self, inst, cls)
class _WritableDictProxy(object):
"""
Writable equivalent of cls.__dict__.
"""
# We need to implement __getitem__ differently from __setitem__. The
# reason is because of an asymmetry in the mechanics of classes:
# - getattr(cls, k) does NOT in general do what we want because it
# returns unbound methods. It's actually equivalent to
# cls.__dict__[k].__get__(cls).
# - setattr(cls, k, v) does do what we want.
# - cls.__dict__[k] does do what we want.
# - cls.__dict__[k] = v does not work, because dictproxy is read-only.
def __init__(self, cls):
self._cls = cls
def __getitem__(self, k):
return self._cls.__dict__[k]
def get(self, k, default=None):
return self._cls.__dict__.get(k, default)
def __setitem__(self, k, v):
setattr(self._cls, k, v)
def __delitem__(self, k):
delattr(self._cls, k)
_UNSET = object()
class Aspect(object):
"""
Monkey-patch a target method (joinpoint) with "around" advice.
The advice can call "__original__(...)". At run time, a global named
"__original__" will magically be available to the wrapped function.
This refers to the original function.
Suppose someone else wrote Foo.bar()::
>>> class Foo(object):
... def __init__(self, x):
... self.x = x
... def bar(self, y):
... return "bar(self.x=%s,y=%s)" % (self.x,y)
>>> foo = Foo(42)
To monkey patch ``foo.bar``, decorate the wrapper with ``"@advise(foo.bar)"``::
>>> @advise(foo.bar)
... def addthousand(y):
... return "advised foo.bar(y=%s): %s" % (y, __original__(y+1000))
>>> foo.bar(100)
'advised foo.bar(y=100): bar(self.x=42,y=1100)'
You can uninstall the advice and get the original behavior back::
>>> addthousand.unadvise()
>>> foo.bar(100)
'bar(self.x=42,y=100)'
:see:
http://en.wikipedia.org/wiki/Aspect-oriented_programming
"""
_wrapper = None
def __init__(self, joinpoint):
spec = joinpoint
while hasattr(joinpoint, "__joinpoint__"):
joinpoint = joinpoint.__joinpoint__
self._joinpoint = joinpoint
if (isinstance(joinpoint, (types.FunctionType, six.class_types, type))
and not (PY3 and joinpoint.__name__ != joinpoint.__qualname__)):
self._qname = "%s.%s" % (
joinpoint.__module__,
joinpoint.__name__)
self._container = sys.modules[joinpoint.__module__].__dict__
self._name = joinpoint.__name__
self._original = spec
assert spec == self._container[self._name], joinpoint
elif isinstance(joinpoint, types.MethodType) or (PY3 and isinstance(joinpoint,
types.FunctionType) and joinpoint.__name__ !=
joinpoint.__qualname__) or isinstance(joinpoint, property):
if isinstance(joinpoint, property):
joinpoint = joinpoint.fget
self._wrapper = property
if PY3:
self._qname = '%s.%s' % (joinpoint.__module__,
joinpoint.__qualname__)
self._name = joinpoint.__name__
else:
self._qname = "%s.%s.%s" % (
joinpoint.__self__.__class__.__module__,
joinpoint.__self__.__class__.__name__,
joinpoint.__func__.__name__)
self._name = joinpoint.__func__.__name__
if getattr(joinpoint, '__self__', None) is None:
# Unbound method in Python 2 only. In Python 3, there are no unbound methods
# (they are just functions).
if PY3:
container_obj = getattr(inspect.getmodule(joinpoint),
joinpoint.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
else:
container_obj = joinpoint.im_class
self._container = _WritableDictProxy(container_obj)
# __func__ gives the function for the Python 2 unbound method.
# In Python 3, spec is already a function.
self._original = getattr(spec, '__func__', spec)
else:
# Instance method.
container_obj = joinpoint.__self__
self._container = container_obj.__dict__
self._original = spec
assert spec == getattr(container_obj, self._name), (container_obj, self._qname)
assert self._original == self._container.get(self._name, self._original)
elif isinstance(joinpoint, tuple) and len(joinpoint) == 2:
container, name = joinpoint
if isinstance(container, dict):
self._original = container[name]
self._container = container
self._qname = name
elif name in container.__dict__.get('_trait_values', ()):
# traitlet stuff from IPython
self._container = container._trait_values
self._original = self._container[name]
self._qname = name
elif isinstance(container.__dict__, DictProxyType):
original = getattr(container, name)
if hasattr(original, "__func__"):
# TODO: generalize this to work for all cases, not just classmethod
original = original.__func__
self._wrapper = classmethod
self._original = original
self._container = _WritableDictProxy(container)
self._qname = "%s.%s.%s" % (
container.__module__, container.__name__, name)
else:
# Keep track of the original. We use getattr on the
# container, instead of getitem on container.__dict__, so that
# it works even if it's a class dict proxy that inherits the
# value from a super class.
self._original = getattr(container, name)
self._container = container.__dict__
self._qname = "%s.%s.%s" % (
container.__class__.__module__,
container.__class__.__name__,
name)
self._name = name
# TODO: unbound method
else:
raise TypeError("JoinPoint: unexpected type %s"
% (type(joinpoint).__name__,))
self._wrapped = None
def advise(self, hook, once=False):
from pyflyby._log import logger
self._previous = self._container.get(self._name, _UNSET)
if once and getattr(self._previous, "__aspect__", None) :
# TODO: check that it's the same hook - at least check the name.
logger.debug("already advised %s", self._qname)
return None
logger.debug("advising %s", self._qname)
assert self._previous is _UNSET or self._previous == self._original
assert self._wrapped is None
# Create the wrapped function.
wrapped = FunctionWithGlobals(hook, __original__=self._original)
wrapped.__joinpoint__ = self._joinpoint
wrapped.__original__ = self._original
wrapped.__name__ = "%s__advised__%s" % (self._name, hook.__name__)
wrapped.__doc__ = "%s.\n\nAdvice %s:\n%s" % (
self._original.__doc__, hook.__name__, hook.__doc__)
wrapped.__aspect__ = self
if self._wrapper is not None:
wrapped = self._wrapper(wrapped)
self._wrapped = wrapped
# Install the wrapped function!
self._container[self._name] = wrapped
return self
def unadvise(self):
if self._wrapped is None:
return
cur = self._container.get(self._name, _UNSET)
if cur is self._wrapped:
from pyflyby._log import logger
logger.debug("unadvising %s", self._qname)
if self._previous is _UNSET:
del self._container[self._name]
else:
self._container[self._name] = self._previous
elif cur == self._previous:
pass
else:
from pyflyby._log import logger
logger.debug("%s seems modified; not unadvising it", self._name)
self._wrapped = None
def advise(joinpoint):
"""
Advise ``joinpoint``.
See `Aspect`.
"""
aspect = Aspect(joinpoint)
return aspect.advise
@contextmanager
def AdviceCtx(joinpoint, hook):
aspect = Aspect(joinpoint)
advice = aspect.advise(hook)
try:
yield
finally:
advice.unadvise()
# For Python 2/3 compatibility. cmp isn't included with six.
def cmp(a, b):
return (a > b) - (a < b)
# Create a context manager with an arbitrary number of contexts. This is
# the same as Py2 contextlib.nested, but that one is removed in Py3.
if six.PY2:
from contextlib import nested
else:
from contextlib import ExitStack
@contextmanager
def nested(*mgrs):
with ExitStack() as stack:
ctxes = [stack.enter_context(mgr) for mgr in mgrs]
yield ctxes

View file

@ -0,0 +1,9 @@
# pyflyby/_version.py.
# License for THIS FILE ONLY: CC0 Public Domain Dedication
# http://creativecommons.org/publicdomain/zero/1.0/
from __future__ import (absolute_import, division, print_function,
with_statement)
__version__ = '1.7.4'

View file

@ -0,0 +1,21 @@
# pyflyby/autoimport.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
# Deprecated stub for backwards compatibility.
#
# Change your old code from:
# import pyflyby.autoimport
# pyflyby.autoimport.install_auto_importer()
# to:
# import pyflyby
# pyflyby.enable_auto_importer()
from __future__ import (absolute_import, division, print_function,
with_statement)
from pyflyby._interactive import enable_auto_importer
install_auto_importer = enable_auto_importer
__all__ = [install_auto_importer]

View file

@ -0,0 +1,20 @@
# pyflyby/importdb.py.
# Copyright (C) 2011, 2012, 2013, 2014 Karl Chen.
# License: MIT http://opensource.org/licenses/MIT
# Deprecated stub for backwards compatibility.
from __future__ import (absolute_import, division, print_function,
with_statement)
from pyflyby._importdb import ImportDB
def global_known_imports():
# Deprecated stub for backwards compatibility.
return ImportDB.get_default(".").known_imports
def global_mandatory_imports():
# Deprecated stub for backwards compatibility.
return ImportDB.get_default(".").mandatory_imports