@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
6060pub use rustc_type_ir:: { DebugWithInfcx , InferCtxtLike , WithInfcx } ;
6161pub use vtable:: * ;
6262
63+ use std:: assert_matches:: assert_matches;
6364use std:: fmt:: Debug ;
6465use std:: hash:: { Hash , Hasher } ;
6566use std:: marker:: PhantomData ;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
18261827
18271828 /// Returns layout of a coroutine. Layout might be unavailable if the
18281829 /// coroutine is tainted by errors.
1829- pub fn coroutine_layout ( self , def_id : DefId ) -> Option < & ' tcx CoroutineLayout < ' tcx > > {
1830- self . optimized_mir ( def_id) . coroutine_layout ( )
1830+ ///
1831+ /// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
1832+ /// e.g. `args.as_coroutine().kind_ty()`.
1833+ pub fn coroutine_layout (
1834+ self ,
1835+ def_id : DefId ,
1836+ coroutine_kind_ty : Ty < ' tcx > ,
1837+ ) -> Option < & ' tcx CoroutineLayout < ' tcx > > {
1838+ let mir = self . optimized_mir ( def_id) ;
1839+ // Regular coroutine
1840+ if coroutine_kind_ty. is_unit ( ) {
1841+ mir. coroutine_layout_raw ( )
1842+ } else {
1843+ // If we have a `Coroutine` that comes from an coroutine-closure,
1844+ // then it may be a by-move or by-ref body.
1845+ let ty:: Coroutine ( _, identity_args) =
1846+ * self . type_of ( def_id) . instantiate_identity ( ) . kind ( )
1847+ else {
1848+ unreachable ! ( ) ;
1849+ } ;
1850+ let identity_kind_ty = identity_args. as_coroutine ( ) . kind_ty ( ) ;
1851+ // If the types differ, then we must be getting the by-move body of
1852+ // a by-ref coroutine.
1853+ if identity_kind_ty == coroutine_kind_ty {
1854+ mir. coroutine_layout_raw ( )
1855+ } else {
1856+ assert_matches ! ( coroutine_kind_ty. to_opt_closure_kind( ) , Some ( ClosureKind :: FnOnce ) ) ;
1857+ assert_matches ! (
1858+ identity_kind_ty. to_opt_closure_kind( ) ,
1859+ Some ( ClosureKind :: Fn | ClosureKind :: FnMut )
1860+ ) ;
1861+ mir. coroutine_by_move_body ( ) . unwrap ( ) . coroutine_layout_raw ( )
1862+ }
1863+ }
18311864 }
18321865
18331866 /// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
0 commit comments