This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,807 @@
"""Plugin for supporting the attrs library (http://www.attrs.org)"""
from mypy.backports import OrderedDict
from typing import Optional, Dict, List, cast, Tuple, Iterable
from typing_extensions import Final
import mypy.plugin # To avoid circular imports.
from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError
from mypy.lookup import lookup_fully_qualified
from mypy.nodes import (
Context, Argument, Var, ARG_OPT, ARG_POS, TypeInfo, AssignmentStmt,
TupleExpr, ListExpr, NameExpr, CallExpr, RefExpr, FuncDef,
is_class_var, TempNode, Decorator, MemberExpr, Expression,
SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, ARG_NAMED_OPT, ARG_NAMED,
TypeVarExpr, PlaceholderNode
)
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method,
deserialize_and_fixup_type, add_attribute_to_class,
)
from mypy.types import (
TupleType, Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarType,
Overloaded, UnionType, FunctionLike, Instance, get_proper_type,
LiteralType,
)
from mypy.typeops import make_simplified_union, map_type_from_supertype
from mypy.typevars import fill_typevars
from mypy.util import unmangle
from mypy.server.trigger import make_wildcard_trigger
KW_ONLY_PYTHON_2_UNSUPPORTED: Final = "kw_only is not supported in Python 2"
# The names of the different functions that create classes or arguments.
attr_class_makers: Final = {
'attr.s',
'attr.attrs',
'attr.attributes',
}
attr_dataclass_makers: Final = {
'attr.dataclass',
}
attr_frozen_makers: Final = {"attr.frozen"}
attr_define_makers: Final = {"attr.define", "attr.mutable"}
attr_attrib_makers: Final = {
'attr.ib',
'attr.attrib',
'attr.attr',
'attr.field',
}
SELF_TVAR_NAME: Final = "_AT"
MAGIC_ATTR_NAME: Final = "__attrs_attrs__"
MAGIC_ATTR_CLS_NAME: Final = "_AttrsAttributes" # The namedtuple subclass name.
class Converter:
"""Holds information about a `converter=` argument"""
def __init__(self,
name: Optional[str] = None,
is_attr_converters_optional: bool = False) -> None:
self.name = name
self.is_attr_converters_optional = is_attr_converters_optional
class Attribute:
"""The value of an attr.ib() call."""
def __init__(self, name: str, info: TypeInfo,
has_default: bool, init: bool, kw_only: bool, converter: Converter,
context: Context,
init_type: Optional[Type]) -> None:
self.name = name
self.info = info
self.has_default = has_default
self.init = init
self.kw_only = kw_only
self.converter = converter
self.context = context
self.init_type = init_type
def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument:
"""Return this attribute as an argument to __init__."""
assert self.init
init_type = self.init_type or self.info[self.name].type
if self.converter.name:
# When a converter is set the init_type is overridden by the first argument
# of the converter method.
converter = lookup_fully_qualified(self.converter.name, ctx.api.modules,
raise_on_missing=False)
if not converter:
# The converter may be a local variable. Check there too.
converter = ctx.api.lookup_qualified(self.converter.name, self.info, True)
# Get the type of the converter.
converter_type: Optional[Type] = None
if converter and isinstance(converter.node, TypeInfo):
from mypy.checkmember import type_object_type # To avoid import cycle.
converter_type = type_object_type(converter.node, ctx.api.named_type)
elif converter and isinstance(converter.node, OverloadedFuncDef):
converter_type = converter.node.type
elif converter and converter.type:
converter_type = converter.type
init_type = None
converter_type = get_proper_type(converter_type)
if isinstance(converter_type, CallableType) and converter_type.arg_types:
init_type = ctx.api.anal_type(converter_type.arg_types[0])
elif isinstance(converter_type, Overloaded):
types: List[Type] = []
for item in converter_type.items:
# Walk the overloads looking for methods that can accept one argument.
num_arg_types = len(item.arg_types)
if not num_arg_types:
continue
if num_arg_types > 1 and any(kind == ARG_POS for kind in item.arg_kinds[1:]):
continue
types.append(item.arg_types[0])
# Make a union of all the valid types.
if types:
args = make_simplified_union(types)
init_type = ctx.api.anal_type(args)
if self.converter.is_attr_converters_optional and init_type:
# If the converter was attr.converter.optional(type) then add None to
# the allowed init_type.
init_type = UnionType.make_union([init_type, NoneType()])
if not init_type:
ctx.api.fail("Cannot determine __init__ type from converter", self.context)
init_type = AnyType(TypeOfAny.from_error)
elif self.converter.name == '':
# This means we had a converter but it's not of a type we can infer.
# Error was shown in _get_converter_name
init_type = AnyType(TypeOfAny.from_error)
if init_type is None:
if ctx.api.options.disallow_untyped_defs:
# This is a compromise. If you don't have a type here then the
# __init__ will be untyped. But since the __init__ is added it's
# pointing at the decorator. So instead we also show the error in the
# assignment, which is where you would fix the issue.
node = self.info[self.name].node
assert node is not None
ctx.api.msg.need_annotation_for_var(node, self.context)
# Convert type not set to Any.
init_type = AnyType(TypeOfAny.unannotated)
if self.kw_only:
arg_kind = ARG_NAMED_OPT if self.has_default else ARG_NAMED
else:
arg_kind = ARG_OPT if self.has_default else ARG_POS
# Attrs removes leading underscores when creating the __init__ arguments.
return Argument(Var(self.name.lstrip("_"), init_type), init_type,
None,
arg_kind)
def serialize(self) -> JsonDict:
"""Serialize this object so it can be saved and restored."""
return {
'name': self.name,
'has_default': self.has_default,
'init': self.init,
'kw_only': self.kw_only,
'converter_name': self.converter.name,
'converter_is_attr_converters_optional': self.converter.is_attr_converters_optional,
'context_line': self.context.line,
'context_column': self.context.column,
'init_type': self.init_type.serialize() if self.init_type else None,
}
@classmethod
def deserialize(cls, info: TypeInfo,
data: JsonDict,
api: SemanticAnalyzerPluginInterface) -> 'Attribute':
"""Return the Attribute that was serialized."""
raw_init_type = data['init_type']
init_type = deserialize_and_fixup_type(raw_init_type, api) if raw_init_type else None
return Attribute(data['name'],
info,
data['has_default'],
data['init'],
data['kw_only'],
Converter(data['converter_name'], data['converter_is_attr_converters_optional']),
Context(line=data['context_line'], column=data['context_column']),
init_type)
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if not isinstance(self.init_type, TypeVarType):
return
self.init_type = map_type_from_supertype(self.init_type, sub_type, self.info)
def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool:
"""
Validate the combination of *cmp*, *eq*, and *order*. Derive the effective
value of order.
"""
cmp = _get_decorator_optional_bool_argument(ctx, 'cmp')
eq = _get_decorator_optional_bool_argument(ctx, 'eq')
order = _get_decorator_optional_bool_argument(ctx, 'order')
if cmp is not None and any((eq is not None, order is not None)):
ctx.api.fail('Don\'t mix "cmp" with "eq" and "order"', ctx.reason)
# cmp takes precedence due to bw-compatibility.
if cmp is not None:
return cmp
# If left None, equality is on and ordering mirrors equality.
if eq is None:
eq = True
if order is None:
order = eq
if eq is False and order is True:
ctx.api.fail('eq must be True if order is True', ctx.reason)
return order
def _get_decorator_optional_bool_argument(
ctx: 'mypy.plugin.ClassDefContext',
name: str,
default: Optional[bool] = None,
) -> Optional[bool]:
"""Return the Optional[bool] argument for the decorator.
This handles both @decorator(...) and @decorator.
"""
if isinstance(ctx.reason, CallExpr):
attr_value = _get_argument(ctx.reason, name)
if attr_value:
if isinstance(attr_value, NameExpr):
if attr_value.fullname == 'builtins.True':
return True
if attr_value.fullname == 'builtins.False':
return False
if attr_value.fullname == 'builtins.None':
return None
ctx.api.fail('"{}" argument must be True or False.'.format(name), ctx.reason)
return default
return default
else:
return default
def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext',
auto_attribs_default: Optional[bool] = False,
frozen_default: bool = False) -> None:
"""Add necessary dunder methods to classes decorated with attr.s.
attrs is a package that lets you define classes without writing dull boilerplate code.
At a quick glance, the decorator searches the class body for assignments of `attr.ib`s (or
annotated variables if auto_attribs=True), then depending on how the decorator is called,
it will add an __init__ or all the __cmp__ methods. For frozen=True it will turn the attrs
into properties.
See http://www.attrs.org/en/stable/how-does-it-work.html for information on how attrs works.
"""
info = ctx.cls.info
init = _get_decorator_bool_argument(ctx, 'init', True)
frozen = _get_frozen(ctx, frozen_default)
order = _determine_eq_order(ctx)
slots = _get_decorator_bool_argument(ctx, 'slots', False)
auto_attribs = _get_decorator_optional_bool_argument(ctx, 'auto_attribs', auto_attribs_default)
kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False)
match_args = _get_decorator_bool_argument(ctx, 'match_args', True)
if ctx.api.options.python_version[0] < 3:
if auto_attribs:
ctx.api.fail("auto_attribs is not supported in Python 2", ctx.reason)
return
if not info.defn.base_type_exprs:
# Note: This will not catch subclassing old-style classes.
ctx.api.fail("attrs only works with new-style classes", info.defn)
return
if kw_only:
ctx.api.fail(KW_ONLY_PYTHON_2_UNSUPPORTED, ctx.reason)
return
attributes = _analyze_class(ctx, auto_attribs, kw_only)
# Check if attribute types are ready.
for attr in attributes:
node = info.get(attr.name)
if node is None:
# This name is likely blocked by a star import. We don't need to defer because
# defer() is already called by mark_incomplete().
return
if node.type is None and not ctx.api.final_iteration:
ctx.api.defer()
return
_add_attrs_magic_attribute(ctx, [(attr.name, info[attr.name].type) for attr in attributes])
if slots:
_add_slots(ctx, attributes)
if match_args and ctx.api.options.python_version[:2] >= (3, 10):
# `.__match_args__` is only added for python3.10+, but the argument
# exists for earlier versions as well.
_add_match_args(ctx, attributes)
# Save the attributes so that subclasses can reuse them.
ctx.cls.info.metadata['attrs'] = {
'attributes': [attr.serialize() for attr in attributes],
'frozen': frozen,
}
adder = MethodAdder(ctx)
if init:
_add_init(ctx, attributes, adder)
if order:
_add_order(ctx, adder)
if frozen:
_make_frozen(ctx, attributes)
def _get_frozen(ctx: 'mypy.plugin.ClassDefContext', frozen_default: bool) -> bool:
"""Return whether this class is frozen."""
if _get_decorator_bool_argument(ctx, 'frozen', frozen_default):
return True
# Subclasses of frozen classes are frozen so check that.
for super_info in ctx.cls.info.mro[1:-1]:
if 'attrs' in super_info.metadata and super_info.metadata['attrs']['frozen']:
return True
return False
def _analyze_class(ctx: 'mypy.plugin.ClassDefContext',
auto_attribs: Optional[bool],
kw_only: bool) -> List[Attribute]:
"""Analyze the class body of an attr maker, its parents, and return the Attributes found.
auto_attribs=True means we'll generate attributes from type annotations also.
auto_attribs=None means we'll detect which mode to use.
kw_only=True means that all attributes created here will be keyword only args in __init__.
"""
own_attrs: OrderedDict[str, Attribute] = OrderedDict()
if auto_attribs is None:
auto_attribs = _detect_auto_attribs(ctx)
# Walk the body looking for assignments and decorators.
for stmt in ctx.cls.defs.body:
if isinstance(stmt, AssignmentStmt):
for attr in _attributes_from_assignment(ctx, stmt, auto_attribs, kw_only):
# When attrs are defined twice in the same body we want to use the 2nd definition
# in the 2nd location. So remove it from the OrderedDict.
# Unless it's auto_attribs in which case we want the 2nd definition in the
# 1st location.
if not auto_attribs and attr.name in own_attrs:
del own_attrs[attr.name]
own_attrs[attr.name] = attr
elif isinstance(stmt, Decorator):
_cleanup_decorator(stmt, own_attrs)
for attribute in own_attrs.values():
# Even though these look like class level assignments we want them to look like
# instance level assignments.
if attribute.name in ctx.cls.info.names:
node = ctx.cls.info.names[attribute.name].node
if isinstance(node, PlaceholderNode):
# This node is not ready yet.
continue
assert isinstance(node, Var)
node.is_initialized_in_class = False
# Traverse the MRO and collect attributes from the parents.
taken_attr_names = set(own_attrs)
super_attrs = []
for super_info in ctx.cls.info.mro[1:-1]:
if 'attrs' in super_info.metadata:
# Each class depends on the set of attributes in its attrs ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(super_info.fullname))
for data in super_info.metadata['attrs']['attributes']:
# Only add an attribute if it hasn't been defined before. This
# allows for overwriting attribute definitions by subclassing.
if data['name'] not in taken_attr_names:
a = Attribute.deserialize(super_info, data, ctx.api)
a.expand_typevar_from_subtype(ctx.cls.info)
super_attrs.append(a)
taken_attr_names.add(a.name)
attributes = super_attrs + list(own_attrs.values())
# Check the init args for correct default-ness. Note: This has to be done after all the
# attributes for all classes have been read, because subclasses can override parents.
last_default = False
for i, attribute in enumerate(attributes):
if not attribute.init:
continue
if attribute.kw_only:
# Keyword-only attributes don't care whether they are default or not.
continue
# If the issue comes from merging different classes, report it
# at the class definition point.
context = attribute.context if i >= len(super_attrs) else ctx.cls
if not attribute.has_default and last_default:
ctx.api.fail(
"Non-default attributes not allowed after default attributes.",
context)
last_default |= attribute.has_default
return attributes
def _detect_auto_attribs(ctx: 'mypy.plugin.ClassDefContext') -> bool:
"""Return whether auto_attribs should be enabled or disabled.
It's disabled if there are any unannotated attribs()
"""
for stmt in ctx.cls.defs.body:
if isinstance(stmt, AssignmentStmt):
for lvalue in stmt.lvalues:
lvalues, rvalues = _parse_assignments(lvalue, stmt)
if len(lvalues) != len(rvalues):
# This means we have some assignment that isn't 1 to 1.
# It can't be an attrib.
continue
for lhs, rvalue in zip(lvalues, rvalues):
# Check if the right hand side is a call to an attribute maker.
if (isinstance(rvalue, CallExpr)
and isinstance(rvalue.callee, RefExpr)
and rvalue.callee.fullname in attr_attrib_makers
and not stmt.new_syntax):
# This means we have an attrib without an annotation and so
# we can't do auto_attribs=True
return False
return True
def _attributes_from_assignment(ctx: 'mypy.plugin.ClassDefContext',
stmt: AssignmentStmt, auto_attribs: bool,
kw_only: bool) -> Iterable[Attribute]:
"""Return Attribute objects that are created by this assignment.
The assignments can look like this:
x = attr.ib()
x = y = attr.ib()
x, y = attr.ib(), attr.ib()
or if auto_attribs is enabled also like this:
x: type
x: type = default_value
"""
for lvalue in stmt.lvalues:
lvalues, rvalues = _parse_assignments(lvalue, stmt)
if len(lvalues) != len(rvalues):
# This means we have some assignment that isn't 1 to 1.
# It can't be an attrib.
continue
for lhs, rvalue in zip(lvalues, rvalues):
# Check if the right hand side is a call to an attribute maker.
if (isinstance(rvalue, CallExpr)
and isinstance(rvalue.callee, RefExpr)
and rvalue.callee.fullname in attr_attrib_makers):
attr = _attribute_from_attrib_maker(ctx, auto_attribs, kw_only, lhs, rvalue, stmt)
if attr:
yield attr
elif auto_attribs and stmt.type and stmt.new_syntax and not is_class_var(lhs):
yield _attribute_from_auto_attrib(ctx, kw_only, lhs, rvalue, stmt)
def _cleanup_decorator(stmt: Decorator, attr_map: Dict[str, Attribute]) -> None:
"""Handle decorators in class bodies.
`x.default` will set a default value on x
`x.validator` and `x.default` will get removed to avoid throwing a type error.
"""
remove_me = []
for func_decorator in stmt.decorators:
if (isinstance(func_decorator, MemberExpr)
and isinstance(func_decorator.expr, NameExpr)
and func_decorator.expr.name in attr_map):
if func_decorator.name == 'default':
attr_map[func_decorator.expr.name].has_default = True
if func_decorator.name in ('default', 'validator'):
# These are decorators on the attrib object that only exist during
# class creation time. In order to not trigger a type error later we
# just remove them. This might leave us with a Decorator with no
# decorators (Emperor's new clothes?)
# TODO: It would be nice to type-check these rather than remove them.
# default should be Callable[[], T]
# validator should be Callable[[Any, 'Attribute', T], Any]
# where T is the type of the attribute.
remove_me.append(func_decorator)
for dec in remove_me:
stmt.decorators.remove(dec)
def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext',
kw_only: bool,
lhs: NameExpr,
rvalue: Expression,
stmt: AssignmentStmt) -> Attribute:
"""Return an Attribute for a new type assignment."""
name = unmangle(lhs.name)
# `x: int` (without equal sign) assigns rvalue to TempNode(AnyType())
has_rhs = not isinstance(rvalue, TempNode)
sym = ctx.cls.info.names.get(name)
init_type = sym.type if sym else None
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt, init_type)
def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
auto_attribs: bool,
kw_only: bool,
lhs: NameExpr,
rvalue: CallExpr,
stmt: AssignmentStmt) -> Optional[Attribute]:
"""Return an Attribute from the assignment or None if you can't make one."""
if auto_attribs and not stmt.new_syntax:
# auto_attribs requires an annotation on *every* attr.ib.
assert lhs.node is not None
ctx.api.msg.need_annotation_for_var(lhs.node, stmt)
return None
if len(stmt.lvalues) > 1:
ctx.api.fail("Too many names for one attribute", stmt)
return None
# This is the type that belongs in the __init__ method for this attrib.
init_type = stmt.type
# Read all the arguments from the call.
init = _get_bool_argument(ctx, rvalue, 'init', True)
# Note: If the class decorator says kw_only=True the attribute is ignored.
# See https://github.com/python-attrs/attrs/issues/481 for explanation.
kw_only |= _get_bool_argument(ctx, rvalue, 'kw_only', False)
if kw_only and ctx.api.options.python_version[0] < 3:
ctx.api.fail(KW_ONLY_PYTHON_2_UNSUPPORTED, stmt)
return None
# TODO: Check for attr.NOTHING
attr_has_default = bool(_get_argument(rvalue, 'default'))
attr_has_factory = bool(_get_argument(rvalue, 'factory'))
if attr_has_default and attr_has_factory:
ctx.api.fail('Can\'t pass both "default" and "factory".', rvalue)
elif attr_has_factory:
attr_has_default = True
# If the type isn't set through annotation but is passed through `type=` use that.
type_arg = _get_argument(rvalue, 'type')
if type_arg and not init_type:
try:
un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options, ctx.api.is_stub_file)
except TypeTranslationError:
ctx.api.fail('Invalid argument to type', type_arg)
else:
init_type = ctx.api.anal_type(un_type)
if init_type and isinstance(lhs.node, Var) and not lhs.node.type:
# If there is no annotation, add one.
lhs.node.type = init_type
lhs.is_inferred_def = False
# Note: convert is deprecated but works the same as converter.
converter = _get_argument(rvalue, 'converter')
convert = _get_argument(rvalue, 'convert')
if convert and converter:
ctx.api.fail('Can\'t pass both "convert" and "converter".', rvalue)
elif convert:
ctx.api.fail("convert is deprecated, use converter", rvalue)
converter = convert
converter_info = _parse_converter(ctx, converter)
name = unmangle(lhs.name)
return Attribute(name, ctx.cls.info, attr_has_default, init,
kw_only, converter_info, stmt, init_type)
def _parse_converter(ctx: 'mypy.plugin.ClassDefContext',
converter: Optional[Expression]) -> Converter:
"""Return the Converter object from an Expression."""
# TODO: Support complex converters, e.g. lambdas, calls, etc.
if converter:
if isinstance(converter, RefExpr) and converter.node:
if (isinstance(converter.node, FuncDef)
and converter.node.type
and isinstance(converter.node.type, FunctionLike)):
return Converter(converter.node.fullname)
elif (isinstance(converter.node, OverloadedFuncDef)
and is_valid_overloaded_converter(converter.node)):
return Converter(converter.node.fullname)
elif isinstance(converter.node, TypeInfo):
return Converter(converter.node.fullname)
if (isinstance(converter, CallExpr)
and isinstance(converter.callee, RefExpr)
and converter.callee.fullname == "attr.converters.optional"
and converter.args
and converter.args[0]):
# Special handling for attr.converters.optional(type)
# We extract the type and add make the init_args Optional in Attribute.argument
argument = _parse_converter(ctx, converter.args[0])
argument.is_attr_converters_optional = True
return argument
# Signal that we have an unsupported converter.
ctx.api.fail(
"Unsupported converter, only named functions and types are currently supported",
converter
)
return Converter('')
return Converter(None)
def is_valid_overloaded_converter(defn: OverloadedFuncDef) -> bool:
return all((not isinstance(item, Decorator) or isinstance(item.func.type, FunctionLike))
for item in defn.items)
def _parse_assignments(
lvalue: Expression,
stmt: AssignmentStmt) -> Tuple[List[NameExpr], List[Expression]]:
"""Convert a possibly complex assignment expression into lists of lvalues and rvalues."""
lvalues: List[NameExpr] = []
rvalues: List[Expression] = []
if isinstance(lvalue, (TupleExpr, ListExpr)):
if all(isinstance(item, NameExpr) for item in lvalue.items):
lvalues = cast(List[NameExpr], lvalue.items)
if isinstance(stmt.rvalue, (TupleExpr, ListExpr)):
rvalues = stmt.rvalue.items
elif isinstance(lvalue, NameExpr):
lvalues = [lvalue]
rvalues = [stmt.rvalue]
return lvalues, rvalues
def _add_order(ctx: 'mypy.plugin.ClassDefContext', adder: 'MethodAdder') -> None:
"""Generate all the ordering methods for this class."""
bool_type = ctx.api.named_type('builtins.bool')
object_type = ctx.api.named_type('builtins.object')
# Make the types be:
# AT = TypeVar('AT')
# def __lt__(self: AT, other: AT) -> bool
# This way comparisons with subclasses will work correctly.
tvd = TypeVarType(SELF_TVAR_NAME, ctx.cls.info.fullname + '.' + SELF_TVAR_NAME,
-1, [], object_type)
self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, ctx.cls.info.fullname + '.' + SELF_TVAR_NAME,
[], object_type)
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
args = [Argument(Var('other', tvd), tvd, None, ARG_POS)]
for method in ['__lt__', '__le__', '__gt__', '__ge__']:
adder.add_method(method, args, bool_type, self_type=tvd, tvd=tvd)
def _make_frozen(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute]) -> None:
"""Turn all the attributes into properties to simulate frozen classes."""
for attribute in attributes:
if attribute.name in ctx.cls.info.names:
# This variable belongs to this class so we can modify it.
node = ctx.cls.info.names[attribute.name].node
assert isinstance(node, Var)
node.is_property = True
else:
# This variable belongs to a super class so create new Var so we
# can modify it.
var = Var(attribute.name, ctx.cls.info[attribute.name].type)
var.info = ctx.cls.info
var._fullname = '%s.%s' % (ctx.cls.info.fullname, var.name)
ctx.cls.info.names[var.name] = SymbolTableNode(MDEF, var)
var.is_property = True
def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute],
adder: 'MethodAdder') -> None:
"""Generate an __init__ method for the attributes and add it to the class."""
# Convert attributes to arguments with kw_only arguments at the end of
# the argument list
pos_args = []
kw_only_args = []
for attribute in attributes:
if not attribute.init:
continue
if attribute.kw_only:
kw_only_args.append(attribute.argument(ctx))
else:
pos_args.append(attribute.argument(ctx))
args = pos_args + kw_only_args
if all(
# We use getattr rather than instance checks because the variable.type
# might be wrapped into a Union or some other type, but even non-Any
# types reliably track the fact that the argument was not annotated.
getattr(arg.variable.type, "type_of_any", None) == TypeOfAny.unannotated
for arg in args
):
# This workaround makes --disallow-incomplete-defs usable with attrs,
# but is definitely suboptimal as a long-term solution.
# See https://github.com/python/mypy/issues/5954 for discussion.
for a in args:
a.variable.type = AnyType(TypeOfAny.implementation_artifact)
a.type_annotation = AnyType(TypeOfAny.implementation_artifact)
adder.add_method('__init__', args, NoneType())
def _add_attrs_magic_attribute(ctx: 'mypy.plugin.ClassDefContext',
attrs: 'List[Tuple[str, Optional[Type]]]') -> None:
any_type = AnyType(TypeOfAny.explicit)
attributes_types: 'List[Type]' = [
ctx.api.named_type_or_none('attr.Attribute', [attr_type or any_type]) or any_type
for _, attr_type in attrs
]
fallback_type = ctx.api.named_type('builtins.tuple', [
ctx.api.named_type_or_none('attr.Attribute', [any_type]) or any_type,
])
ti = ctx.api.basic_new_typeinfo(MAGIC_ATTR_CLS_NAME, fallback_type, 0)
ti.is_named_tuple = True
for (name, _), attr_type in zip(attrs, attributes_types):
var = Var(name, attr_type)
var.is_property = True
proper_type = get_proper_type(attr_type)
if isinstance(proper_type, Instance):
var.info = proper_type.type
ti.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
attributes_type = Instance(ti, [])
# TODO: refactor using `add_attribute_to_class`
var = Var(name=MAGIC_ATTR_NAME, type=TupleType(attributes_types, fallback=attributes_type))
var.info = ctx.cls.info
var.is_classvar = True
var._fullname = f"{ctx.cls.fullname}.{MAGIC_ATTR_CLS_NAME}"
var.allow_incompatible_override = True
ctx.cls.info.names[MAGIC_ATTR_NAME] = SymbolTableNode(
kind=MDEF,
node=var,
plugin_generated=True,
no_serialize=True,
)
def _add_slots(ctx: 'mypy.plugin.ClassDefContext',
attributes: List[Attribute]) -> None:
# Unlike `@dataclasses.dataclass`, `__slots__` is rewritten here.
ctx.cls.info.slots = {attr.name for attr in attributes}
def _add_match_args(ctx: 'mypy.plugin.ClassDefContext',
attributes: List[Attribute]) -> None:
if ('__match_args__' not in ctx.cls.info.names
or ctx.cls.info.names['__match_args__'].plugin_generated):
str_type = ctx.api.named_type('builtins.str')
match_args = TupleType(
[
str_type.copy_modified(
last_known_value=LiteralType(attr.name, fallback=str_type),
)
for attr in attributes
if not attr.kw_only and attr.init
],
fallback=ctx.api.named_type('builtins.tuple'),
)
add_attribute_to_class(
api=ctx.api,
cls=ctx.cls,
name='__match_args__',
typ=match_args,
)
class MethodAdder:
"""Helper to add methods to a TypeInfo.
ctx: The ClassDefCtx we are using on which we will add methods.
"""
# TODO: Combine this with the code build_namedtuple_typeinfo to support both.
def __init__(self, ctx: 'mypy.plugin.ClassDefContext') -> None:
self.ctx = ctx
self.self_type = fill_typevars(ctx.cls.info)
def add_method(self,
method_name: str, args: List[Argument], ret_type: Type,
self_type: Optional[Type] = None,
tvd: Optional[TypeVarType] = None) -> None:
"""Add a method: def <method_name>(self, <args>) -> <ret_type>): ... to info.
self_type: The type to use for the self argument or None to use the inferred self type.
tvd: If the method is generic these should be the type variables.
"""
self_type = self_type if self_type is not None else self.self_type
add_method(self.ctx, method_name, args, ret_type, self_type, tvd)

View file

@ -0,0 +1,202 @@
from typing import List, Optional, Union
from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
)
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.semanal import set_callable_name, ALLOW_INCOMPATIBLE_OVERRIDE
from mypy.types import (
CallableType, Overloaded, Type, TypeVarType, deserialize_type, get_proper_type,
)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
from mypy.fixup import TypeFixer
def _get_decorator_bool_argument(
ctx: ClassDefContext,
name: str,
default: bool,
) -> bool:
"""Return the bool argument for the decorator.
This handles both @decorator(...) and @decorator.
"""
if isinstance(ctx.reason, CallExpr):
return _get_bool_argument(ctx, ctx.reason, name, default)
else:
return default
def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr,
name: str, default: bool) -> bool:
"""Return the boolean value for an argument to a call or the
default if it's not found.
"""
attr_value = _get_argument(expr, name)
if attr_value:
ret = ctx.api.parse_bool(attr_value)
if ret is None:
ctx.api.fail('"{}" argument must be True or False.'.format(name), expr)
return default
return ret
return default
def _get_argument(call: CallExpr, name: str) -> Optional[Expression]:
"""Return the expression for the specific argument."""
# To do this we use the CallableType of the callee to find the FormalArgument,
# then walk the actual CallExpr looking for the appropriate argument.
#
# Note: I'm not hard-coding the index so that in the future we can support other
# attrib and class makers.
if not isinstance(call.callee, RefExpr):
return None
callee_type = None
callee_node = call.callee.node
if (isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES))
and callee_node.type):
callee_node_type = get_proper_type(callee_node.type)
if isinstance(callee_node_type, Overloaded):
# We take the last overload.
callee_type = callee_node_type.items[-1]
elif isinstance(callee_node_type, CallableType):
callee_type = callee_node_type
if not callee_type:
return None
argument = callee_type.argument_by_name(name)
if not argument:
return None
assert argument.name
for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)):
if argument.pos is not None and not attr_name and i == argument.pos:
return attr_value
if attr_name == argument.name:
return attr_value
return None
def add_method(
ctx: ClassDefContext,
name: str,
args: List[Argument],
return_type: Type,
self_type: Optional[Type] = None,
tvar_def: Optional[TypeVarType] = None,
) -> None:
"""
Adds a new method to a class.
Deprecated, use add_method_to_class() instead.
"""
add_method_to_class(ctx.api, ctx.cls,
name=name,
args=args,
return_type=return_type,
self_type=self_type,
tvar_def=tvar_def)
def add_method_to_class(
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface],
cls: ClassDef,
name: str,
args: List[Argument],
return_type: Type,
self_type: Optional[Type] = None,
tvar_def: Optional[TypeVarType] = None,
) -> None:
"""Adds a new method to a class definition."""
info = cls.info
# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
if name in info.names:
sym = info.names[name]
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)
self_type = self_type or fill_typevars(info)
if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type('builtins.function')
else:
function_type = api.named_generic_type('builtins.function', [])
args = [Argument(Var('self'), self_type, None, ARG_POS)] + args
arg_types, arg_names, arg_kinds = [], [], []
for arg in args:
assert arg.type_annotation, 'All arguments must be fully typed.'
arg_types.append(arg.type_annotation)
arg_names.append(arg.variable.name)
arg_kinds.append(arg.kind)
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
func.type = set_callable_name(signature, func)
func._fullname = info.fullname + '.' + name
func.line = info.line
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
info.defn.defs.body.append(func)
def add_attribute_to_class(
api: SemanticAnalyzerPluginInterface,
cls: ClassDef,
name: str,
typ: Type,
final: bool = False,
no_serialize: bool = False,
override_allow_incompatible: bool = False,
) -> None:
"""
Adds a new attribute to a class definition.
This currently only generates the symbol table entry and no corresponding AssignmentStatement
"""
info = cls.info
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
node = Var(name, typ)
node.info = info
node.is_final = final
if name in ALLOW_INCOMPATIBLE_OVERRIDE:
node.allow_incompatible_override = True
else:
node.allow_incompatible_override = override_allow_incompatible
node._fullname = info.fullname + '.' + name
info.names[name] = SymbolTableNode(
MDEF,
node,
plugin_generated=True,
no_serialize=no_serialize,
)
def deserialize_and_fixup_type(
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ

View file

@ -0,0 +1,229 @@
"""Plugin to provide accurate types for some parts of the ctypes module."""
from typing import List, Optional
# Fully qualified instead of "from mypy.plugin import ..." to avoid circular import problems.
import mypy.plugin
from mypy import nodes
from mypy.maptype import map_instance_to_supertype
from mypy.messages import format_type
from mypy.subtypes import is_subtype
from mypy.types import (
AnyType, CallableType, Instance, NoneType, Type, TypeOfAny, UnionType,
union_items, ProperType, get_proper_type
)
from mypy.typeops import make_simplified_union
def _get_bytes_type(api: 'mypy.plugin.CheckerPluginInterface') -> Instance:
"""Return the type corresponding to bytes on the current Python version.
This is bytes in Python 3, and str in Python 2.
"""
return api.named_generic_type(
'builtins.bytes' if api.options.python_version >= (3,) else 'builtins.str', [])
def _get_text_type(api: 'mypy.plugin.CheckerPluginInterface') -> Instance:
"""Return the type corresponding to Text on the current Python version.
This is str in Python 3, and unicode in Python 2.
"""
return api.named_generic_type(
'builtins.str' if api.options.python_version >= (3,) else 'builtins.unicode', [])
def _find_simplecdata_base_arg(tp: Instance, api: 'mypy.plugin.CheckerPluginInterface'
) -> Optional[ProperType]:
"""Try to find a parametrized _SimpleCData in tp's bases and return its single type argument.
None is returned if _SimpleCData appears nowhere in tp's (direct or indirect) bases.
"""
if tp.type.has_base('ctypes._SimpleCData'):
simplecdata_base = map_instance_to_supertype(tp,
api.named_generic_type('ctypes._SimpleCData', [AnyType(TypeOfAny.special_form)]).type)
assert len(simplecdata_base.args) == 1, '_SimpleCData takes exactly one type argument'
return get_proper_type(simplecdata_base.args[0])
return None
def _autoconvertible_to_cdata(tp: Type, api: 'mypy.plugin.CheckerPluginInterface') -> Type:
"""Get a type that is compatible with all types that can be implicitly converted to the given
CData type.
Examples:
* c_int -> Union[c_int, int]
* c_char_p -> Union[c_char_p, bytes, int, NoneType]
* MyStructure -> MyStructure
"""
allowed_types = []
# If tp is a union, we allow all types that are convertible to at least one of the union
# items. This is not quite correct - strictly speaking, only types convertible to *all* of the
# union items should be allowed. This may be worth changing in the future, but the more
# correct algorithm could be too strict to be useful.
for t in union_items(tp):
# Every type can be converted from itself (obviously).
allowed_types.append(t)
if isinstance(t, Instance):
unboxed = _find_simplecdata_base_arg(t, api)
if unboxed is not None:
# If _SimpleCData appears in tp's (direct or indirect) bases, its type argument
# specifies the type's "unboxed" version, which can always be converted back to
# the original "boxed" type.
allowed_types.append(unboxed)
if t.type.has_base('ctypes._PointerLike'):
# Pointer-like _SimpleCData subclasses can also be converted from
# an int or None.
allowed_types.append(api.named_generic_type('builtins.int', []))
allowed_types.append(NoneType())
return make_simplified_union(allowed_types)
def _autounboxed_cdata(tp: Type) -> ProperType:
"""Get the auto-unboxed version of a CData type, if applicable.
For *direct* _SimpleCData subclasses, the only type argument of _SimpleCData in the bases list
is returned.
For all other CData types, including indirect _SimpleCData subclasses, tp is returned as-is.
"""
tp = get_proper_type(tp)
if isinstance(tp, UnionType):
return make_simplified_union([_autounboxed_cdata(t) for t in tp.items])
elif isinstance(tp, Instance):
for base in tp.type.bases:
if base.type.fullname == 'ctypes._SimpleCData':
# If tp has _SimpleCData as a direct base class,
# the auto-unboxed type is the single type argument of the _SimpleCData type.
assert len(base.args) == 1
return get_proper_type(base.args[0])
# If tp is not a concrete type, or if there is no _SimpleCData in the bases,
# the type is not auto-unboxed.
return tp
def _get_array_element_type(tp: Type) -> Optional[ProperType]:
"""Get the element type of the Array type tp, or None if not specified."""
tp = get_proper_type(tp)
if isinstance(tp, Instance):
assert tp.type.fullname == 'ctypes.Array'
if len(tp.args) == 1:
return get_proper_type(tp.args[0])
return None
def array_constructor_callback(ctx: 'mypy.plugin.FunctionContext') -> Type:
"""Callback to provide an accurate signature for the ctypes.Array constructor."""
# Extract the element type from the constructor's return type, i. e. the type of the array
# being constructed.
et = _get_array_element_type(ctx.default_return_type)
if et is not None:
allowed = _autoconvertible_to_cdata(et, ctx.api)
assert len(ctx.arg_types) == 1, \
"The stub of the ctypes.Array constructor should have a single vararg parameter"
for arg_num, (arg_kind, arg_type) in enumerate(zip(ctx.arg_kinds[0], ctx.arg_types[0]), 1):
if arg_kind == nodes.ARG_POS and not is_subtype(arg_type, allowed):
ctx.api.msg.fail(
'Array constructor argument {} of type {}'
' is not convertible to the array element type {}'
.format(arg_num, format_type(arg_type), format_type(et)), ctx.context)
elif arg_kind == nodes.ARG_STAR:
ty = ctx.api.named_generic_type("typing.Iterable", [allowed])
if not is_subtype(arg_type, ty):
it = ctx.api.named_generic_type("typing.Iterable", [et])
ctx.api.msg.fail(
'Array constructor argument {} of type {}'
' is not convertible to the array element type {}'
.format(arg_num, format_type(arg_type), format_type(it)), ctx.context)
return ctx.default_return_type
def array_getitem_callback(ctx: 'mypy.plugin.MethodContext') -> Type:
"""Callback to provide an accurate return type for ctypes.Array.__getitem__."""
et = _get_array_element_type(ctx.type)
if et is not None:
unboxed = _autounboxed_cdata(et)
assert len(ctx.arg_types) == 1, \
'The stub of ctypes.Array.__getitem__ should have exactly one parameter'
assert len(ctx.arg_types[0]) == 1, \
"ctypes.Array.__getitem__'s parameter should not be variadic"
index_type = get_proper_type(ctx.arg_types[0][0])
if isinstance(index_type, Instance):
if index_type.type.has_base('builtins.int'):
return unboxed
elif index_type.type.has_base('builtins.slice'):
return ctx.api.named_generic_type('builtins.list', [unboxed])
return ctx.default_return_type
def array_setitem_callback(ctx: 'mypy.plugin.MethodSigContext') -> CallableType:
"""Callback to provide an accurate signature for ctypes.Array.__setitem__."""
et = _get_array_element_type(ctx.type)
if et is not None:
allowed = _autoconvertible_to_cdata(et, ctx.api)
assert len(ctx.default_signature.arg_types) == 2
index_type = get_proper_type(ctx.default_signature.arg_types[0])
if isinstance(index_type, Instance):
arg_type = None
if index_type.type.has_base('builtins.int'):
arg_type = allowed
elif index_type.type.has_base('builtins.slice'):
arg_type = ctx.api.named_generic_type('builtins.list', [allowed])
if arg_type is not None:
# Note: arg_type can only be None if index_type is invalid, in which case we use
# the default signature and let mypy report an error about it.
return ctx.default_signature.copy_modified(
arg_types=ctx.default_signature.arg_types[:1] + [arg_type],
)
return ctx.default_signature
def array_iter_callback(ctx: 'mypy.plugin.MethodContext') -> Type:
"""Callback to provide an accurate return type for ctypes.Array.__iter__."""
et = _get_array_element_type(ctx.type)
if et is not None:
unboxed = _autounboxed_cdata(et)
return ctx.api.named_generic_type('typing.Iterator', [unboxed])
return ctx.default_return_type
def array_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
"""Callback to provide an accurate type for ctypes.Array.value."""
et = _get_array_element_type(ctx.type)
if et is not None:
types: List[Type] = []
for tp in union_items(et):
if isinstance(tp, AnyType):
types.append(AnyType(TypeOfAny.from_another_any, source_any=tp))
elif isinstance(tp, Instance) and tp.type.fullname == 'ctypes.c_char':
types.append(_get_bytes_type(ctx.api))
elif isinstance(tp, Instance) and tp.type.fullname == 'ctypes.c_wchar':
types.append(_get_text_type(ctx.api))
else:
ctx.api.msg.fail(
'Array attribute "value" is only available'
' with element type "c_char" or "c_wchar", not {}'
.format(format_type(et)), ctx.context)
return make_simplified_union(types)
return ctx.default_attr_type
def array_raw_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
"""Callback to provide an accurate type for ctypes.Array.raw."""
et = _get_array_element_type(ctx.type)
if et is not None:
types: List[Type] = []
for tp in union_items(et):
if (isinstance(tp, AnyType)
or isinstance(tp, Instance) and tp.type.fullname == 'ctypes.c_char'):
types.append(_get_bytes_type(ctx.api))
else:
ctx.api.msg.fail(
'Array attribute "raw" is only available'
' with element type "c_char", not {}'
.format(format_type(et)), ctx.context)
return make_simplified_union(types)
return ctx.default_attr_type

View file

@ -0,0 +1,548 @@
"""Plugin that provides support for dataclasses."""
from typing import Dict, List, Set, Tuple, Optional
from typing_extensions import Final
from mypy.nodes import (
ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, ARG_STAR, ARG_STAR2, MDEF,
Argument, AssignmentStmt, CallExpr, Context, Expression, JsonDict,
NameExpr, RefExpr, SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr,
PlaceholderNode
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, add_attribute_to_class,
)
from mypy.typeops import map_type_from_supertype
from mypy.types import (
Type, Instance, NoneType, TypeVarType, CallableType, TupleType, LiteralType,
get_proper_type, AnyType, TypeOfAny,
)
from mypy.server.trigger import make_wildcard_trigger
# The set of decorators that generate dataclasses.
dataclass_makers: Final = {
'dataclass',
'dataclasses.dataclass',
}
# The set of functions that generate dataclass fields.
field_makers: Final = {
'dataclasses.field',
}
SELF_TVAR_NAME: Final = "_DT"
class DataclassAttribute:
def __init__(
self,
name: str,
is_in_init: bool,
is_init_var: bool,
has_default: bool,
line: int,
column: int,
type: Optional[Type],
info: TypeInfo,
kw_only: bool,
) -> None:
self.name = name
self.is_in_init = is_in_init
self.is_init_var = is_init_var
self.has_default = has_default
self.line = line
self.column = column
self.type = type
self.info = info
self.kw_only = kw_only
def to_argument(self) -> Argument:
arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
return Argument(
variable=self.to_var(),
type_annotation=self.type,
initializer=None,
kind=arg_kind,
)
def to_var(self) -> Var:
return Var(self.name, self.type)
def serialize(self) -> JsonDict:
assert self.type
return {
'name': self.name,
'is_in_init': self.is_in_init,
'is_init_var': self.is_init_var,
'has_default': self.has_default,
'line': self.line,
'column': self.column,
'type': self.type.serialize(),
'kw_only': self.kw_only,
}
@classmethod
def deserialize(
cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface
) -> 'DataclassAttribute':
data = data.copy()
if data.get('kw_only') is None:
data['kw_only'] = False
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data)
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if not isinstance(self.type, TypeVarType):
return
self.type = map_type_from_supertype(self.type, sub_type, self.info)
class DataclassTransformer:
def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx
def transform(self) -> None:
"""Apply all the necessary transformations to the underlying
dataclass so as to ensure it is fully type checked according
to the rules in PEP 557.
"""
ctx = self._ctx
info = self._ctx.cls.info
attributes = self.collect_attributes()
if attributes is None:
# Some definitions are not ready, defer() should be already called.
return
for attr in attributes:
if attr.type is None:
ctx.api.defer()
return
decorator_arguments = {
'init': _get_decorator_bool_argument(self._ctx, 'init', True),
'eq': _get_decorator_bool_argument(self._ctx, 'eq', True),
'order': _get_decorator_bool_argument(self._ctx, 'order', False),
'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False),
'slots': _get_decorator_bool_argument(self._ctx, 'slots', False),
'match_args': _get_decorator_bool_argument(self._ctx, 'match_args', True),
}
py_version = self._ctx.api.options.python_version
# If there are no attributes, it may be that the semantic analyzer has not
# processed them yet. In order to work around this, we can simply skip generating
# __init__ if there are no attributes, because if the user truly did not define any,
# then the object default __init__ with an empty signature will be present anyway.
if (decorator_arguments['init'] and
('__init__' not in info.names or info.names['__init__'].plugin_generated) and
attributes):
args = [attr.to_argument() for attr in attributes if attr.is_in_init
and not self._is_kw_only_type(attr.type)]
if info.fallback_to_any:
# Make positional args optional since we don't know their order.
# This will at least allow us to typecheck them if they are called
# as kwargs
for arg in args:
if arg.kind == ARG_POS:
arg.kind = ARG_OPT
nameless_var = Var('')
args = [Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR),
*args,
Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR2),
]
add_method(
ctx,
'__init__',
args=args,
return_type=NoneType(),
)
if (decorator_arguments['eq'] and info.get('__eq__') is None or
decorator_arguments['order']):
# Type variable for self types in generated methods.
obj_type = ctx.api.named_type('builtins.object')
self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME,
[], obj_type)
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
# Add <, >, <=, >=, but only if the class has an eq method.
if decorator_arguments['order']:
if not decorator_arguments['eq']:
ctx.api.fail('eq must be True if order is True', ctx.cls)
for method_name in ['__lt__', '__gt__', '__le__', '__ge__']:
# Like for __eq__ and __ne__, we want "other" to match
# the self type.
obj_type = ctx.api.named_type('builtins.object')
order_tvar_def = TypeVarType(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME,
-1, [], obj_type)
order_return_type = ctx.api.named_type('builtins.bool')
order_args = [
Argument(Var('other', order_tvar_def), order_tvar_def, None, ARG_POS)
]
existing_method = info.get(method_name)
if existing_method is not None and not existing_method.plugin_generated:
assert existing_method.node
ctx.api.fail(
'You may not have a custom %s method when order=True' % method_name,
existing_method.node,
)
add_method(
ctx,
method_name,
args=order_args,
return_type=order_return_type,
self_type=order_tvar_def,
tvar_def=order_tvar_def,
)
if decorator_arguments['frozen']:
self._freeze(attributes)
else:
self._propertize_callables(attributes)
if decorator_arguments['slots']:
self.add_slots(info, attributes, correct_version=py_version >= (3, 10))
self.reset_init_only_vars(info, attributes)
if (decorator_arguments['match_args'] and
('__match_args__' not in info.names or
info.names['__match_args__'].plugin_generated) and
attributes):
str_type = ctx.api.named_type("builtins.str")
literals: List[Type] = [LiteralType(attr.name, str_type)
for attr in attributes if attr.is_in_init]
match_args_type = TupleType(literals, ctx.api.named_type("builtins.tuple"))
add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type)
self._add_dataclass_fields_magic_attribute()
info.metadata['dataclass'] = {
'attributes': [attr.serialize() for attr in attributes],
'frozen': decorator_arguments['frozen'],
}
def add_slots(self,
info: TypeInfo,
attributes: List[DataclassAttribute],
*,
correct_version: bool) -> None:
if not correct_version:
# This means that version is lower than `3.10`,
# it is just a non-existent argument for `dataclass` function.
self._ctx.api.fail(
'Keyword argument "slots" for "dataclass" '
'is only valid in Python 3.10 and higher',
self._ctx.reason,
)
return
generated_slots = {attr.name for attr in attributes}
if ((info.slots is not None and info.slots != generated_slots)
or info.names.get('__slots__')):
# This means we have a slots conflict.
# Class explicitly specifies a different `__slots__` field.
# And `@dataclass(slots=True)` is used.
# In runtime this raises a type error.
self._ctx.api.fail(
'"{}" both defines "__slots__" and is used with "slots=True"'.format(
self._ctx.cls.name,
),
self._ctx.cls,
)
return
info.slots = generated_slots
def reset_init_only_vars(self, info: TypeInfo, attributes: List[DataclassAttribute]) -> None:
"""Remove init-only vars from the class and reset init var declarations."""
for attr in attributes:
if attr.is_init_var:
if attr.name in info.names:
del info.names[attr.name]
else:
# Nodes of superclass InitVars not used in __init__ cannot be reached.
assert attr.is_init_var
for stmt in info.defn.defs.body:
if isinstance(stmt, AssignmentStmt) and stmt.unanalyzed_type:
lvalue = stmt.lvalues[0]
if isinstance(lvalue, NameExpr) and lvalue.name == attr.name:
# Reset node so that another semantic analysis pass will
# recreate a symbol node for this attribute.
lvalue.node = None
def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
"""Collect all attributes declared in the dataclass and its parents.
All assignments of the form
a: SomeType
b: SomeOtherType = ...
are collected.
"""
# First, collect attributes belonging to the current class.
ctx = self._ctx
cls = self._ctx.cls
attrs: List[DataclassAttribute] = []
known_attrs: Set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
if not (isinstance(stmt, AssignmentStmt) and stmt.new_syntax):
continue
# a: int, b: str = 1, 'foo' is not supported syntax so we
# don't have to worry about it.
lhs = stmt.lvalues[0]
if not isinstance(lhs, NameExpr):
continue
sym = cls.info.names.get(lhs.name)
if sym is None:
# This name is likely blocked by a star import. We don't need to defer because
# defer() is already called by mark_incomplete().
continue
node = sym.node
if isinstance(node, PlaceholderNode):
# This node is not ready yet.
return None
assert isinstance(node, Var)
# x: ClassVar[int] is ignored by dataclasses.
if node.is_classvar:
continue
# x: InitVar[int] is turned into x: int and is removed from the class.
is_init_var = False
node_type = get_proper_type(node.type)
if (isinstance(node_type, Instance) and
node_type.type.fullname == 'dataclasses.InitVar'):
is_init_var = True
node.type = node_type.args[0]
if self._is_kw_only_type(node_type):
kw_only = True
has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
is_in_init_param = field_args.get('init')
if is_in_init_param is None:
is_in_init = True
else:
is_in_init = bool(ctx.api.parse_bool(is_in_init_param))
has_default = False
# Ensure that something like x: int = field() is rejected
# after an attribute with a default.
if has_field_call:
has_default = 'default' in field_args or 'default_factory' in field_args
# All other assignments are already type checked.
elif not isinstance(stmt.rvalue, TempNode):
has_default = True
if not has_default:
# Make all non-default attributes implicit because they are de-facto set
# on self in the generated __init__(), not in the class body.
sym.implicit = True
is_kw_only = kw_only
# Use the kw_only field arg if it is provided. Otherwise use the
# kw_only value from the decorator parameter.
field_kw_only_param = field_args.get('kw_only')
if field_kw_only_param is not None:
is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param))
known_attrs.add(lhs.name)
attrs.append(DataclassAttribute(
name=lhs.name,
is_in_init=is_in_init,
is_init_var=is_init_var,
has_default=has_default,
line=stmt.line,
column=stmt.column,
type=sym.type,
info=cls.info,
kw_only=is_kw_only,
))
# Next, collect attributes belonging to any class in the MRO
# as long as those attributes weren't already collected. This
# makes it possible to overwrite attributes in subclasses.
# copy() because we potentially modify all_attrs below and if this code requires debugging
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
for info in cls.info.mro[1:-1]:
if 'dataclass' not in info.metadata:
continue
super_attrs = []
# Each class depends on the set of attributes in its dataclass ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname))
for data in info.metadata["dataclass"]["attributes"]:
name: str = data["name"]
if name not in known_attrs:
attr = DataclassAttribute.deserialize(info, data, ctx.api)
attr.expand_typevar_from_subtype(ctx.cls.info)
known_attrs.add(name)
super_attrs.append(attr)
elif all_attrs:
# How early in the attribute list an attribute appears is determined by the
# reverse MRO, not simply MRO.
# See https://docs.python.org/3/library/dataclasses.html#inheritance for
# details.
for attr in all_attrs:
if attr.name == name:
all_attrs.remove(attr)
super_attrs.append(attr)
break
all_attrs = super_attrs + all_attrs
all_attrs.sort(key=lambda a: a.kw_only)
# Ensure that arguments without a default don't follow
# arguments that have a default.
found_default = False
# Ensure that the KW_ONLY sentinel is only provided once
found_kw_sentinel = False
for attr in all_attrs:
# If we find any attribute that is_in_init, not kw_only, and that
# doesn't have a default after one that does have one,
# then that's an error.
if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only:
# If the issue comes from merging different classes, report it
# at the class definition point.
context = (Context(line=attr.line, column=attr.column) if attr in attrs
else ctx.cls)
ctx.api.fail(
'Attributes without a default cannot follow attributes with one',
context,
)
found_default = found_default or (attr.has_default and attr.is_in_init)
if found_kw_sentinel and self._is_kw_only_type(attr.type):
context = (Context(line=attr.line, column=attr.column) if attr in attrs
else ctx.cls)
ctx.api.fail(
'There may not be more than one field with the KW_ONLY type',
context,
)
found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type)
return all_attrs
def _freeze(self, attributes: List[DataclassAttribute]) -> None:
"""Converts all attributes to @property methods in order to
emulate frozen classes.
"""
info = self._ctx.cls.info
for attr in attributes:
sym_node = info.names.get(attr.name)
if sym_node is not None:
var = sym_node.node
assert isinstance(var, Var)
var.is_property = True
else:
var = attr.to_var()
var.info = info
var.is_property = True
var._fullname = info.fullname + '.' + var.name
info.names[var.name] = SymbolTableNode(MDEF, var)
def _propertize_callables(self, attributes: List[DataclassAttribute]) -> None:
"""Converts all attributes with callable types to @property methods.
This avoids the typechecker getting confused and thinking that
`my_dataclass_instance.callable_attr(foo)` is going to receive a
`self` argument (it is not).
"""
info = self._ctx.cls.info
for attr in attributes:
if isinstance(get_proper_type(attr.type), CallableType):
var = attr.to_var()
var.info = info
var.is_property = True
var.is_settable_property = True
var._fullname = info.fullname + '.' + var.name
info.names[var.name] = SymbolTableNode(MDEF, var)
def _is_kw_only_type(self, node: Optional[Type]) -> bool:
"""Checks if the type of the node is the KW_ONLY sentinel value."""
if node is None:
return False
node_type = get_proper_type(node)
if not isinstance(node_type, Instance):
return False
return node_type.type.fullname == 'dataclasses.KW_ONLY'
def _add_dataclass_fields_magic_attribute(self) -> None:
attr_name = '__dataclass_fields__'
any_type = AnyType(TypeOfAny.explicit)
field_type = self._ctx.api.named_type_or_none('dataclasses.Field', [any_type]) or any_type
attr_type = self._ctx.api.named_type('builtins.dict', [
self._ctx.api.named_type('builtins.str'),
field_type,
])
var = Var(name=attr_name, type=attr_type)
var.info = self._ctx.cls.info
var._fullname = self._ctx.cls.info.fullname + '.' + attr_name
self._ctx.cls.info.names[attr_name] = SymbolTableNode(
kind=MDEF,
node=var,
plugin_generated=True,
)
def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
"""Hooks into the class typechecking process to add support for dataclasses.
"""
transformer = DataclassTransformer(ctx)
transformer.transform()
def _collect_field_args(expr: Expression,
ctx: ClassDefContext) -> Tuple[bool, Dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
"""
if (
isinstance(expr, CallExpr) and
isinstance(expr.callee, RefExpr) and
expr.callee.fullname in field_makers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}

View file

@ -0,0 +1,431 @@
from functools import partial
from typing import Callable, Optional, List
from mypy import message_registry
from mypy.nodes import StrExpr, IntExpr, DictExpr, UnaryExpr
from mypy.plugin import (
Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext
)
from mypy.plugins.common import try_getting_str_literals
from mypy.types import (
FunctionLike, Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType,
TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType, TupleType
)
from mypy.subtypes import is_subtype
from mypy.typeops import make_simplified_union
from mypy.checkexpr import is_literal_type_like
class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
from mypy.plugins import ctypes, singledispatch
if fullname in ('contextlib.contextmanager', 'contextlib.asynccontextmanager'):
return contextmanager_callback
elif fullname == 'ctypes.Array':
return ctypes.array_constructor_callback
elif fullname == 'functools.singledispatch':
return singledispatch.create_singledispatch_function_callback
return None
def get_method_signature_hook(self, fullname: str
) -> Optional[Callable[[MethodSigContext], FunctionLike]]:
from mypy.plugins import ctypes, singledispatch
if fullname == 'typing.Mapping.get':
return typed_dict_get_signature_callback
elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES):
return typed_dict_setdefault_signature_callback
elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES):
return typed_dict_pop_signature_callback
elif fullname in set(n + '.update' for n in TPDICT_FB_NAMES):
return typed_dict_update_signature_callback
elif fullname == 'ctypes.Array.__setitem__':
return ctypes.array_setitem_callback
elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_callback
return None
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
from mypy.plugins import ctypes, singledispatch
if fullname == 'typing.Mapping.get':
return typed_dict_get_callback
elif fullname == 'builtins.int.__pow__':
return int_pow_callback
elif fullname == 'builtins.int.__neg__':
return int_neg_callback
elif fullname in ('builtins.tuple.__mul__', 'builtins.tuple.__rmul__'):
return tuple_mul_callback
elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES):
return typed_dict_setdefault_callback
elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES):
return typed_dict_pop_callback
elif fullname in set(n + '.__delitem__' for n in TPDICT_FB_NAMES):
return typed_dict_delitem_callback
elif fullname == 'ctypes.Array.__getitem__':
return ctypes.array_getitem_callback
elif fullname == 'ctypes.Array.__iter__':
return ctypes.array_iter_callback
elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_after_register_argument
return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
from mypy.plugins import ctypes
from mypy.plugins import enums
if fullname == 'ctypes.Array.value':
return ctypes.array_value_callback
elif fullname == 'ctypes.Array.raw':
return ctypes.array_raw_callback
elif fullname in enums.ENUM_NAME_ACCESS:
return enums.enum_name_callback
elif fullname in enums.ENUM_VALUE_ACCESS:
return enums.enum_value_callback
return None
def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
from mypy.plugins import attrs
from mypy.plugins import dataclasses
from mypy.plugins import functools
if fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
elif fullname in attrs.attr_dataclass_makers:
return partial(
attrs.attr_class_maker_callback,
auto_attribs_default=True,
)
elif fullname in attrs.attr_frozen_makers:
return partial(
attrs.attr_class_maker_callback,
auto_attribs_default=None,
frozen_default=True,
)
elif fullname in attrs.attr_define_makers:
return partial(
attrs.attr_class_maker_callback,
auto_attribs_default=None,
)
elif fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback
return None
def contextmanager_callback(ctx: FunctionContext) -> Type:
"""Infer a better return type for 'contextlib.contextmanager'."""
# Be defensive, just in case.
if ctx.arg_types and len(ctx.arg_types[0]) == 1:
arg_type = get_proper_type(ctx.arg_types[0][0])
default_return = get_proper_type(ctx.default_return_type)
if (isinstance(arg_type, CallableType)
and isinstance(default_return, CallableType)):
# The stub signature doesn't preserve information about arguments so
# add them back here.
return default_return.copy_modified(
arg_types=arg_type.arg_types,
arg_kinds=arg_type.arg_kinds,
arg_names=arg_type.arg_names,
variables=arg_type.variables,
is_ellipsis_args=arg_type.is_ellipsis_args)
return ctx.default_return_type
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.get.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
if (isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1):
key = ctx.args[0][0].value
value_type = get_proper_type(ctx.type.items.get(key))
ret_type = signature.ret_type
if value_type:
default_arg = ctx.args[1][0]
if (isinstance(value_type, TypedDictType)
and isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0):
# Caller has empty dict {} as default for typed dict.
value_type = value_type.copy_modified(required_keys=set())
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
return signature.copy_modified(
arg_types=[signature.arg_types[0],
make_simplified_union([value_type, tv])],
ret_type=ret_type)
return signature
def typed_dict_get_callback(ctx: MethodContext) -> Type:
"""Infer a precise return type for TypedDict.get with literal first argument."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type
output_types: List[Type] = []
for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
return ctx.default_return_type
if len(ctx.arg_types) == 1:
output_types.append(value_type)
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
default_arg = ctx.args[1][0]
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)):
# Special case '{}' as the default for a typed dict type.
output_types.append(value_type.copy_modified(required_keys=set()))
else:
output_types.append(value_type)
output_types.append(ctx.arg_types[1][0])
if len(ctx.arg_types) == 1:
output_types.append(NoneType())
return make_simplified_union(output_types)
return ctx.default_return_type
def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.pop.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type('builtins.str', [])
if (isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
typ = make_simplified_union([value_type, tv])
return signature.copy_modified(
arg_types=[str_type, typ],
ret_type=typ)
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
def typed_dict_pop_callback(ctx: MethodContext) -> Type:
"""Type check and infer a precise return type for TypedDict.pop."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)
value_types = []
for key in keys:
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
value_type = ctx.type.items.get(key)
if value_type:
value_types.append(value_type)
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
if len(ctx.args[1]) == 0:
return make_simplified_union(value_types)
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
return make_simplified_union([*value_types, ctx.arg_types[1][0]])
return ctx.default_return_type
def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.setdefault.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type('builtins.str', [])
if (isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(ctx.args[1]) == 1):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
return signature.copy_modified(arg_types=[str_type, value_type])
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.setdefault and infer a precise return type."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 2
and len(ctx.arg_types[0]) == 1
and len(ctx.arg_types[1]) == 1):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)
default_type = ctx.arg_types[1][0]
value_types = []
for key in keys:
value_type = ctx.type.items.get(key)
if value_type is None:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
# The signature_callback above can't always infer the right signature
# (e.g. when the expression is a variable that happens to be a Literal str)
# so we need to handle the check ourselves here and make sure the provided
# default can be assigned to all key-value pairs we're updating.
if not is_subtype(default_type, value_type):
ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
default_type, value_type, ctx.context)
return AnyType(TypeOfAny.from_error)
value_types.append(value_type)
return make_simplified_union(value_types)
return ctx.default_return_type
def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.__delitem__."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)
for key in keys:
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return ctx.default_return_type
def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.update."""
signature = ctx.default_signature
if (isinstance(ctx.type, TypedDictType)
and len(signature.arg_types) == 1):
arg_type = get_proper_type(signature.arg_types[0])
assert isinstance(arg_type, TypedDictType)
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
return signature.copy_modified(arg_types=[arg_type])
return signature
def int_pow_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pow__."""
# int.__pow__ has an optional modulo argument,
# so we expect 2 argument positions
if (len(ctx.arg_types) == 2
and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0):
arg = ctx.args[0][0]
if isinstance(arg, IntExpr):
exponent = arg.value
elif isinstance(arg, UnaryExpr) and arg.op == '-' and isinstance(arg.expr, IntExpr):
exponent = -arg.expr.value
else:
# Right operand not an int literal or a negated literal -- give up.
return ctx.default_return_type
if exponent >= 0:
return ctx.api.named_generic_type('builtins.int', [])
else:
return ctx.api.named_generic_type('builtins.float', [])
return ctx.default_return_type
def int_neg_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__neg__.
This is mainly used to infer the return type as LiteralType
if the original underlying object is a LiteralType object
"""
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
value = ctx.type.last_known_value.value
fallback = ctx.type.last_known_value.fallback
if isinstance(value, int):
if is_literal_type_like(ctx.api.type_context[-1]):
return LiteralType(value=-value, fallback=fallback)
else:
return ctx.type.copy_modified(last_known_value=LiteralType(
value=-value,
fallback=ctx.type,
line=ctx.type.line,
column=ctx.type.column,
))
elif isinstance(ctx.type, LiteralType):
value = ctx.type.value
fallback = ctx.type.fallback
if isinstance(value, int):
return LiteralType(value=-value, fallback=fallback)
return ctx.default_return_type
def tuple_mul_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
This is used to return a specific sized tuple if multiplied by Literal int
"""
if not isinstance(ctx.type, TupleType):
return ctx.default_return_type
arg_type = get_proper_type(ctx.arg_types[0][0])
if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
value = arg_type.last_known_value.value
if isinstance(value, int):
return ctx.type.copy_modified(items=ctx.type.items * value)
elif isinstance(ctx.type, LiteralType):
value = arg_type.value
if isinstance(value, int):
return ctx.type.copy_modified(items=ctx.type.items * value)
return ctx.default_return_type

View file

@ -0,0 +1,255 @@
"""
This file contains a variety of plugins for refining how mypy infers types of
expressions involving Enums.
Currently, this file focuses on providing better inference for expressions like
'SomeEnum.FOO.name' and 'SomeEnum.FOO.value'. Note that the type of both expressions
will vary depending on exactly which instance of SomeEnum we're looking at.
Note that this file does *not* contain all special-cased logic related to enums:
we actually bake some of it directly in to the semantic analysis layer (see
semanal_enum.py).
"""
from typing import Iterable, Optional, Sequence, TypeVar, cast
from typing_extensions import Final
import mypy.plugin # To avoid circular imports.
from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type
from mypy.typeops import make_simplified_union
from mypy.nodes import TypeInfo
from mypy.subtypes import is_equivalent
from mypy.semanal_enum import ENUM_BASES
ENUM_NAME_ACCESS: Final = {"{}.name".format(prefix) for prefix in ENUM_BASES} | {
"{}._name_".format(prefix) for prefix in ENUM_BASES
}
ENUM_VALUE_ACCESS: Final = {"{}.value".format(prefix) for prefix in ENUM_BASES} | {
"{}._value_".format(prefix) for prefix in ENUM_BASES
}
def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
"""This plugin refines the 'name' attribute in enums to act as if
they were declared to be final.
For example, the expression 'MyEnum.FOO.name' normally is inferred
to be of type 'str'.
This plugin will instead make the inferred type be a 'str' where the
last known value is 'Literal["FOO"]'. This means it would be legal to
use 'MyEnum.FOO.name' in contexts that expect a Literal type, just like
any other Final variable or attribute.
This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_NAME_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
return ctx.default_attr_type
else:
str_type = ctx.api.named_generic_type('builtins.str', [])
literal_type = LiteralType(enum_field_name, fallback=str_type)
return str_type.copy_modified(last_known_value=literal_type)
_T = TypeVar('_T')
def _first(it: Iterable[_T]) -> Optional[_T]:
"""Return the first value from any iterable.
Returns ``None`` if the iterable is empty.
"""
for val in it:
return val
return None
def _infer_value_type_with_auto_fallback(
ctx: 'mypy.plugin.AttributeContext',
proper_type: Optional[ProperType]) -> Optional[Type]:
"""Figure out the type of an enum value accounting for `auto()`.
This method is a no-op for a `None` proper_type and also in the case where
the type is not "enum.auto"
"""
if proper_type is None:
return None
if not ((isinstance(proper_type, Instance) and
proper_type.type.fullname == 'enum.auto')):
return proper_type
assert isinstance(ctx.type, Instance), 'An incorrect ctx.type was passed.'
info = ctx.type.type
# Find the first _generate_next_value_ on the mro. We need to know
# if it is `Enum` because `Enum` types say that the return-value of
# `_generate_next_value_` is `Any`. In reality the default `auto()`
# returns an `int` (presumably the `Any` in typeshed is to make it
# easier to subclass and change the returned type).
type_with_gnv = _first(
ti for ti in info.mro if ti.names.get('_generate_next_value_'))
if type_with_gnv is None:
return ctx.default_attr_type
stnode = type_with_gnv.names['_generate_next_value_']
# This should be a `CallableType`
node_type = get_proper_type(stnode.type)
if isinstance(node_type, CallableType):
if type_with_gnv.fullname == 'enum.Enum':
int_type = ctx.api.named_generic_type('builtins.int', [])
return int_type
return get_proper_type(node_type.ret_type)
return ctx.default_attr_type
def _implements_new(info: TypeInfo) -> bool:
"""Check whether __new__ comes from enum.Enum or was implemented in a
subclass. In the latter case, we must infer Any as long as mypy can't infer
the type of _value_ from assignments in __new__.
"""
type_with_new = _first(
ti
for ti in info.mro
if ti.names.get('__new__') and not ti.fullname.startswith('builtins.')
)
if type_with_new is None:
return False
return type_with_new.fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.StrEnum')
def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
following:
class SomeEnum:
FOO = A()
BAR = B()
By default, mypy will infer that 'SomeEnum.FOO.value' and
'SomeEnum.BAR.value' both are of type 'Any'. This plugin refines
this inference so that mypy understands the expressions are
actually of types 'A' and 'B' respectively. This better reflects
the actual runtime behavior.
This plugin works simply by looking up the original value assigned
to the enum. For example, when this plugin sees 'SomeEnum.BAR.value',
it will look up whatever type 'BAR' had in the SomeEnum TypeInfo and
use that as the inferred type of the overall expression.
This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_VALUE_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
# We do not know the enum field name (perhaps it was passed to a
# function and we only know that it _is_ a member). All is not lost
# however, if we can prove that the all of the enum members have the
# same value-type, then it doesn't matter which member was passed in.
# The value-type is still known.
if isinstance(ctx.type, Instance):
info = ctx.type.type
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type
stnodes = (info.get(name) for name in info.names)
# Enums _can_ have methods and instance attributes.
# Omit methods and attributes created by assigning to self.*
# for our value inference.
node_types = (
get_proper_type(n.type) if n else None
for n in stnodes
if n is None or not n.implicit)
proper_types = list(
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
if t is None or not isinstance(t, CallableType))
underlying_type = _first(proper_types)
if underlying_type is None:
return ctx.default_attr_type
# At first we try to predict future `value` type if all other items
# have the same type. For example, `int`.
# If this is the case, we simply return this type.
# See https://github.com/python/mypy/pull/9443
all_same_value_type = all(
proper_type is not None and proper_type == underlying_type
for proper_type in proper_types)
if all_same_value_type:
if underlying_type is not None:
return underlying_type
# But, after we started treating all `Enum` values as `Final`,
# we start to infer types in
# `item = 1` as `Literal[1]`, not just `int`.
# So, for example types in this `Enum` will all be different:
#
# class Ordering(IntEnum):
# one = 1
# two = 2
# three = 3
#
# We will infer three `Literal` types here.
# They are not the same, but they are equivalent.
# So, we unify them to make sure `.value` prediction still works.
# Result will be `Literal[1] | Literal[2] | Literal[3]` for this case.
all_equivalent_types = all(
proper_type is not None and is_equivalent(proper_type, underlying_type)
for proper_type in proper_types)
if all_equivalent_types:
return make_simplified_union(cast(Sequence[Type], proper_types))
return ctx.default_attr_type
assert isinstance(ctx.type, Instance)
info = ctx.type.type
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type
stnode = info.get(enum_field_name)
if stnode is None:
return ctx.default_attr_type
underlying_type = _infer_value_type_with_auto_fallback(
ctx, get_proper_type(stnode.type))
if underlying_type is None:
return ctx.default_attr_type
return underlying_type
def _extract_underlying_field_name(typ: Type) -> Optional[str]:
"""If the given type corresponds to some Enum instance, returns the
original name of that enum. For example, if we receive in the type
corresponding to 'SomeEnum.FOO', we return the string "SomeEnum.Foo".
This helper takes advantage of the fact that Enum instances are valid
to use inside Literal[...] types. An expression like 'SomeEnum.FOO' is
actually represented by an Instance type with a Literal enum fallback.
We can examine this Literal fallback to retrieve the string.
"""
typ = get_proper_type(typ)
if not isinstance(typ, Instance):
return None
if not typ.type.is_enum:
return None
underlying_literal = typ.last_known_value
if underlying_literal is None:
return None
# The checks above have verified this LiteralType is representing an enum value,
# which means the 'value' field is guaranteed to be the name of the enum field
# as a string.
assert isinstance(underlying_literal.value, str)
return underlying_literal.value

View file

@ -0,0 +1,106 @@
"""Plugin for supporting the functools standard library module."""
from typing import Dict, NamedTuple, Optional
from typing_extensions import Final
import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType
functools_total_ordering_makers: Final = {
'functools.total_ordering',
}
_ORDERING_METHODS: Final = {
'__lt__',
'__le__',
'__gt__',
'__ge__',
}
_MethodInfo = NamedTuple('_MethodInfo', [('is_static', bool), ('type', CallableType)])
def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext,
auto_attribs_default: bool = False) -> None:
"""Add dunder methods to classes decorated with functools.total_ordering."""
if ctx.api.options.python_version < (3,):
# This plugin is not supported in Python 2 mode (it's a no-op).
return
comparison_methods = _analyze_class(ctx)
if not comparison_methods:
ctx.api.fail(
'No ordering operation defined when using "functools.total_ordering": < > <= >=',
ctx.reason)
return
# prefer __lt__ to __le__ to __gt__ to __ge__
root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
root_method = comparison_methods[root]
if not root_method:
# None of the defined comparison methods can be analysed
return
other_type = _find_other_type(root_method)
bool_type = ctx.api.named_type('builtins.bool')
ret_type: Type = bool_type
if root_method.type.ret_type != ctx.api.named_type('builtins.bool'):
proper_ret_type = get_proper_type(root_method.type.ret_type)
if not (isinstance(proper_ret_type, UnboundType)
and proper_ret_type.name.split('.')[-1] == 'bool'):
ret_type = AnyType(TypeOfAny.implementation_artifact)
for additional_op in _ORDERING_METHODS:
# Either the method is not implemented
# or has an unknown signature that we can now extrapolate.
if not comparison_methods.get(additional_op):
args = [Argument(Var('other', other_type), other_type, None, ARG_POS)]
add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)
def _find_other_type(method: _MethodInfo) -> Type:
"""Find the type of the ``other`` argument in a comparison method."""
first_arg_pos = 0 if method.is_static else 1
cur_pos_arg = 0
other_arg = None
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
if arg_kind.is_positional():
if cur_pos_arg == first_arg_pos:
other_arg = arg_type
break
cur_pos_arg += 1
elif arg_kind != ARG_STAR2:
other_arg = arg_type
break
if other_arg is None:
return AnyType(TypeOfAny.implementation_artifact)
return other_arg
def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> Dict[str, Optional[_MethodInfo]]:
"""Analyze the class body, its parents, and return the comparison methods found."""
# Traverse the MRO and collect ordering methods.
comparison_methods: Dict[str, Optional[_MethodInfo]] = {}
# Skip object because total_ordering does not use methods from object
for cls in ctx.cls.info.mro[:-1]:
for name in _ORDERING_METHODS:
if name in cls.names and name not in comparison_methods:
node = cls.names[name].node
if isinstance(node, FuncItem) and isinstance(node.type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_static, node.type)
continue
if isinstance(node, Var):
proper_type = get_proper_type(node.type)
if isinstance(proper_type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
continue
comparison_methods[name] = None
return comparison_methods

View file

@ -0,0 +1,211 @@
from mypy.messages import format_type
from mypy.plugins.common import add_method_to_class
from mypy.nodes import (
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context
)
from mypy.subtypes import is_subtype
from mypy.types import (
AnyType, CallableType, Instance, NoneType, Overloaded, Type, TypeOfAny, get_proper_type,
FunctionLike
)
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext
from typing import List, NamedTuple, Optional, Sequence, TypeVar, Union
from typing_extensions import Final
SingledispatchTypeVars = NamedTuple('SingledispatchTypeVars', [
('return_type', Type),
('fallback', CallableType),
])
RegisterCallableInfo = NamedTuple('RegisterCallableInfo', [
('register_type', Type),
('singledispatch_obj', Instance),
])
SINGLEDISPATCH_TYPE: Final = 'functools._SingleDispatchCallable'
SINGLEDISPATCH_REGISTER_METHOD: Final = '{}.register'.format(SINGLEDISPATCH_TYPE)
SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = '{}.__call__'.format(SINGLEDISPATCH_TYPE)
def get_singledispatch_info(typ: Instance) -> Optional[SingledispatchTypeVars]:
if len(typ.args) == 2:
return SingledispatchTypeVars(*typ.args) # type: ignore
return None
T = TypeVar('T')
def get_first_arg(args: List[List[T]]) -> Optional[T]:
"""Get the element that corresponds to the first argument passed to the function"""
if args and args[0]:
return args[0][0]
return None
REGISTER_RETURN_CLASS: Final = '_SingleDispatchRegisterCallable'
REGISTER_CALLABLE_CALL_METHOD: Final = 'functools.{}.__call__'.format(
REGISTER_RETURN_CLASS
)
def make_fake_register_class_instance(api: CheckerPluginInterface, type_args: Sequence[Type]
) -> Instance:
defn = ClassDef(REGISTER_RETURN_CLASS, Block([]))
defn.fullname = 'functools.{}'.format(REGISTER_RETURN_CLASS)
info = TypeInfo(SymbolTable(), defn, "functools")
obj_type = api.named_generic_type('builtins.object', []).type
info.bases = [Instance(obj_type, [])]
info.mro = [info, obj_type]
defn.info = info
func_arg = Argument(Var('name'), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS)
add_method_to_class(api, defn, '__call__', [func_arg], NoneType())
return Instance(info, type_args)
PluginContext = Union[FunctionContext, MethodContext]
def fail(ctx: PluginContext, msg: str, context: Optional[Context]) -> None:
"""Emit an error message.
This tries to emit an error message at the location specified by `context`, falling back to the
location specified by `ctx.context`. This is helpful when the only context information about
where you want to put the error message may be None (like it is for `CallableType.definition`)
and falling back to the location of the calling function is fine."""
# TODO: figure out if there is some more reliable way of getting context information, so this
# function isn't necessary
if context is not None:
err_context = context
else:
err_context = ctx.context
ctx.api.fail(msg, err_context)
def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
"""Called for functools.singledispatch"""
func_type = get_proper_type(get_first_arg(ctx.arg_types))
if isinstance(func_type, CallableType):
if len(func_type.arg_kinds) < 1:
fail(
ctx,
'Singledispatch function requires at least one argument',
func_type.definition,
)
return ctx.default_return_type
elif not func_type.arg_kinds[0].is_positional(star=True):
fail(
ctx,
'First argument to singledispatch function must be a positional argument',
func_type.definition,
)
return ctx.default_return_type
# singledispatch returns an instance of functools._SingleDispatchCallable according to
# typeshed
singledispatch_obj = get_proper_type(ctx.default_return_type)
assert isinstance(singledispatch_obj, Instance)
singledispatch_obj.args += (func_type,)
return ctx.default_return_type
def singledispatch_register_callback(ctx: MethodContext) -> Type:
"""Called for functools._SingleDispatchCallable.register"""
assert isinstance(ctx.type, Instance)
# TODO: check that there's only one argument
first_arg_type = get_proper_type(get_first_arg(ctx.arg_types))
if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj():
# HACK: We received a class as an argument to register. We need to be able
# to access the function that register is being applied to, and the typeshed definition
# of register has it return a generic Callable, so we create a new
# SingleDispatchRegisterCallable class, define a __call__ method, and then add a
# plugin hook for that.
# is_subtype doesn't work when the right type is Overloaded, so we need the
# actual type
register_type = first_arg_type.items[0].ret_type
type_args = RegisterCallableInfo(register_type, ctx.type)
register_callable = make_fake_register_class_instance(
ctx.api,
type_args
)
return register_callable
elif isinstance(first_arg_type, CallableType):
# TODO: do more checking for registered functions
register_function(ctx, ctx.type, first_arg_type)
# The typeshed stubs for register say that the function returned is Callable[..., T], even
# though the function returned is the same as the one passed in. We return the type of the
# function so that mypy can properly type check cases where the registered function is used
# directly (instead of through singledispatch)
return first_arg_type
# fallback in case we don't recognize the arguments
return ctx.default_return_type
def register_function(ctx: PluginContext, singledispatch_obj: Instance, func: Type,
register_arg: Optional[Type] = None) -> None:
"""Register a function"""
func = get_proper_type(func)
if not isinstance(func, CallableType):
return
metadata = get_singledispatch_info(singledispatch_obj)
if metadata is None:
# if we never added the fallback to the type variables, we already reported an error, so
# just don't do anything here
return
dispatch_type = get_dispatch_type(func, register_arg)
if dispatch_type is None:
# TODO: report an error here that singledispatch requires at least one argument
# (might want to do the error reporting in get_dispatch_type)
return
fallback = metadata.fallback
fallback_dispatch_type = fallback.arg_types[0]
if not is_subtype(dispatch_type, fallback_dispatch_type):
fail(ctx, 'Dispatch type {} must be subtype of fallback function first argument {}'.format(
format_type(dispatch_type), format_type(fallback_dispatch_type)
), func.definition)
return
return
def get_dispatch_type(func: CallableType, register_arg: Optional[Type]) -> Optional[Type]:
if register_arg is not None:
return register_arg
if func.arg_types:
return func.arg_types[0]
return None
def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type:
"""Called on the function after passing a type to register"""
register_callable = ctx.type
if isinstance(register_callable, Instance):
type_args = RegisterCallableInfo(*register_callable.args) # type: ignore
func = get_first_arg(ctx.arg_types)
if func is not None:
register_function(ctx, type_args.singledispatch_obj, func, type_args.register_type)
# see call to register_function in the callback for register
return func
return ctx.default_return_type
def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike:
"""Called for functools._SingleDispatchCallable.__call__"""
if not isinstance(ctx.type, Instance):
return ctx.default_signature
metadata = get_singledispatch_info(ctx.type)
if metadata is None:
return ctx.default_signature
return metadata.fallback