Skip to content

Commit 960fb52

Browse files
authored
Fixed bug that results in a false negative when attempting to assign an IntEnum or StrEnum literal to the Literal type corresponding to its value type. This addresses #10552. (#10558)
1 parent 9148e3d commit 960fb52

File tree

5 files changed

+68
-28
lines changed

5 files changed

+68
-28
lines changed

packages/pyright-internal/src/analyzer/checker.ts

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ import {
129129
TypeResult,
130130
} from './typeEvaluatorTypes';
131131
import {
132+
enumerateLiteralsForType,
132133
getElementTypeForContainerNarrowing,
133134
getIsInstanceClassTypes,
134135
narrowTypeForContainerElementType,
@@ -2142,8 +2143,8 @@ export class Checker extends ParseTreeWalker {
21422143
rightExpression = rightExpression.d.leftExpr;
21432144
}
21442145

2145-
const leftType = this._evaluator.getType(node.d.leftExpr);
2146-
const rightType = this._evaluator.getType(rightExpression);
2146+
let leftType = this._evaluator.getType(node.d.leftExpr);
2147+
let rightType = this._evaluator.getType(rightExpression);
21472148

21482149
if (!leftType || !rightType) {
21492150
return;
@@ -2159,6 +2160,44 @@ export class Checker extends ParseTreeWalker {
21592160
: LocMessage.comparisonAlwaysTrue();
21602161
};
21612162

2163+
const replaceEnumTypeWithLiteralValue = (type: Type) => {
2164+
return mapSubtypes(type, (subtype) => {
2165+
if (
2166+
!isClassInstance(subtype) ||
2167+
!ClassType.isEnumClass(subtype) ||
2168+
!subtype.shared.mro.some(
2169+
(base) => isClass(base) && ClassType.isBuiltIn(base, ['int', 'str', 'bytes'])
2170+
)
2171+
) {
2172+
return subtype;
2173+
}
2174+
2175+
// If this is an enum literal, replace it with its literal value.
2176+
if (subtype.priv.literalValue instanceof EnumLiteral) {
2177+
return subtype.priv.literalValue.itemType;
2178+
}
2179+
2180+
// If this is an enum class, replace it with the type of its members.
2181+
const literalValues = enumerateLiteralsForType(this._evaluator, subtype);
2182+
if (literalValues && literalValues.length > 0) {
2183+
return combineTypes(
2184+
literalValues.map((literalClass) => {
2185+
const literalValue = literalClass.priv.literalValue;
2186+
assert(literalValue instanceof EnumLiteral);
2187+
return literalValue.itemType;
2188+
})
2189+
);
2190+
}
2191+
2192+
return subtype;
2193+
});
2194+
};
2195+
2196+
// Handle enum literals that are assignable to another (non-Enum) literal.
2197+
// This can happen for IntEnum and StrEnum members.
2198+
leftType = replaceEnumTypeWithLiteralValue(leftType);
2199+
rightType = replaceEnumTypeWithLiteralValue(rightType);
2200+
21622201
// Check for the special case where the LHS and RHS are both literals.
21632202
if (isLiteralTypeOrUnion(rightType) && isLiteralTypeOrUnion(leftType)) {
21642203
if (
@@ -2171,13 +2210,13 @@ export class Checker extends ParseTreeWalker {
21712210
let isPossiblyTrue = false;
21722211

21732212
doForEachSubtype(leftType, (leftSubtype) => {
2174-
if (this._evaluator.assignType(rightType, leftSubtype)) {
2213+
if (this._evaluator.assignType(rightType!, leftSubtype)) {
21752214
isPossiblyTrue = true;
21762215
}
21772216
});
21782217

21792218
doForEachSubtype(rightType, (rightSubtype) => {
2180-
if (this._evaluator.assignType(leftType, rightSubtype)) {
2219+
if (this._evaluator.assignType(leftType!, rightSubtype)) {
21812220
isPossiblyTrue = true;
21822221
}
21832222
});
@@ -2201,7 +2240,7 @@ export class Checker extends ParseTreeWalker {
22012240
return;
22022241
}
22032242

2204-
this._evaluator.mapSubtypesExpandTypeVars(rightType, {}, (rightSubtype) => {
2243+
this._evaluator.mapSubtypesExpandTypeVars(rightType!, {}, (rightSubtype) => {
22052244
if (isComparable) {
22062245
return;
22072246
}

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24903,21 +24903,6 @@ export function createTypeEvaluator(
2490324903
);
2490424904
}
2490524905

24906-
// Handle enum literals that are assignable to another (non-Enum) literal.
24907-
// This can happen for IntEnum and StrEnum members.
24908-
if (
24909-
ClassType.isEnumClass(concreteSrcType) &&
24910-
concreteSrcType.priv.literalValue instanceof EnumLiteral &&
24911-
concreteSrcType.shared.mro.some(
24912-
(base) => isClass(base) && ClassType.isBuiltIn(base, ['int', 'str', 'bytes'])
24913-
) &&
24914-
isClassInstance(concreteSrcType.priv.literalValue.itemType) &&
24915-
isLiteralType(concreteSrcType.priv.literalValue.itemType) &&
24916-
assignType(destType, concreteSrcType.priv.literalValue.itemType)
24917-
) {
24918-
return true;
24919-
}
24920-
2492124906
if (
2492224907
destType.priv.literalValue !== undefined &&
2492324908
ClassType.isSameGenericClass(destType, concreteSrcType)

packages/pyright-internal/src/tests/samples/comparison2.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# when applied to functions that appear within a conditional expression.
33

44

5+
from enum import Enum
56
from typing import Any, Callable, Coroutine, Protocol
67
from dataclasses import dataclass
78

@@ -140,3 +141,18 @@ def func11(a: A, b: SupportsBool, c: object):
140141

141142
def func12(a: object, b: Callable[..., int]) -> bool:
142143
return a is b
144+
145+
146+
class IntVal(int, Enum):
147+
one = 1
148+
two = 2
149+
three = 3
150+
151+
152+
def func13(x: IntVal):
153+
if x == 1:
154+
pass
155+
156+
# This should generate an error if reportUnnecessaryComparison is enabled.
157+
if x == 4:
158+
pass

packages/pyright-internal/src/tests/samples/enum13.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@ class IntEnum1(IntEnum):
99
MEMBER_2 = 2
1010

1111

12-
i1: Literal[1] = IntEnum1.MEMBER_1
12+
i1: Literal[1] = IntEnum1.MEMBER_1.value
1313

1414
# This should generate an error.
15-
i2: Literal[1] = IntEnum1.MEMBER_2
15+
i2: Literal[1] = IntEnum1.MEMBER_2.value
1616

1717

1818
class StrEnum1(StrEnum):
1919
MEMBER_1 = "a"
2020
MEMBER_2 = "b"
2121

2222

23-
s1: Literal["a"] = StrEnum1.MEMBER_1
23+
s1: Literal["a"] = StrEnum1.MEMBER_1.value
2424

2525
# This should generate an error.
26-
s2: Literal["b"] = StrEnum1.MEMBER_1
26+
s2: Literal["b"] = StrEnum1.MEMBER_1.value
2727

28-
s3: LiteralString = StrEnum1.MEMBER_1
28+
s3: LiteralString = StrEnum1.MEMBER_1.value
2929

3030

3131
class BytesEnum(bytes, ReprEnum):
3232
MEMBER_1 = b"1"
3333
MEMBER_2 = b"2"
3434

3535

36-
b1: Literal[b"1"] = BytesEnum.MEMBER_1
36+
b1: Literal[b"1"] = BytesEnum.MEMBER_1.value
3737

3838
# This should generate an error.
39-
b2: Literal[b"2"] = BytesEnum.MEMBER_1
39+
b2: Literal[b"2"] = BytesEnum.MEMBER_1.value

packages/pyright-internal/src/tests/typeEvaluator6.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ test('Comparison2', () => {
629629

630630
configOptions.diagnosticRuleSet.reportUnnecessaryComparison = 'error';
631631
const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['comparison2.py'], configOptions);
632-
TestUtils.validateResults(analysisResults2, 17);
632+
TestUtils.validateResults(analysisResults2, 18);
633633
});
634634

635635
test('EmptyContainers1', () => {

0 commit comments

Comments
 (0)