Skip to content

Commit d54d25a

Browse files
1. Disallow direct item access of NotRequired TypedDict properties
2. When using .get() on a typeddict, the result type will now be a union of the dict[key] type and the type of the default parameter, instead of `object` Fixes python#12094 - replaces python#12095 which is now bitrotted
1 parent 3d78a2f commit d54d25a

File tree

12 files changed

+193
-65
lines changed

12 files changed

+193
-65
lines changed

mypy/checkexpr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3684,6 +3684,8 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
36843684
def visit_index_expr(self, e: IndexExpr) -> Type:
36853685
"""Type check an index expression (base[index]).
36863686
3687+
This function is only used for expressions (rvalues) not for setitem statement (lvalues).
3688+
36873689
It may also represent type application.
36883690
"""
36893691
result = self.visit_index_expr_helper(e)
@@ -3748,7 +3750,7 @@ def visit_index_with_type(
37483750
else:
37493751
return self.nonliteral_tuple_index_helper(left_type, index)
37503752
elif isinstance(left_type, TypedDictType):
3751-
return self.visit_typeddict_index_expr(left_type, e.index)
3753+
return self.visit_typeddict_index_expr(left_type, e.index, is_rvalue=True)
37523754
elif (
37533755
isinstance(left_type, CallableType)
37543756
and left_type.is_type_obj()
@@ -3837,7 +3839,7 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
38373839
return union
38383840

38393841
def visit_typeddict_index_expr(
3840-
self, td_type: TypedDictType, index: Expression, setitem: bool = False
3842+
self, td_type: TypedDictType, index: Expression, setitem: bool = False, *, is_rvalue: bool
38413843
) -> Type:
38423844
if isinstance(index, StrExpr):
38433845
key_names = [index.value]
@@ -3870,6 +3872,8 @@ def visit_typeddict_index_expr(
38703872
self.msg.typeddict_key_not_found(td_type, key_name, index, setitem)
38713873
return AnyType(TypeOfAny.from_error)
38723874
else:
3875+
if is_rvalue and not td_type.is_required(key_name):
3876+
self.msg.typeddict_key_not_required(td_type, key_name, index)
38733877
value_types.append(value_type)
38743878
return make_simplified_union(value_types)
38753879

mypy/checkmember.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def analyze_typeddict_access(
10741074
# Since we can get this during `a['key'] = ...`
10751075
# it is safe to assume that the context is `IndexExpr`.
10761076
item_type = mx.chk.expr_checker.visit_typeddict_index_expr(
1077-
typ, mx.context.index, setitem=True
1077+
typ, mx.context.index, setitem=True, is_rvalue=False
10781078
)
10791079
else:
10801080
# It can also be `a.__setitem__(...)` direct call.

mypy/checkpattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def get_mapping_item_type(
416416
if isinstance(mapping_type, TypedDictType):
417417
with self.msg.filter_errors() as local_errors:
418418
result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
419-
mapping_type, key
419+
mapping_type, key, is_rvalue=False
420420
)
421421
has_local_errors = local_errors.has_new_errors()
422422
# If we can't determine the type statically fall back to treating it as a normal

mypy/errorcodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def __str__(self) -> str:
8484
TYPEDDICT_ITEM: Final = ErrorCode(
8585
"typeddict-item", "Check items when constructing TypedDict", "General"
8686
)
87+
TYPEDDICT_ITEM_ACCESS: Final = ErrorCode(
88+
"typeddict-item-access", "Check NotRequired item access when using TypedDict", "General"
89+
)
8790
TYPEDDICT_UNKNOWN_KEY: Final = ErrorCode(
8891
"typeddict-unknown-key",
8992
"Check unknown keys when constructing TypedDict",

mypy/messages.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,18 @@ def typeddict_key_not_found(
17031703
"Did you mean {}?".format(pretty_seq(matches, "or")), context, code=err_code
17041704
)
17051705

1706+
def typeddict_key_not_required(
1707+
self, typ: TypedDictType, item_name: str, context: Context
1708+
) -> None:
1709+
type_name: str = ""
1710+
if not typ.is_anonymous():
1711+
type_name = format_type(typ) + " "
1712+
self.fail(
1713+
f'TypedDict {type_name}key "{item_name}" is not required and might not be present.',
1714+
context,
1715+
code=codes.TYPEDDICT_ITEM_ACCESS,
1716+
)
1717+
17061718
def typeddict_context_ambiguous(self, types: list[TypedDictType], context: Context) -> None:
17071719
formatted_types = ", ".join(list(format_type_distinctly(*types)))
17081720
self.fail(

mypy/plugins/default.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -189,41 +189,53 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
189189

190190
def typed_dict_get_callback(ctx: MethodContext) -> Type:
191191
"""Infer a precise return type for TypedDict.get with literal first argument."""
192-
if (
192+
if not (
193193
isinstance(ctx.type, TypedDictType)
194194
and len(ctx.arg_types) >= 1
195195
and len(ctx.arg_types[0]) == 1
196196
):
197-
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
198-
if keys is None:
199-
return ctx.default_return_type
197+
return ctx.default_return_type
200198

201-
output_types: list[Type] = []
202-
for key in keys:
203-
value_type = get_proper_type(ctx.type.items.get(key))
204-
if value_type is None:
205-
return ctx.default_return_type
206-
207-
if len(ctx.arg_types) == 1:
208-
output_types.append(value_type)
209-
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
210-
default_arg = ctx.args[1][0]
211-
if (
212-
isinstance(default_arg, DictExpr)
213-
and len(default_arg.items) == 0
214-
and isinstance(value_type, TypedDictType)
215-
):
216-
# Special case '{}' as the default for a typed dict type.
217-
output_types.append(value_type.copy_modified(required_keys=set()))
218-
else:
219-
output_types.append(value_type)
220-
output_types.append(ctx.arg_types[1][0])
221-
222-
if len(ctx.arg_types) == 1:
223-
output_types.append(NoneType())
224-
225-
return make_simplified_union(output_types)
226-
return ctx.default_return_type
199+
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
200+
if keys is None:
201+
return ctx.default_return_type
202+
203+
default_type: Type = NoneType()
204+
if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1:
205+
default_type = ctx.arg_types[1][0]
206+
elif len(ctx.arg_types) > 1:
207+
default_type = ctx.default_return_type
208+
209+
output_types: list[Type] = []
210+
for key in keys:
211+
value_type = get_proper_type(ctx.type.items.get(key))
212+
if value_type is None:
213+
# It would be nice to issue a "TypedDict has no key {key}" failure here. However,
214+
# we don't do this because in the case where you have a union of typed dicts, and
215+
# one of them has the key but the others don't, an error message is incorrect, and
216+
# the plugin API has no mechanism to distinguish these cases.
217+
output_types.append(default_type)
218+
continue
219+
220+
if ctx.type.is_required(key):
221+
output_types.append(value_type)
222+
continue
223+
224+
if len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
225+
default_arg = ctx.args[1][0]
226+
if (
227+
isinstance(default_arg, DictExpr)
228+
and len(default_arg.items) == 0
229+
and isinstance(value_type, TypedDictType)
230+
):
231+
# Special case '{}' as the default for a typed dict type.
232+
output_types.append(value_type.copy_modified(required_keys=set()))
233+
continue
234+
235+
output_types.append(value_type)
236+
output_types.append(default_type)
237+
238+
return make_simplified_union(output_types)
227239

228240

229241
def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:

mypy/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,6 +2336,9 @@ def __init__(
23362336
def accept(self, visitor: TypeVisitor[T]) -> T:
23372337
return visitor.visit_typeddict_type(self)
23382338

2339+
def is_required(self, key: str) -> bool:
2340+
return key in self.required_keys
2341+
23392342
def __hash__(self) -> int:
23402343
return hash((frozenset(self.items.items()), self.fallback, frozenset(self.required_keys)))
23412344

mypyc/test-data/run-misc.test

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ TypeError
640640
10
641641

642642
[case testClassBasedTypedDict]
643+
[typing fixtures/typing-full.pyi]
643644
from typing_extensions import TypedDict
644645

645646
class TD(TypedDict):
@@ -670,8 +671,11 @@ def test_inherited_typed_dict() -> None:
670671
def test_non_total_typed_dict() -> None:
671672
d3 = TD3(c=3)
672673
d4 = TD4(a=1, b=2, c=3, d=4)
673-
assert d3['c'] == 3
674-
assert d4['d'] == 4
674+
assert d3['c'] == 3 # type: ignore[typeddict-item-access]
675+
assert d4['d'] == 4 # type: ignore[typeddict-item-access]
676+
assert d3.get('c') == 3
677+
assert d3.get('d') == 4
678+
assert d3.get('z') is None
675679

676680
[case testClassBasedNamedTuple]
677681
from typing import NamedTuple

test-data/unit/check-literal.test

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,12 +1898,14 @@ c_key: Literal["c"]
18981898
d: Outer
18991899

19001900
reveal_type(d[a_key]) # N: Revealed type is "builtins.int"
1901-
reveal_type(d[b_key]) # N: Revealed type is "builtins.str"
1901+
reveal_type(d[b_key]) # N: Revealed type is "builtins.str" \
1902+
# E: TypedDict "Outer" key "b" is not required and might not be present.
1903+
reveal_type(d.get(b_key)) # N: Revealed type is "builtins.str"
19021904
d[c_key] # E: TypedDict "Outer" has no key "c"
19031905

1904-
reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1906+
reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int"
19051907
reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]"
1906-
reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object"
1908+
reveal_type(d.get(c_key, u)) # N: Revealed type is "__main__.Unrelated"
19071909

19081910
reveal_type(d.pop(a_key)) # E: Key "a" of TypedDict "Outer" cannot be deleted \
19091911
# N: Revealed type is "builtins.int"
@@ -1946,8 +1948,8 @@ u: Unrelated
19461948
reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int"
19471949
reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int"
19481950
reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int"
1949-
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1950-
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object"
1951+
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int"
1952+
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "__main__.Unrelated"
19511953

19521954
a[int_key_bad] # E: Tuple index out of range
19531955
b[int_key_bad] # E: Tuple index out of range
@@ -1987,6 +1989,7 @@ tup2[idx_bad] # E: Tuple index out of range
19871989
[out]
19881990

19891991
[case testLiteralIntelligentIndexingTypedDictUnions]
1992+
# flags: --strict-optional
19901993
from typing_extensions import Literal, Final
19911994
from mypy_extensions import TypedDict
19921995

@@ -2014,12 +2017,12 @@ bad_keys: Literal["a", "bad"]
20142017

20152018
reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]"
20162019
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]"
2017-
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]"
2020+
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]"
20182021
reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]"
20192022
reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]"
20202023
reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]"
2021-
reveal_type(test.get(bad_keys)) # N: Revealed type is "builtins.object"
2022-
reveal_type(test.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
2024+
reveal_type(test.get(bad_keys)) # N: Revealed type is "Union[__main__.A, None]"
2025+
reveal_type(test.get(bad_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?]"
20232026
del test[optional_keys]
20242027

20252028

@@ -2039,6 +2042,7 @@ del test[bad_keys] # E: Key "a" of TypedDict "Test" cannot be delet
20392042
[out]
20402043

20412044
[case testLiteralIntelligentIndexingMultiTypedDict]
2045+
# flags: --strict-optional
20422046
from typing import Union
20432047
from typing_extensions import Literal
20442048
from mypy_extensions import TypedDict
@@ -2067,9 +2071,9 @@ x[bad_keys] # E: TypedDict "D1" has no key "d" \
20672071

20682072
reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]"
20692073
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2070-
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]"
2071-
reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object"
2072-
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
2074+
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2075+
reveal_type(x.get(bad_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, __main__.C, None, __main__.D]"
2076+
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B, __main__.C, Literal[3]?, __main__.D]"
20732077

20742078
[builtins fixtures/dict.pyi]
20752079
[typing fixtures/typing-typeddict.pyi]

test-data/unit/check-narrowing.test

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,20 @@ class TypedDict2(TypedDict):
283283
key: Literal['B', 'C']
284284

285285
x: Union[TypedDict1, TypedDict2]
286-
if x['key'] == 'A':
286+
287+
# NOTE: we ignore typeddict-item-access errors here because the narrowing doesn't work with .get().
288+
289+
if x['key'] == 'A': # type: ignore[typeddict-item-access]
287290
reveal_type(x) # N: Revealed type is "TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]})"
288291
else:
289292
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"
290293

291-
if x['key'] == 'C':
294+
if x['key'] == 'C': # type: ignore[typeddict-item-access]
292295
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"
293296
else:
294297
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"
295298

296-
if x['key'] == 'D':
299+
if x['key'] == 'D': # type: ignore[typeddict-item-access]
297300
reveal_type(x) # E: Statement is unreachable
298301
else:
299302
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"
@@ -310,17 +313,17 @@ class TypedDict2(TypedDict, total=False):
310313
key: Literal['B', 'C']
311314

312315
x: Union[TypedDict1, TypedDict2]
313-
if x['key'] == 'A':
316+
if x['key'] == 'A': # type: ignore[typeddict-item-access]
314317
reveal_type(x) # N: Revealed type is "TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]})"
315318
else:
316319
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"
317320

318-
if x['key'] == 'C':
321+
if x['key'] == 'C': # type: ignore[typeddict-item-access]
319322
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"
320323
else:
321324
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"
322325

323-
if x['key'] == 'D':
326+
if x['key'] == 'D': # type: ignore[typeddict-item-access]
324327
reveal_type(x) # E: Statement is unreachable
325328
else:
326329
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"

0 commit comments

Comments
 (0)