init
This commit is contained in:
commit
38355d2442
9083 changed files with 1225834 additions and 0 deletions
706
.venv/lib/python3.8/site-packages/mypy/checkpattern.py
Normal file
706
.venv/lib/python3.8/site-packages/mypy/checkpattern.py
Normal file
|
|
@ -0,0 +1,706 @@
|
|||
"""Pattern checker. This file is conceptually part of TypeChecker."""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union
|
||||
from typing_extensions import Final
|
||||
|
||||
import mypy.checker
|
||||
from mypy.checkmember import analyze_member_access
|
||||
from mypy.expandtype import expand_type_by_instance
|
||||
from mypy.join import join_types
|
||||
from mypy.literals import literal_hash
|
||||
from mypy.maptype import map_instance_to_supertype
|
||||
from mypy.meet import narrow_declared_type
|
||||
from mypy import message_registry
|
||||
from mypy.messages import MessageBuilder
|
||||
from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr
|
||||
from mypy.patterns import (
|
||||
Pattern, AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern,
|
||||
ClassPattern, SingletonPattern
|
||||
)
|
||||
from mypy.plugin import Plugin
|
||||
from mypy.subtypes import is_subtype
|
||||
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \
|
||||
coerce_to_literal
|
||||
from mypy.types import (
|
||||
ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type,
|
||||
TypedDictType, TupleType, NoneType, UnionType
|
||||
)
|
||||
from mypy.typevars import fill_typevars
|
||||
from mypy.visitor import PatternVisitor
|
||||
|
||||
self_match_type_names: Final = [
|
||||
"builtins.bool",
|
||||
"builtins.bytearray",
|
||||
"builtins.bytes",
|
||||
"builtins.dict",
|
||||
"builtins.float",
|
||||
"builtins.frozenset",
|
||||
"builtins.int",
|
||||
"builtins.list",
|
||||
"builtins.set",
|
||||
"builtins.str",
|
||||
"builtins.tuple",
|
||||
]
|
||||
|
||||
non_sequence_match_type_names: Final = [
|
||||
"builtins.str",
|
||||
"builtins.bytes",
|
||||
"builtins.bytearray"
|
||||
]
|
||||
|
||||
|
||||
# For every Pattern a PatternType can be calculated. This requires recursively calculating
|
||||
# the PatternTypes of the sub-patterns first.
|
||||
# Using the data in the PatternType the match subject and captured names can be narrowed/inferred.
|
||||
PatternType = NamedTuple(
|
||||
'PatternType',
|
||||
[
|
||||
('type', Type), # The type the match subject can be narrowed to
|
||||
('rest_type', Type), # The remaining type if the pattern didn't match
|
||||
('captures', Dict[Expression, Type]), # The variables captured by the pattern
|
||||
])
|
||||
|
||||
|
||||
class PatternChecker(PatternVisitor[PatternType]):
|
||||
"""Pattern checker.
|
||||
|
||||
This class checks if a pattern can match a type, what the type can be narrowed to, and what
|
||||
type capture patterns should be inferred as.
|
||||
"""
|
||||
|
||||
# Some services are provided by a TypeChecker instance.
|
||||
chk: 'mypy.checker.TypeChecker'
|
||||
# This is shared with TypeChecker, but stored also here for convenience.
|
||||
msg: MessageBuilder
|
||||
# Currently unused
|
||||
plugin: Plugin
|
||||
# The expression being matched against the pattern
|
||||
subject: Expression
|
||||
|
||||
subject_type: Type
|
||||
# Type of the subject to check the (sub)pattern against
|
||||
type_context: List[Type]
|
||||
# Types that match against self instead of their __match_args__ if used as a class pattern
|
||||
# Filled in from self_match_type_names
|
||||
self_match_types: List[Type]
|
||||
# Types that are sequences, but don't match sequence patterns. Filled in from
|
||||
# non_sequence_match_type_names
|
||||
non_sequence_match_types: List[Type]
|
||||
|
||||
def __init__(self,
|
||||
chk: 'mypy.checker.TypeChecker',
|
||||
msg: MessageBuilder, plugin: Plugin
|
||||
) -> None:
|
||||
self.chk = chk
|
||||
self.msg = msg
|
||||
self.plugin = plugin
|
||||
|
||||
self.type_context = []
|
||||
self.self_match_types = self.generate_types_from_names(self_match_type_names)
|
||||
self.non_sequence_match_types = self.generate_types_from_names(
|
||||
non_sequence_match_type_names
|
||||
)
|
||||
|
||||
def accept(self, o: Pattern, type_context: Type) -> PatternType:
|
||||
self.type_context.append(type_context)
|
||||
result = o.accept(self)
|
||||
self.type_context.pop()
|
||||
|
||||
return result
|
||||
|
||||
def visit_as_pattern(self, o: AsPattern) -> PatternType:
|
||||
current_type = self.type_context[-1]
|
||||
if o.pattern is not None:
|
||||
pattern_type = self.accept(o.pattern, current_type)
|
||||
typ, rest_type, type_map = pattern_type
|
||||
else:
|
||||
typ, rest_type, type_map = current_type, UninhabitedType(), {}
|
||||
|
||||
if not is_uninhabited(typ) and o.name is not None:
|
||||
typ, _ = self.chk.conditional_types_with_intersection(current_type,
|
||||
[get_type_range(typ)],
|
||||
o,
|
||||
default=current_type)
|
||||
if not is_uninhabited(typ):
|
||||
type_map[o.name] = typ
|
||||
|
||||
return PatternType(typ, rest_type, type_map)
|
||||
|
||||
def visit_or_pattern(self, o: OrPattern) -> PatternType:
|
||||
current_type = self.type_context[-1]
|
||||
|
||||
#
|
||||
# Check all the subpatterns
|
||||
#
|
||||
pattern_types = []
|
||||
for pattern in o.patterns:
|
||||
pattern_type = self.accept(pattern, current_type)
|
||||
pattern_types.append(pattern_type)
|
||||
current_type = pattern_type.rest_type
|
||||
|
||||
#
|
||||
# Collect the final type
|
||||
#
|
||||
types = []
|
||||
for pattern_type in pattern_types:
|
||||
if not is_uninhabited(pattern_type.type):
|
||||
types.append(pattern_type.type)
|
||||
|
||||
#
|
||||
# Check the capture types
|
||||
#
|
||||
capture_types: Dict[Var, List[Tuple[Expression, Type]]] = defaultdict(list)
|
||||
# Collect captures from the first subpattern
|
||||
for expr, typ in pattern_types[0].captures.items():
|
||||
node = get_var(expr)
|
||||
capture_types[node].append((expr, typ))
|
||||
|
||||
# Check if other subpatterns capture the same names
|
||||
for i, pattern_type in enumerate(pattern_types[1:]):
|
||||
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
|
||||
if capture_types.keys() != vars:
|
||||
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
|
||||
for expr, typ in pattern_type.captures.items():
|
||||
node = get_var(expr)
|
||||
capture_types[node].append((expr, typ))
|
||||
|
||||
captures: Dict[Expression, Type] = {}
|
||||
for var, capture_list in capture_types.items():
|
||||
typ = UninhabitedType()
|
||||
for _, other in capture_list:
|
||||
typ = join_types(typ, other)
|
||||
|
||||
captures[capture_list[0][0]] = typ
|
||||
|
||||
union_type = make_simplified_union(types)
|
||||
return PatternType(union_type, current_type, captures)
|
||||
|
||||
def visit_value_pattern(self, o: ValuePattern) -> PatternType:
|
||||
current_type = self.type_context[-1]
|
||||
typ = self.chk.expr_checker.accept(o.expr)
|
||||
typ = coerce_to_literal(typ)
|
||||
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
|
||||
current_type,
|
||||
[get_type_range(typ)],
|
||||
o,
|
||||
default=current_type
|
||||
)
|
||||
return PatternType(narrowed_type, rest_type, {})
|
||||
|
||||
def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
|
||||
current_type = self.type_context[-1]
|
||||
value: Union[bool, None] = o.value
|
||||
if isinstance(value, bool):
|
||||
typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool")
|
||||
elif value is None:
|
||||
typ = NoneType()
|
||||
else:
|
||||
assert False
|
||||
|
||||
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
|
||||
current_type,
|
||||
[get_type_range(typ)],
|
||||
o,
|
||||
default=current_type
|
||||
)
|
||||
return PatternType(narrowed_type, rest_type, {})
|
||||
|
||||
def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
|
||||
#
|
||||
# check for existence of a starred pattern
|
||||
#
|
||||
current_type = get_proper_type(self.type_context[-1])
|
||||
if not self.can_match_sequence(current_type):
|
||||
return self.early_non_match()
|
||||
star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
|
||||
star_position: Optional[int] = None
|
||||
if len(star_positions) == 1:
|
||||
star_position = star_positions[0]
|
||||
elif len(star_positions) >= 2:
|
||||
assert False, "Parser should prevent multiple starred patterns"
|
||||
required_patterns = len(o.patterns)
|
||||
if star_position is not None:
|
||||
required_patterns -= 1
|
||||
|
||||
#
|
||||
# get inner types of original type
|
||||
#
|
||||
if isinstance(current_type, TupleType):
|
||||
inner_types = current_type.items
|
||||
size_diff = len(inner_types) - required_patterns
|
||||
if size_diff < 0:
|
||||
return self.early_non_match()
|
||||
elif size_diff > 0 and star_position is None:
|
||||
return self.early_non_match()
|
||||
else:
|
||||
inner_type = self.get_sequence_type(current_type)
|
||||
if inner_type is None:
|
||||
inner_type = self.chk.named_type("builtins.object")
|
||||
inner_types = [inner_type] * len(o.patterns)
|
||||
|
||||
#
|
||||
# match inner patterns
|
||||
#
|
||||
contracted_new_inner_types: List[Type] = []
|
||||
contracted_rest_inner_types: List[Type] = []
|
||||
captures: Dict[Expression, Type] = {}
|
||||
|
||||
contracted_inner_types = self.contract_starred_pattern_types(inner_types,
|
||||
star_position,
|
||||
required_patterns)
|
||||
can_match = True
|
||||
for p, t in zip(o.patterns, contracted_inner_types):
|
||||
pattern_type = self.accept(p, t)
|
||||
typ, rest, type_map = pattern_type
|
||||
if is_uninhabited(typ):
|
||||
can_match = False
|
||||
else:
|
||||
contracted_new_inner_types.append(typ)
|
||||
contracted_rest_inner_types.append(rest)
|
||||
self.update_type_map(captures, type_map)
|
||||
new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types,
|
||||
star_position,
|
||||
len(inner_types))
|
||||
rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types,
|
||||
star_position,
|
||||
len(inner_types))
|
||||
|
||||
#
|
||||
# Calculate new type
|
||||
#
|
||||
new_type: Type
|
||||
rest_type: Type = current_type
|
||||
if not can_match:
|
||||
new_type = UninhabitedType()
|
||||
elif isinstance(current_type, TupleType):
|
||||
narrowed_inner_types = []
|
||||
inner_rest_types = []
|
||||
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
|
||||
narrowed_inner_type, inner_rest_type = \
|
||||
self.chk.conditional_types_with_intersection(
|
||||
new_inner_type,
|
||||
[get_type_range(inner_type)],
|
||||
o,
|
||||
default=new_inner_type
|
||||
)
|
||||
narrowed_inner_types.append(narrowed_inner_type)
|
||||
inner_rest_types.append(inner_rest_type)
|
||||
if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
|
||||
new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
|
||||
else:
|
||||
new_type = UninhabitedType()
|
||||
|
||||
if all(is_uninhabited(typ) for typ in inner_rest_types):
|
||||
# All subpatterns always match, so we can apply negative narrowing
|
||||
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
|
||||
else:
|
||||
new_inner_type = UninhabitedType()
|
||||
for typ in new_inner_types:
|
||||
new_inner_type = join_types(new_inner_type, typ)
|
||||
new_type = self.construct_sequence_child(current_type, new_inner_type)
|
||||
if is_subtype(new_type, current_type):
|
||||
new_type, _ = self.chk.conditional_types_with_intersection(
|
||||
current_type,
|
||||
[get_type_range(new_type)],
|
||||
o,
|
||||
default=current_type
|
||||
)
|
||||
else:
|
||||
new_type = current_type
|
||||
return PatternType(new_type, rest_type, captures)
|
||||
|
||||
def get_sequence_type(self, t: Type) -> Optional[Type]:
|
||||
t = get_proper_type(t)
|
||||
if isinstance(t, AnyType):
|
||||
return AnyType(TypeOfAny.from_another_any, t)
|
||||
if isinstance(t, UnionType):
|
||||
items = [self.get_sequence_type(item) for item in t.items]
|
||||
not_none_items = [item for item in items if item is not None]
|
||||
if len(not_none_items) > 0:
|
||||
return make_simplified_union(not_none_items)
|
||||
else:
|
||||
return None
|
||||
|
||||
if self.chk.type_is_iterable(t) and isinstance(t, Instance):
|
||||
return self.chk.iterable_item_type(t)
|
||||
else:
|
||||
return None
|
||||
|
||||
def contract_starred_pattern_types(self,
|
||||
types: List[Type],
|
||||
star_pos: Optional[int],
|
||||
num_patterns: int
|
||||
) -> List[Type]:
|
||||
"""
|
||||
Contracts a list of types in a sequence pattern depending on the position of a starred
|
||||
capture pattern.
|
||||
|
||||
For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str,
|
||||
bytes] the contracted types are [bool, Union[int, str], bytes].
|
||||
|
||||
If star_pos in None the types are returned unchanged.
|
||||
"""
|
||||
if star_pos is None:
|
||||
return types
|
||||
new_types = types[:star_pos]
|
||||
star_length = len(types) - num_patterns
|
||||
new_types.append(make_simplified_union(types[star_pos:star_pos+star_length]))
|
||||
new_types += types[star_pos+star_length:]
|
||||
|
||||
return new_types
|
||||
|
||||
def expand_starred_pattern_types(self,
|
||||
types: List[Type],
|
||||
star_pos: Optional[int],
|
||||
num_types: int
|
||||
) -> List[Type]:
|
||||
"""Undoes the contraction done by contract_starred_pattern_types.
|
||||
|
||||
For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
|
||||
to lenght 4 the result is [bool, int, int, str].
|
||||
"""
|
||||
if star_pos is None:
|
||||
return types
|
||||
new_types = types[:star_pos]
|
||||
star_length = num_types - len(types) + 1
|
||||
new_types += [types[star_pos]] * star_length
|
||||
new_types += types[star_pos+1:]
|
||||
|
||||
return new_types
|
||||
|
||||
def visit_starred_pattern(self, o: StarredPattern) -> PatternType:
|
||||
captures: Dict[Expression, Type] = {}
|
||||
if o.capture is not None:
|
||||
list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]])
|
||||
captures[o.capture] = list_type
|
||||
return PatternType(self.type_context[-1], UninhabitedType(), captures)
|
||||
|
||||
def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
|
||||
current_type = get_proper_type(self.type_context[-1])
|
||||
can_match = True
|
||||
captures: Dict[Expression, Type] = {}
|
||||
for key, value in zip(o.keys, o.values):
|
||||
inner_type = self.get_mapping_item_type(o, current_type, key)
|
||||
if inner_type is None:
|
||||
can_match = False
|
||||
inner_type = self.chk.named_type("builtins.object")
|
||||
pattern_type = self.accept(value, inner_type)
|
||||
if is_uninhabited(pattern_type.type):
|
||||
can_match = False
|
||||
else:
|
||||
self.update_type_map(captures, pattern_type.captures)
|
||||
|
||||
if o.rest is not None:
|
||||
mapping = self.chk.named_type("typing.Mapping")
|
||||
if is_subtype(current_type, mapping) and isinstance(current_type, Instance):
|
||||
mapping_inst = map_instance_to_supertype(current_type, mapping.type)
|
||||
dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict")
|
||||
dict_type = fill_typevars(dict_typeinfo)
|
||||
rest_type = expand_type_by_instance(dict_type, mapping_inst)
|
||||
else:
|
||||
object_type = self.chk.named_type("builtins.object")
|
||||
rest_type = self.chk.named_generic_type("builtins.dict",
|
||||
[object_type, object_type])
|
||||
|
||||
captures[o.rest] = rest_type
|
||||
|
||||
if can_match:
|
||||
# We can't narrow the type here, as Mapping key is invariant.
|
||||
new_type = self.type_context[-1]
|
||||
else:
|
||||
new_type = UninhabitedType()
|
||||
return PatternType(new_type, current_type, captures)
|
||||
|
||||
def get_mapping_item_type(self,
|
||||
pattern: MappingPattern,
|
||||
mapping_type: Type,
|
||||
key: Expression
|
||||
) -> Optional[Type]:
|
||||
local_errors = self.msg.clean_copy()
|
||||
local_errors.disable_count = 0
|
||||
mapping_type = get_proper_type(mapping_type)
|
||||
if isinstance(mapping_type, TypedDictType):
|
||||
result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr(
|
||||
mapping_type, key, local_errors=local_errors)
|
||||
# If we can't determine the type statically fall back to treating it as a normal
|
||||
# mapping
|
||||
if local_errors.is_errors():
|
||||
local_errors = self.msg.clean_copy()
|
||||
local_errors.disable_count = 0
|
||||
result = self.get_simple_mapping_item_type(pattern,
|
||||
mapping_type,
|
||||
key,
|
||||
local_errors)
|
||||
|
||||
if local_errors.is_errors():
|
||||
result = None
|
||||
else:
|
||||
result = self.get_simple_mapping_item_type(pattern,
|
||||
mapping_type,
|
||||
key,
|
||||
local_errors)
|
||||
return result
|
||||
|
||||
def get_simple_mapping_item_type(self,
|
||||
pattern: MappingPattern,
|
||||
mapping_type: Type,
|
||||
key: Expression,
|
||||
local_errors: MessageBuilder
|
||||
) -> Type:
|
||||
result, _ = self.chk.expr_checker.check_method_call_by_name('__getitem__',
|
||||
mapping_type,
|
||||
[key],
|
||||
[ARG_POS],
|
||||
pattern,
|
||||
local_errors=local_errors)
|
||||
return result
|
||||
|
||||
def visit_class_pattern(self, o: ClassPattern) -> PatternType:
|
||||
current_type = get_proper_type(self.type_context[-1])
|
||||
|
||||
#
|
||||
# Check class type
|
||||
#
|
||||
type_info = o.class_ref.node
|
||||
if type_info is None:
|
||||
return PatternType(AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error), {})
|
||||
if isinstance(type_info, TypeAlias) and not type_info.no_args:
|
||||
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
|
||||
return self.early_non_match()
|
||||
if isinstance(type_info, TypeInfo):
|
||||
any_type = AnyType(TypeOfAny.implementation_artifact)
|
||||
typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
|
||||
elif isinstance(type_info, TypeAlias):
|
||||
typ = type_info.target
|
||||
else:
|
||||
if isinstance(type_info, Var):
|
||||
name = str(type_info.type)
|
||||
else:
|
||||
name = type_info.name
|
||||
self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o.class_ref)
|
||||
return self.early_non_match()
|
||||
|
||||
new_type, rest_type = self.chk.conditional_types_with_intersection(
|
||||
current_type, [get_type_range(typ)], o, default=current_type
|
||||
)
|
||||
if is_uninhabited(new_type):
|
||||
return self.early_non_match()
|
||||
# TODO: Do I need this?
|
||||
narrowed_type = narrow_declared_type(current_type, new_type)
|
||||
|
||||
#
|
||||
# Convert positional to keyword patterns
|
||||
#
|
||||
keyword_pairs: List[Tuple[Optional[str], Pattern]] = []
|
||||
match_arg_set: Set[str] = set()
|
||||
|
||||
captures: Dict[Expression, Type] = {}
|
||||
|
||||
if len(o.positionals) != 0:
|
||||
if self.should_self_match(typ):
|
||||
if len(o.positionals) > 1:
|
||||
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
|
||||
pattern_type = self.accept(o.positionals[0], narrowed_type)
|
||||
if not is_uninhabited(pattern_type.type):
|
||||
return PatternType(pattern_type.type,
|
||||
join_types(rest_type, pattern_type.rest_type),
|
||||
pattern_type.captures)
|
||||
captures = pattern_type.captures
|
||||
else:
|
||||
local_errors = self.msg.clean_copy()
|
||||
match_args_type = analyze_member_access("__match_args__", typ, o,
|
||||
False, False, False,
|
||||
local_errors,
|
||||
original_type=typ,
|
||||
chk=self.chk)
|
||||
|
||||
if local_errors.is_errors():
|
||||
self.msg.fail(message_registry.MISSING_MATCH_ARGS.format(typ), o)
|
||||
return self.early_non_match()
|
||||
|
||||
proper_match_args_type = get_proper_type(match_args_type)
|
||||
if isinstance(proper_match_args_type, TupleType):
|
||||
match_arg_names = get_match_arg_names(proper_match_args_type)
|
||||
|
||||
if len(o.positionals) > len(match_arg_names):
|
||||
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
|
||||
return self.early_non_match()
|
||||
else:
|
||||
match_arg_names = [None] * len(o.positionals)
|
||||
|
||||
for arg_name, pos in zip(match_arg_names, o.positionals):
|
||||
keyword_pairs.append((arg_name, pos))
|
||||
if arg_name is not None:
|
||||
match_arg_set.add(arg_name)
|
||||
|
||||
#
|
||||
# Check for duplicate patterns
|
||||
#
|
||||
keyword_arg_set = set()
|
||||
has_duplicates = False
|
||||
for key, value in zip(o.keyword_keys, o.keyword_values):
|
||||
keyword_pairs.append((key, value))
|
||||
if key in match_arg_set:
|
||||
self.msg.fail(
|
||||
message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key),
|
||||
value
|
||||
)
|
||||
has_duplicates = True
|
||||
elif key in keyword_arg_set:
|
||||
self.msg.fail(message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key),
|
||||
value)
|
||||
has_duplicates = True
|
||||
keyword_arg_set.add(key)
|
||||
|
||||
if has_duplicates:
|
||||
return self.early_non_match()
|
||||
|
||||
#
|
||||
# Check keyword patterns
|
||||
#
|
||||
can_match = True
|
||||
for keyword, pattern in keyword_pairs:
|
||||
key_type: Optional[Type] = None
|
||||
local_errors = self.msg.clean_copy()
|
||||
if keyword is not None:
|
||||
key_type = analyze_member_access(keyword,
|
||||
narrowed_type,
|
||||
pattern,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
local_errors,
|
||||
original_type=new_type,
|
||||
chk=self.chk)
|
||||
else:
|
||||
key_type = AnyType(TypeOfAny.from_error)
|
||||
if local_errors.is_errors() or key_type is None:
|
||||
key_type = AnyType(TypeOfAny.from_error)
|
||||
self.msg.fail(message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword),
|
||||
pattern)
|
||||
|
||||
inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
|
||||
if is_uninhabited(inner_type):
|
||||
can_match = False
|
||||
else:
|
||||
self.update_type_map(captures, inner_captures)
|
||||
if not is_uninhabited(inner_rest_type):
|
||||
rest_type = current_type
|
||||
|
||||
if not can_match:
|
||||
new_type = UninhabitedType()
|
||||
return PatternType(new_type, rest_type, captures)
|
||||
|
||||
def should_self_match(self, typ: Type) -> bool:
|
||||
typ = get_proper_type(typ)
|
||||
if isinstance(typ, Instance) and typ.type.is_named_tuple:
|
||||
return False
|
||||
for other in self.self_match_types:
|
||||
if is_subtype(typ, other):
|
||||
return True
|
||||
return False
|
||||
|
||||
def can_match_sequence(self, typ: ProperType) -> bool:
|
||||
if isinstance(typ, UnionType):
|
||||
return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
|
||||
for other in self.non_sequence_match_types:
|
||||
# We have to ignore promotions, as memoryview should match, but bytes,
|
||||
# which it can be promoted to, shouldn't
|
||||
if is_subtype(typ, other, ignore_promotions=True):
|
||||
return False
|
||||
sequence = self.chk.named_type("typing.Sequence")
|
||||
# If the static type is more general than sequence the actual type could still match
|
||||
return is_subtype(typ, sequence) or is_subtype(sequence, typ)
|
||||
|
||||
def generate_types_from_names(self, type_names: List[str]) -> List[Type]:
|
||||
types: List[Type] = []
|
||||
for name in type_names:
|
||||
try:
|
||||
types.append(self.chk.named_type(name))
|
||||
except KeyError as e:
|
||||
# Some built in types are not defined in all test cases
|
||||
if not name.startswith('builtins.'):
|
||||
raise e
|
||||
pass
|
||||
|
||||
return types
|
||||
|
||||
def update_type_map(self,
|
||||
original_type_map: Dict[Expression, Type],
|
||||
extra_type_map: Dict[Expression, Type]
|
||||
) -> None:
|
||||
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
|
||||
# expressions, as suggested in the TODO above it's definition
|
||||
already_captured = set(literal_hash(expr) for expr in original_type_map)
|
||||
for expr, typ in extra_type_map.items():
|
||||
if literal_hash(expr) in already_captured:
|
||||
node = get_var(expr)
|
||||
self.msg.fail(message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name),
|
||||
expr)
|
||||
else:
|
||||
original_type_map[expr] = typ
|
||||
|
||||
def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
|
||||
"""
|
||||
If outer_type is a child class of typing.Sequence returns a new instance of
|
||||
outer_type, that is a Sequence of inner_type. If outer_type is not a child class of
|
||||
typing.Sequence just returns a Sequence of inner_type
|
||||
|
||||
For example:
|
||||
construct_sequence_child(List[int], str) = List[str]
|
||||
"""
|
||||
proper_type = get_proper_type(outer_type)
|
||||
if isinstance(proper_type, UnionType):
|
||||
types = [
|
||||
self.construct_sequence_child(item, inner_type) for item in proper_type.items
|
||||
if self.can_match_sequence(get_proper_type(item))
|
||||
]
|
||||
return make_simplified_union(types)
|
||||
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
|
||||
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
|
||||
proper_type = get_proper_type(outer_type)
|
||||
assert isinstance(proper_type, Instance)
|
||||
empty_type = fill_typevars(proper_type.type)
|
||||
partial_type = expand_type_by_instance(empty_type, sequence)
|
||||
return expand_type_by_instance(partial_type, proper_type)
|
||||
else:
|
||||
return sequence
|
||||
|
||||
def early_non_match(self) -> PatternType:
|
||||
return PatternType(UninhabitedType(), self.type_context[-1], {})
|
||||
|
||||
|
||||
def get_match_arg_names(typ: TupleType) -> List[Optional[str]]:
|
||||
args: List[Optional[str]] = []
|
||||
for item in typ.items:
|
||||
values = try_getting_str_literals_from_type(item)
|
||||
if values is None or len(values) != 1:
|
||||
args.append(None)
|
||||
else:
|
||||
args.append(values[0])
|
||||
return args
|
||||
|
||||
|
||||
def get_var(expr: Expression) -> Var:
|
||||
"""
|
||||
Warning: this in only true for expressions captured by a match statement.
|
||||
Don't call it from anywhere else
|
||||
"""
|
||||
assert isinstance(expr, NameExpr)
|
||||
node = expr.node
|
||||
assert isinstance(node, Var)
|
||||
return node
|
||||
|
||||
|
||||
def get_type_range(typ: Type) -> 'mypy.checker.TypeRange':
|
||||
typ = get_proper_type(typ)
|
||||
if (isinstance(typ, Instance)
|
||||
and typ.last_known_value
|
||||
and isinstance(typ.last_known_value.value, bool)):
|
||||
typ = typ.last_known_value
|
||||
return mypy.checker.TypeRange(typ, is_upper_bound=False)
|
||||
|
||||
|
||||
def is_uninhabited(typ: Type) -> bool:
|
||||
return isinstance(get_proper_type(typ), UninhabitedType)
|
||||
Loading…
Add table
Add a link
Reference in a new issue