Skip to content

Commit e6b927a

Browse files
authored
[red-knot] Add a convenience method for constructing a union from a list of elements (#13315)
1 parent acab1f4 commit e6b927a

File tree

4 files changed

+45
-70
lines changed

4 files changed

+45
-70
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,7 @@ pub(crate) fn definitions_ty<'db>(
144144
.expect("definitions_ty should never be called with zero definitions and no unbound_ty.");
145145

146146
if let Some(second) = all_types.next() {
147-
let mut builder = UnionBuilder::new(db);
148-
builder = builder.add(first).add(second);
149-
150-
for variant in all_types {
151-
builder = builder.add(variant);
152-
}
153-
154-
builder.build()
147+
UnionType::from_elements(db, [first, second].into_iter().chain(all_types))
155148
} else {
156149
first
157150
}
@@ -410,13 +403,7 @@ impl<'db> Type<'db> {
410403
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
411404
if let Type::Tuple(tuple_type) = self {
412405
return IterationOutcome::Iterable {
413-
element_ty: tuple_type
414-
.elements(db)
415-
.iter()
416-
.fold(UnionBuilder::new(db), |builder, element| {
417-
builder.add(*element)
418-
})
419-
.build(),
406+
element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)),
420407
};
421408
}
422409

@@ -497,6 +484,12 @@ impl<'db> Type<'db> {
497484
}
498485
}
499486

487+
impl<'db> From<&Type<'db>> for Type<'db> {
488+
fn from(value: &Type<'db>) -> Self {
489+
*value
490+
}
491+
}
492+
500493
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
501494
enum IterationOutcome<'db> {
502495
Iterable { element_ty: Type<'db> },
@@ -636,20 +629,29 @@ impl<'db> UnionType<'db> {
636629
self.elements(db).contains(&ty)
637630
}
638631

639-
/// Apply a transformation function to all elements of the union,
640-
/// and create a new union from the resulting set of types
641-
pub fn map(
642-
&self,
632+
/// Create a union from a list of elements
633+
/// (which may be eagerly simplified into a different variant of [`Type`] altogether)
634+
pub fn from_elements<T: Into<Type<'db>>>(
643635
db: &'db dyn Db,
644-
mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
636+
elements: impl IntoIterator<Item = T>,
645637
) -> Type<'db> {
646-
self.elements(db)
638+
elements
647639
.into_iter()
648640
.fold(UnionBuilder::new(db), |builder, element| {
649-
builder.add(transform_fn(element))
641+
builder.add(element.into())
650642
})
651643
.build()
652644
}
645+
646+
/// Apply a transformation function to all elements of the union,
647+
/// and create a new union from the resulting set of types
648+
pub fn map(
649+
&self,
650+
db: &'db dyn Db,
651+
transform_fn: impl Fn(&Type<'db>) -> Type<'db>,
652+
) -> Type<'db> {
653+
Self::from_elements(db, self.elements(db).into_iter().map(transform_fn))
654+
}
653655
}
654656

655657
#[salsa::interned]

crates/red_knot_python_semantic/src/types/builder.rs

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,12 @@ impl<'db> IntersectionBuilder<'db> {
169169
if self.intersections.len() == 1 {
170170
self.intersections.pop().unwrap().build(self.db)
171171
} else {
172-
let mut builder = UnionBuilder::new(self.db);
173-
for inner in self.intersections {
174-
builder = builder.add(inner.build(self.db));
175-
}
176-
builder.build()
172+
UnionType::from_elements(
173+
self.db,
174+
self.intersections
175+
.into_iter()
176+
.map(|inner| inner.build(self.db)),
177+
)
177178
}
178179
}
179180
}
@@ -271,11 +272,11 @@ impl<'db> InnerIntersectionBuilder<'db> {
271272

272273
#[cfg(test)]
273274
mod tests {
274-
use super::{IntersectionBuilder, IntersectionType, Type, UnionBuilder, UnionType};
275+
use super::{IntersectionBuilder, IntersectionType, Type, UnionType};
275276
use crate::db::tests::TestDb;
276277
use crate::program::{Program, SearchPathSettings};
277278
use crate::python_version::PythonVersion;
278-
use crate::types::builtins_symbol_ty;
279+
use crate::types::{builtins_symbol_ty, UnionBuilder};
279280
use crate::ProgramSettings;
280281
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
281282

@@ -310,11 +311,7 @@ mod tests {
310311
let db = setup_db();
311312
let t0 = Type::IntLiteral(0);
312313
let t1 = Type::IntLiteral(1);
313-
let union = UnionBuilder::new(&db)
314-
.add(t0)
315-
.add(t1)
316-
.build()
317-
.expect_union();
314+
let union = UnionType::from_elements(&db, [t0, t1]).expect_union();
318315

319316
assert_eq!(union.elements_vec(&db), &[t0, t1]);
320317
}
@@ -323,25 +320,22 @@ mod tests {
323320
fn build_union_single() {
324321
let db = setup_db();
325322
let t0 = Type::IntLiteral(0);
326-
let ty = UnionBuilder::new(&db).add(t0).build();
327-
323+
let ty = UnionType::from_elements(&db, [t0]);
328324
assert_eq!(ty, t0);
329325
}
330326

331327
#[test]
332328
fn build_union_empty() {
333329
let db = setup_db();
334330
let ty = UnionBuilder::new(&db).build();
335-
336331
assert_eq!(ty, Type::Never);
337332
}
338333

339334
#[test]
340335
fn build_union_never() {
341336
let db = setup_db();
342337
let t0 = Type::IntLiteral(0);
343-
let ty = UnionBuilder::new(&db).add(t0).add(Type::Never).build();
344-
338+
let ty = UnionType::from_elements(&db, [t0, Type::Never]);
345339
assert_eq!(ty, t0);
346340
}
347341

@@ -355,21 +349,10 @@ mod tests {
355349
let t2 = Type::BooleanLiteral(false);
356350
let t3 = Type::IntLiteral(17);
357351

358-
let union = UnionBuilder::new(&db)
359-
.add(t0)
360-
.add(t1)
361-
.add(t3)
362-
.build()
363-
.expect_union();
352+
let union = UnionType::from_elements(&db, [t0, t1, t3]).expect_union();
364353
assert_eq!(union.elements_vec(&db), &[t0, t3]);
365-
let union = UnionBuilder::new(&db)
366-
.add(t0)
367-
.add(t1)
368-
.add(t2)
369-
.add(t3)
370-
.build()
371-
.expect_union();
372354

355+
let union = UnionType::from_elements(&db, [t0, t1, t2, t3]).expect_union();
373356
assert_eq!(union.elements_vec(&db), &[bool_ty, t3]);
374357
}
375358

@@ -379,12 +362,8 @@ mod tests {
379362
let t0 = Type::IntLiteral(0);
380363
let t1 = Type::IntLiteral(1);
381364
let t2 = Type::IntLiteral(2);
382-
let u1 = UnionBuilder::new(&db).add(t0).add(t1).build();
383-
let union = UnionBuilder::new(&db)
384-
.add(u1)
385-
.add(t2)
386-
.build()
387-
.expect_union();
365+
let u1 = UnionType::from_elements(&db, [t0, t1]);
366+
let union = UnionType::from_elements(&db, [u1, t2]).expect_union();
388367

389368
assert_eq!(union.elements_vec(&db), &[t0, t1, t2]);
390369
}
@@ -460,7 +439,7 @@ mod tests {
460439
let t0 = Type::IntLiteral(0);
461440
let t1 = Type::IntLiteral(1);
462441
let ta = Type::Any;
463-
let u0 = UnionBuilder::new(&db).add(t0).add(t1).build();
442+
let u0 = UnionType::from_elements(&db, [t0, t1]);
464443

465444
let union = IntersectionBuilder::new(&db)
466445
.add_positive(ta)

crates/red_knot_python_semantic/src/types/display.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ mod tests {
253253
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
254254

255255
use crate::db::tests::TestDb;
256-
use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionBuilder};
256+
use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType};
257257
use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings};
258258

259259
fn setup_db() -> TestDb {
@@ -295,7 +295,7 @@ mod tests {
295295
)?;
296296
let mod_file = system_path_to_file(&db, "src/main.py").expect("Expected file to exist.");
297297

298-
let vec: Vec<Type<'_>> = vec![
298+
let union_elements = &[
299299
Type::Unknown,
300300
Type::IntLiteral(-1),
301301
global_symbol_ty(&db, mod_file, "A"),
@@ -311,10 +311,7 @@ mod tests {
311311
Type::BooleanLiteral(true),
312312
Type::None,
313313
];
314-
let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| {
315-
builder.add(*literal)
316-
});
317-
let union = builder.build().expect_union();
314+
let union = UnionType::from_elements(&db, union_elements).expect_union();
318315
let display = format!("{}", union.display(&db));
319316
assert_eq!(
320317
display,

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use crate::stdlib::builtins_module_scope;
4949
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
5050
use crate::types::{
5151
builtins_symbol_ty, definitions_ty, global_symbol_ty, symbol_ty, BytesLiteralType, ClassType,
52-
FunctionType, StringLiteralType, TupleType, Type, UnionBuilder,
52+
FunctionType, StringLiteralType, TupleType, Type, UnionType,
5353
};
5454
use crate::Db;
5555

@@ -1827,10 +1827,7 @@ impl<'db> TypeInferenceBuilder<'db> {
18271827
let body_ty = self.infer_expression(body);
18281828
let orelse_ty = self.infer_expression(orelse);
18291829

1830-
UnionBuilder::new(self.db)
1831-
.add(body_ty)
1832-
.add(orelse_ty)
1833-
.build()
1830+
UnionType::from_elements(self.db, [body_ty, orelse_ty])
18341831
}
18351832

18361833
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {

0 commit comments

Comments
 (0)