From 1da718a254617633473d8c4de47efe7a3b7152ac Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 27 Aug 2025 15:08:48 +0200 Subject: [PATCH] Improve match subject inference --- mypy/checker.py | 9 +- mypy/checkpattern.py | 150 +++++++++++++++++++++------- mypy/literals.py | 2 +- test-data/unit/check-python310.test | 137 +++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 37 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ae6ae591ed8c..c42a91477e09 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5567,7 +5567,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # capture variable may depend on multiple patterns (it # will be a union of all capture types). This pass ignores # guard expressions. - pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] + pattern_types = [ + self.pattern_checker.accept(p, subject_type, [unwrapped_subject]) + for p in s.patterns + ] type_maps: list[TypeMap] = [t.captures for t in pattern_types] inferred_types = self.infer_variable_types_from_type_maps(type_maps) @@ -5577,7 +5580,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None: current_subject_type = self.expr_checker.narrow_type_from_binder( named_subject, subject_type ) - pattern_type = self.pattern_checker.accept(p, current_subject_type) + pattern_type = self.pattern_checker.accept( + p, current_subject_type, [unwrapped_subject] + ) with self.binder.frame_context(can_skip=True, fall_through=2): if b.is_unreachable or isinstance( get_proper_type(pattern_type.type), UninhabitedType diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index f81684d2f44a..7f3f61b4a5ae 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -10,11 +10,25 @@ from mypy.checkmember import analyze_member_access from mypy.expandtype import expand_type_by_instance from mypy.join import join_types -from mypy.literals import literal_hash +from mypy.literals import Key, literal_hash from mypy.maptype import map_instance_to_supertype from mypy.meet import narrow_declared_type from mypy.messages import MessageBuilder -from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var +from mypy.nodes import ( + ARG_POS, + Context, + Expression, + IndexExpr, + IntExpr, + ListExpr, + MemberExpr, + NameExpr, + TupleExpr, + TypeAlias, + TypeInfo, + UnaryExpr, + Var, +) from mypy.options import Options from mypy.patterns import ( AsPattern, @@ -96,10 +110,8 @@ class PatternChecker(PatternVisitor[PatternType]): msg: MessageBuilder # Currently unused plugin: Plugin - # The expression being matched against the pattern - subject: Expression - - subject_type: Type + # The expressions being matched against the (sub)pattern + subject_context: list[list[Expression]] # Type of the subject to check the (sub)pattern against type_context: list[Type] # Types that match against self instead of their __match_args__ if used as a class pattern @@ -118,6 +130,7 @@ def __init__( self.msg = msg self.plugin = plugin + self.subject_context = [] self.type_context = [] self.self_match_types = self.generate_types_from_names(self_match_type_names) self.non_sequence_match_types = self.generate_types_from_names( @@ -125,17 +138,20 @@ def __init__( ) self.options = options - def accept(self, o: Pattern, type_context: Type) -> PatternType: + def accept(self, o: Pattern, type_context: Type, subject: list[Expression]) -> PatternType: + self.subject_context.append(subject) self.type_context.append(type_context) result = o.accept(self) + self.subject_context.pop() self.type_context.pop() return result def visit_as_pattern(self, o: AsPattern) -> PatternType: + current_subject = self.subject_context[-1] current_type = self.type_context[-1] if o.pattern is not None: - pattern_type = self.accept(o.pattern, current_type) + pattern_type = self.accept(o.pattern, current_type, current_subject) typ, rest_type, type_map = pattern_type else: typ, rest_type, type_map = current_type, UninhabitedType(), {} @@ -150,14 +166,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType: return PatternType(typ, rest_type, type_map) def visit_or_pattern(self, o: OrPattern) -> PatternType: + current_subject = self.subject_context[-1] current_type = self.type_context[-1] # # Check all the subpatterns # - pattern_types = [] + pattern_types: list[PatternType] = [] for pattern in o.patterns: - pattern_type = self.accept(pattern, current_type) + pattern_type = self.accept(pattern, current_type, current_subject) pattern_types.append(pattern_type) if not is_uninhabited(pattern_type.type): current_type = pattern_type.rest_type @@ -173,28 +190,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: # # Check the capture types # - capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list) + capture_types: dict[Var, dict[Key | None, list[tuple[Expression, Type]]]] = defaultdict( + lambda: defaultdict(list) + ) + capture_expr_keys: set[Key | None] = set() # Collect captures from the first subpattern for expr, typ in pattern_types[0].captures.items(): - node = get_var(expr) - capture_types[node].append((expr, typ)) + if (node := get_var(expr)) is None: + continue + key = literal_hash(expr) + capture_types[node][key].append((expr, typ)) + if isinstance(expr, NameExpr): + capture_expr_keys.add(key) # Check if other subpatterns capture the same names for i, pattern_type in enumerate(pattern_types[1:]): - vars = {get_var(expr) for expr, _ in pattern_type.captures.items()} - if capture_types.keys() != vars: + vars = { + literal_hash(expr) for expr in pattern_type.captures if isinstance(expr, NameExpr) + } + if capture_expr_keys != vars: + # Only fail for directly captured names (with NameExpr) self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i]) for expr, typ in pattern_type.captures.items(): - node = get_var(expr) - capture_types[node].append((expr, typ)) + if (node := get_var(expr)) is None: + continue + key = literal_hash(expr) + capture_types[node][key].append((expr, typ)) captures: dict[Expression, Type] = {} - for capture_list in capture_types.values(): - typ = UninhabitedType() - for _, other in capture_list: - typ = make_simplified_union([typ, other]) + for expressions in capture_types.values(): + for key, capture_list in expressions.items(): + if other_types := [entry[1] for entry in capture_list]: + typ = make_simplified_union(other_types) + else: + typ = UninhabitedType() - captures[capture_list[0][0]] = typ + captures[capture_list[0][0]] = typ union_type = make_simplified_union(types) return PatternType(union_type, current_type, captures) @@ -284,12 +315,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: contracted_inner_types = self.contract_starred_pattern_types( inner_types, star_position, required_patterns ) - for p, t in zip(o.patterns, contracted_inner_types): - pattern_type = self.accept(p, t) + current_subjects: list[list[Expression]] = [[] for _ in range(len(contracted_inner_types))] + end_pos = len(contracted_inner_types) if star_position is None else star_position + for subject in self.subject_context[-1]: + if isinstance(subject, (ListExpr, TupleExpr)): + # For list and tuple expressions, lookup expression in items + for i in range(end_pos): + if i < len(subject.items): + current_subjects[i].append(subject.items[i]) + if star_position is not None: + for i in range(star_position + 1, len(contracted_inner_types)): + offset = len(contracted_inner_types) - i + if offset <= len(subject.items): + current_subjects[i].append(subject.items[-offset]) + else: + # Support x[0], x[1], ... lookup until wildcard + for i in range(end_pos): + current_subjects[i].append(IndexExpr(subject, IntExpr(i))) + # For everything after wildcard use x[-2], x[-1] + for i in range((star_position or -1) + 1, len(contracted_inner_types)): + offset = len(contracted_inner_types) - i + current_subjects[i].append(IndexExpr(subject, UnaryExpr("-", IntExpr(offset)))) + for p, t, s in zip(o.patterns, contracted_inner_types, current_subjects): + pattern_type = self.accept(p, t, s) typ, rest, type_map = pattern_type contracted_new_inner_types.append(typ) contracted_rest_inner_types.append(rest) self.update_type_map(captures, type_map) + if s: + self.update_type_map( + captures, {subject: typ for subject in s}, fail_multiple_assignments=False + ) new_inner_types = self.expand_starred_pattern_types( contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None @@ -473,11 +529,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: if inner_type is None: can_match = False inner_type = self.chk.named_type("builtins.object") - pattern_type = self.accept(value, inner_type) + current_subjects: list[Expression] = [ + IndexExpr(s, key) for s in self.subject_context[-1] + ] + pattern_type = self.accept(value, inner_type, current_subjects) if is_uninhabited(pattern_type.type): can_match = False else: self.update_type_map(captures, pattern_type.captures) + if current_subjects: + self.update_type_map( + captures, {subject: pattern_type.type for subject in current_subjects} + ) if o.rest is not None: mapping = self.chk.named_type("typing.Mapping") @@ -581,7 +644,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if self.should_self_match(typ): if len(o.positionals) > 1: self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) - pattern_type = self.accept(o.positionals[0], narrowed_type) + pattern_type = self.accept(o.positionals[0], narrowed_type, []) if not is_uninhabited(pattern_type.type): return PatternType( pattern_type.type, @@ -681,11 +744,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: elif keyword is not None: new_type = self.chk.add_any_attribute_to_type(new_type, keyword) - inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type) + current_subjects: list[Expression] = [] + if keyword is not None: + current_subjects = [MemberExpr(s, keyword) for s in self.subject_context[-1]] + inner_type, inner_rest_type, inner_captures = self.accept( + pattern, key_type, current_subjects + ) if is_uninhabited(inner_type): can_match = False else: self.update_type_map(captures, inner_captures) + if current_subjects: + self.update_type_map( + captures, {subject: inner_type for subject in current_subjects} + ) if not is_uninhabited(inner_rest_type): rest_type = current_type @@ -732,17 +804,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]: return types def update_type_map( - self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type] + self, + original_type_map: dict[Expression, Type], + extra_type_map: dict[Expression, Type], + fail_multiple_assignments: bool = True, ) -> None: # Calculating this would not be needed if TypeMap directly used literal hashes instead of # expressions, as suggested in the TODO above it's definition already_captured = {literal_hash(expr) for expr in original_type_map} for expr, typ in extra_type_map.items(): if literal_hash(expr) in already_captured: - node = get_var(expr) - self.msg.fail( - message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr - ) + if (node := get_var(expr)) is None: + continue + if fail_multiple_assignments: + self.msg.fail( + message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr + ) else: original_type_map[expr] = typ @@ -794,12 +871,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]: return args -def get_var(expr: Expression) -> Var: +def get_var(expr: Expression) -> Var | None: """ Warning: this in only true for expressions captured by a match statement. Don't call it from anywhere else """ - assert isinstance(expr, NameExpr), expr + if isinstance(expr, MemberExpr): + return get_var(expr.expr) + if isinstance(expr, IndexExpr): + return get_var(expr.base) + if not isinstance(expr, NameExpr): + return None node = expr.node assert isinstance(node, Var), node return node diff --git a/mypy/literals.py b/mypy/literals.py index 5b0c46f4bee8..a7dcd6b2ea72 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -228,7 +228,7 @@ def visit_set_expr(self, e: SetExpr) -> Key | None: return self.seq_expr(e, "Set") def visit_index_expr(self, e: IndexExpr) -> Key | None: - if literal(e.index) == LITERAL_YES: + if literal(e.index) != LITERAL_NO: return ("Index", literal_hash(e.base), literal_hash(e.index)) return None diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 5c495d2ed863..0bd2f5dd9f96 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -2975,3 +2975,140 @@ val: int = 8 match val: case FOO: # E: Cannot assign to final name "FOO" pass + +[case testMatchSubjectInferenceSequence] +m: object + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" + reveal_type(m[0]) # N: Revealed type is "Literal[1]" + reveal_type(m[-2]) # N: Revealed type is "Literal[1]" + reveal_type(m[1]) # N: Revealed type is "Literal[True]" + reveal_type(m[-1]) # N: Revealed type is "Literal[True]" + case [1, *_, False]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.object]" + reveal_type(m[-1]) # N: Revealed type is "Literal[False]" + case [[1], [True]]: + reveal_type(m[0][0]) # N: Revealed type is "Literal[1]" + reveal_type(m[-2][0]) # N: Revealed type is "Literal[1]" + reveal_type(m[1][0]) # N: Revealed type is "Literal[True]" + reveal_type(m[-1][0]) # N: Revealed type is "Literal[True]" +[builtins fixtures/tuple.pyi] + +[case testMatchSubjectInferenceMapping] +from typing import Any +m: Any + +match m: + case {"key": 1}: + reveal_type(m["key"]) # N: Revealed type is "Literal[1]" + +[case testMatchSubjectInferenceClass] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str | None + b: int | None + +m: A + +match m: + case A("Hello", 2): + reveal_type(m.a) # N: Revealed type is "Literal['Hello']" + reveal_type(m.b) # N: Revealed type is "Literal[2]" + case A(a="Hello", b=2): + reveal_type(m.a) # N: Revealed type is "Literal['Hello']" + reveal_type(m.b) # N: Revealed type is "Literal[2]" + case A(a=str()) | A(a=None): + reveal_type(m.a) # N: Revealed type is "Union[builtins.str, None]" + case object(some_attr=str()): + reveal_type(m.some_attr) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] + +[case testMatchSubjectInferenceOR] +m: object + +match m: + case [1, 2, 3] | [8, 9]: + reveal_type(m[0]) # N: Revealed type is "Union[Literal[1], Literal[8]]" + reveal_type(m[1]) # N: Revealed type is "Union[Literal[2], Literal[9]]" + reveal_type(m[2]) # N: Revealed type is "Literal[3]" + +[case testMatchSubjectNested] +from typing import Any +class A: + a: str | None + b: int | None + +m: Any + +match m: + case {"key": [0, A(a="Hello")]}: + reveal_type(m) # N: Revealed type is "Any" + reveal_type(m["key"]) # N: Revealed type is "Any" + reveal_type(m["key"][0]) # N: Revealed type is "Literal[0]" + reveal_type(m["key"][1]) # N: Revealed type is "__main__.A" + reveal_type(m["key"][1].a) # N: Revealed type is "Literal['Hello']" + case [0, {"key": 2}]: + reveal_type(m[1]) # N: Revealed type is "Any" + reveal_type(m[1]["key"]) # N: Revealed type is "Literal[2]" + case object(a=[A(a="Hello") | A(a="World")]): + reveal_type(m.a) # N: Revealed type is "Any" + reveal_type(m.a[0]) # N: Revealed type is "__main__.A" + reveal_type(m.a[0].a) # N: Revealed type is "Union[Literal['Hello'], Literal['World']]" + +[case testMatchSubjectExpression] +# flags: --warn-unreachable +m: object +n: object +o: object +def func(): ... + +match (m, n, o): + case [1, 2, 3] | [2, 3, 4]: + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" + reveal_type(n) # N: Revealed type is "Union[Literal[2], Literal[3]]" + reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]" + case [1, 2, 3, 4] | [2, 3, 4, 5]: + # No match -> don't crash + reveal_type(m) # E: Statement is unreachable + case [1, *_, 3] | [2, *_, 4]: + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" + reveal_type(n) # N: Revealed type is "builtins.object" + reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]" + case [1, *_, 3, 4, 5] | [2, *_, 3, 4, 5]: + # No match -> don't crash + reveal_type(m) # E: Statement is unreachable + case [m, *_]: + # This will always match and bind the variables to itself + # Although it doesn't make much sense, make sure it doesn't raise an error + reveal_type(m) # N: Revealed type is "builtins.object" + +match [m, n, o]: + case [1, 2, 3] | [2, 3, 4]: + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" + reveal_type(n) # N: Revealed type is "Union[Literal[2], Literal[3]]" + reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]" + case [1, 2, 3, 4] | [2, 3, 4, 5]: + # No match, but mypy can't detect that yet -> don't crash + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" + case [1, *_, 3] | [2, *_, 4]: + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" + reveal_type(n) # N: Revealed type is "builtins.object" + reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]" + case [1, *_, 3, 4, 5] | [2, *_, 3, 4, 5]: + # No match, but mypy can't detect that yet -> don't crash + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" + +match a := m: + case [1, 2] | [3, 4]: + reveal_type(a) # N: Revealed type is "typing.Sequence[builtins.int]" + reveal_type(a[0]) # N: Revealed type is "Union[Literal[1], Literal[3]]" + +match func(): + # Don't crash for subject expressions which can't be narrowed + case [1, 2] | [3, 4]: + ... +[builtins fixtures/tuple.pyi]