Skip to content

Commit 859cc50

Browse files
authored
Added support for bidirectional type inference when assigning an expression to an unpacked tuple literal and all of the items in the tuple have a declared type. This addresses #10481. (#10585)
1 parent 4887eb7 commit 859cc50

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2905,6 +2905,37 @@ export function createTypeEvaluator(
29052905
}
29062906
break;
29072907
}
2908+
2909+
case ParseNodeType.Tuple: {
2910+
// If this is a tuple expression with at least one item and no
2911+
// unpacked items, and all of the items have declared types,
2912+
// we can assume a declared type for the resulting tuple. This
2913+
// is needed to enable bidirectional type inference when assigning
2914+
// to an unpacked tuple.
2915+
if (
2916+
expression.d.items.length > 0 &&
2917+
!expression.d.items.some((item) => item.nodeType === ParseNodeType.Unpack)
2918+
) {
2919+
const itemTypes: Type[] = [];
2920+
expression.d.items.forEach((expr) => {
2921+
const itemType = getDeclaredTypeForExpression(expr, usage);
2922+
if (itemType) {
2923+
itemTypes.push(itemType);
2924+
}
2925+
});
2926+
2927+
if (itemTypes.length === expression.d.items.length) {
2928+
// If all items have a declared type, return a tuple of those types.
2929+
return makeTupleObject(
2930+
evaluatorInterface,
2931+
itemTypes.map((t) => {
2932+
return { type: t, isUnbounded: false };
2933+
})
2934+
);
2935+
}
2936+
}
2937+
break;
2938+
}
29082939
}
29092940

29102941
if (symbol) {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# This sample tests the case where an expression is assigned to an unpacked
2+
# tuple, and the correctly-inferred type of the expression depends on
3+
# bidirectional type inference.
4+
5+
from typing import Literal, TypedDict
6+
7+
8+
def func1[S, T](v: S | T, s: type[S], t: type[T]) -> tuple[S | None, T | None]: ...
9+
10+
11+
def test1():
12+
a: int | None
13+
b: str | None
14+
15+
a, b = func1(1, int, str)
16+
17+
18+
class TD1(TypedDict):
19+
a: int
20+
21+
22+
def test2():
23+
a: TD1
24+
b: TD1
25+
26+
a, b = ({"a": 1}, {"a": 2})
27+
28+
29+
def test3():
30+
a: Literal[1]
31+
b: Literal[2]
32+
33+
a, b = (1, 2)

packages/pyright-internal/src/tests/typeEvaluator2.test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,12 @@ test('Solver44', () => {
843843
TestUtils.validateResults(analysisResults, 0);
844844
});
845845

846+
test('Solver45', () => {
847+
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solver45.py']);
848+
849+
TestUtils.validateResults(analysisResults, 0);
850+
});
851+
846852
test('SolverScoring1', () => {
847853
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solverScoring1.py']);
848854

0 commit comments

Comments
 (0)