@@ -16,7 +16,7 @@ use syntax::{
1616 edit_in_place:: { AttrsOwnerEdit , Indent } ,
1717 make, HasName ,
1818 } ,
19- ted, AstNode , NodeOrToken , SyntaxNode , T ,
19+ ted, AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
2020} ;
2121use text_edit:: TextRange ;
2222
@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
4040// ```
4141// ->
4242// ```
43- // fn main() {
44- // #[derive(PartialEq, Eq)]
45- // enum Bool { True, False }
43+ // #[derive(PartialEq, Eq)]
44+ // enum Bool { True, False }
4645//
46+ // fn main() {
4747// let bool = Bool::True;
4848//
4949// if bool == Bool::True {
@@ -270,6 +270,15 @@ fn replace_usages(
270270 }
271271 _ => ( ) ,
272272 }
273+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274+ {
275+ edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276+ replace_bool_expr ( edit, initializer) ;
277+ } else if let Some ( receiver) = find_method_call_expr_usage ( & new_name) {
278+ edit. replace (
279+ receiver. syntax ( ) . text_range ( ) ,
280+ format ! ( "({} == Bool::True)" , receiver) ,
281+ ) ;
273282 } else if new_name. syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
274283 // for any other usage in an expression, replace it with a check that it is the true variant
275284 if let Some ( ( record_field, expr) ) = new_name
@@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
413422 }
414423}
415424
425+ fn find_assoc_const_usage ( name : & ast:: NameLike ) -> Option < ( ast:: Type , ast:: Expr ) > {
426+ let const_ = name. syntax ( ) . parent ( ) . and_then ( ast:: Const :: cast) ?;
427+ if const_. syntax ( ) . parent ( ) . and_then ( ast:: AssocItemList :: cast) . is_none ( ) {
428+ return None ;
429+ }
430+
431+ Some ( ( const_. ty ( ) ?, const_. body ( ) ?) )
432+ }
433+
434+ fn find_method_call_expr_usage ( name : & ast:: NameLike ) -> Option < ast:: Expr > {
435+ let method_call = name. syntax ( ) . ancestors ( ) . find_map ( ast:: MethodCallExpr :: cast) ?;
436+ let receiver = method_call. receiver ( ) ?;
437+
438+ if !receiver. syntax ( ) . descendants ( ) . contains ( name. syntax ( ) ) {
439+ return None ;
440+ }
441+
442+ Some ( receiver)
443+ }
444+
416445/// Adds the definition of the new enum before the target node.
417446fn add_enum_def (
418447 edit : & mut SourceChangeBuilder ,
@@ -430,18 +459,31 @@ fn add_enum_def(
430459 . any ( |module| module. nearest_non_block_module ( ctx. db ( ) ) != * target_module) ;
431460 let enum_def = make_bool_enum ( make_enum_pub) ;
432461
433- let indent = IndentLevel :: from_node ( & target_node) ;
462+ let insert_before = node_to_insert_before ( target_node) ;
463+ let indent = IndentLevel :: from_node ( & insert_before) ;
434464 enum_def. reindent_to ( indent) ;
435465
436466 ted:: insert_all (
437- ted:: Position :: before ( & edit. make_syntax_mut ( target_node ) ) ,
467+ ted:: Position :: before ( & edit. make_syntax_mut ( insert_before ) ) ,
438468 vec ! [
439469 enum_def. syntax( ) . clone( ) . into( ) ,
440470 make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
441471 ] ,
442472 ) ;
443473}
444474
475+ /// Finds where to put the new enum definition.
476+ /// Tries to find the ast node at the nearest module or at top-level, otherwise just
477+ /// returns the input node.
478+ fn node_to_insert_before ( target_node : SyntaxNode ) -> SyntaxNode {
479+ target_node
480+ . ancestors ( )
481+ . take_while ( |it| !matches ! ( it. kind( ) , SyntaxKind :: MODULE | SyntaxKind :: SOURCE_FILE ) )
482+ . filter ( |it| ast:: Item :: can_cast ( it. kind ( ) ) )
483+ . last ( )
484+ . unwrap_or ( target_node)
485+ }
486+
445487fn make_bool_enum ( make_pub : bool ) -> ast:: Enum {
446488 let enum_def = make:: enum_ (
447489 if make_pub { Some ( make:: visibility_pub ( ) ) } else { None } ,
@@ -491,10 +533,10 @@ fn main() {
491533}
492534"# ,
493535 r#"
494- fn main() {
495- #[derive(PartialEq, Eq)]
496- enum Bool { True, False }
536+ #[derive(PartialEq, Eq)]
537+ enum Bool { True, False }
497538
539+ fn main() {
498540 let foo = Bool::True;
499541
500542 if foo == Bool::True {
@@ -520,10 +562,10 @@ fn main() {
520562}
521563"# ,
522564 r#"
523- fn main() {
524- #[derive(PartialEq, Eq)]
525- enum Bool { True, False }
565+ #[derive(PartialEq, Eq)]
566+ enum Bool { True, False }
526567
568+ fn main() {
527569 let foo = Bool::True;
528570
529571 if foo == Bool::False {
@@ -545,10 +587,10 @@ fn main() {
545587}
546588"# ,
547589 r#"
548- fn main() {
549- #[derive(PartialEq, Eq)]
550- enum Bool { True, False }
590+ #[derive(PartialEq, Eq)]
591+ enum Bool { True, False }
551592
593+ fn main() {
552594 let foo: Bool = Bool::False;
553595}
554596"# ,
@@ -565,10 +607,10 @@ fn main() {
565607}
566608"# ,
567609 r#"
568- fn main() {
569- #[derive(PartialEq, Eq)]
570- enum Bool { True, False }
610+ #[derive(PartialEq, Eq)]
611+ enum Bool { True, False }
571612
613+ fn main() {
572614 let foo = if 1 == 2 { Bool::True } else { Bool::False };
573615}
574616"# ,
@@ -590,10 +632,10 @@ fn main() {
590632}
591633"# ,
592634 r#"
593- fn main() {
594- #[derive(PartialEq, Eq)]
595- enum Bool { True, False }
635+ #[derive(PartialEq, Eq)]
636+ enum Bool { True, False }
596637
638+ fn main() {
597639 let foo = Bool::False;
598640 let bar = true;
599641
@@ -619,10 +661,10 @@ fn main() {
619661}
620662"# ,
621663 r#"
622- fn main() {
623- #[derive(PartialEq, Eq)]
624- enum Bool { True, False }
664+ #[derive(PartialEq, Eq)]
665+ enum Bool { True, False }
625666
667+ fn main() {
626668 let foo = Bool::True;
627669
628670 if *&foo == Bool::True {
@@ -645,10 +687,10 @@ fn main() {
645687}
646688"# ,
647689 r#"
648- fn main() {
649- #[derive(PartialEq, Eq)]
650- enum Bool { True, False }
690+ #[derive(PartialEq, Eq)]
691+ enum Bool { True, False }
651692
693+ fn main() {
652694 let foo: Bool;
653695 foo = Bool::True;
654696}
@@ -671,10 +713,10 @@ fn main() {
671713}
672714"# ,
673715 r#"
674- fn main() {
675- #[derive(PartialEq, Eq)]
676- enum Bool { True, False }
716+ #[derive(PartialEq, Eq)]
717+ enum Bool { True, False }
677718
719+ fn main() {
678720 let foo = Bool::True;
679721 let bar = foo == Bool::False;
680722
@@ -702,11 +744,11 @@ fn main() {
702744}
703745"# ,
704746 r#"
747+ #[derive(PartialEq, Eq)]
748+ enum Bool { True, False }
749+
705750fn main() {
706751 if !"foo".chars().any(|c| {
707- #[derive(PartialEq, Eq)]
708- enum Bool { True, False }
709-
710752 let foo = Bool::True;
711753 foo == Bool::True
712754 }) {
@@ -1244,6 +1286,38 @@ fn main() {
12441286 )
12451287 }
12461288
1289+ #[ test]
1290+ fn field_method_chain_usage ( ) {
1291+ check_assist (
1292+ bool_to_enum,
1293+ r#"
1294+ struct Foo {
1295+ $0bool: bool,
1296+ }
1297+
1298+ fn main() {
1299+ let foo = Foo { bool: true };
1300+
1301+ foo.bool.then(|| 2);
1302+ }
1303+ "# ,
1304+ r#"
1305+ #[derive(PartialEq, Eq)]
1306+ enum Bool { True, False }
1307+
1308+ struct Foo {
1309+ bool: Bool,
1310+ }
1311+
1312+ fn main() {
1313+ let foo = Foo { bool: Bool::True };
1314+
1315+ (foo.bool == Bool::True).then(|| 2);
1316+ }
1317+ "# ,
1318+ )
1319+ }
1320+
12471321 #[ test]
12481322 fn field_non_bool ( ) {
12491323 cov_mark:: check!( not_applicable_non_bool_field) ;
@@ -1445,6 +1519,90 @@ pub mod bar {
14451519 )
14461520 }
14471521
1522+ #[ test]
1523+ fn const_in_impl_cross_file ( ) {
1524+ check_assist (
1525+ bool_to_enum,
1526+ r#"
1527+ //- /main.rs
1528+ mod foo;
1529+
1530+ struct Foo;
1531+
1532+ impl Foo {
1533+ pub const $0BOOL: bool = true;
1534+ }
1535+
1536+ //- /foo.rs
1537+ use crate::Foo;
1538+
1539+ fn foo() -> bool {
1540+ Foo::BOOL
1541+ }
1542+ "# ,
1543+ r#"
1544+ //- /main.rs
1545+ mod foo;
1546+
1547+ struct Foo;
1548+
1549+ #[derive(PartialEq, Eq)]
1550+ pub enum Bool { True, False }
1551+
1552+ impl Foo {
1553+ pub const BOOL: Bool = Bool::True;
1554+ }
1555+
1556+ //- /foo.rs
1557+ use crate::{Foo, Bool};
1558+
1559+ fn foo() -> bool {
1560+ Foo::BOOL == Bool::True
1561+ }
1562+ "# ,
1563+ )
1564+ }
1565+
1566+ #[ test]
1567+ fn const_in_trait ( ) {
1568+ check_assist (
1569+ bool_to_enum,
1570+ r#"
1571+ trait Foo {
1572+ const $0BOOL: bool;
1573+ }
1574+
1575+ impl Foo for usize {
1576+ const BOOL: bool = true;
1577+ }
1578+
1579+ fn main() {
1580+ if <usize as Foo>::BOOL {
1581+ println!("foo");
1582+ }
1583+ }
1584+ "# ,
1585+ r#"
1586+ #[derive(PartialEq, Eq)]
1587+ enum Bool { True, False }
1588+
1589+ trait Foo {
1590+ const BOOL: Bool;
1591+ }
1592+
1593+ impl Foo for usize {
1594+ const BOOL: Bool = Bool::True;
1595+ }
1596+
1597+ fn main() {
1598+ if <usize as Foo>::BOOL == Bool::True {
1599+ println!("foo");
1600+ }
1601+ }
1602+ "# ,
1603+ )
1604+ }
1605+
14481606 #[ test]
14491607 fn const_non_bool ( ) {
14501608 cov_mark:: check!( not_applicable_non_bool_const) ;
0 commit comments