@@ -470,6 +470,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
470470 }
471471 }
472472
473+ /// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
474+ ///
475+ /// This results in:
476+ ///
477+ /// ```text
478+ /// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
479+ /// <body>
480+ /// })
481+ /// ```
473482 pub ( super ) fn make_async_expr (
474483 & mut self ,
475484 capture_clause : CaptureBy ,
@@ -480,17 +489,42 @@ impl<'hir> LoweringContext<'_, 'hir> {
480489 body : impl FnOnce ( & mut Self ) -> hir:: Expr < ' hir > ,
481490 ) -> hir:: ExprKind < ' hir > {
482491 let output = match ret_ty {
483- Some ( ty) => FnRetTy :: Ty ( ty ) ,
484- None => FnRetTy :: Default ( span) ,
492+ Some ( ty) => hir :: FnRetTy :: Return ( self . lower_ty ( & ty , ImplTraitContext :: disallowed ( ) ) ) ,
493+ None => hir :: FnRetTy :: DefaultReturn ( span) ,
485494 } ;
486- let ast_decl = FnDecl { inputs : vec ! [ ] , output } ;
487- let decl = self . lower_fn_decl ( & ast_decl, None , /* impl trait allowed */ false , None ) ;
488- let body_id = self . lower_fn_body ( & ast_decl, |this| {
495+
496+ // Resume argument type. We let the compiler infer this to simplify the lowering. It is
497+ // fully constrained by `future::from_generator`.
498+ let input_ty = hir:: Ty { hir_id : self . next_id ( ) , kind : hir:: TyKind :: Infer , span } ;
499+
500+ // The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
501+ let decl = self . arena . alloc ( hir:: FnDecl {
502+ inputs : arena_vec ! [ self ; input_ty] ,
503+ output,
504+ c_variadic : false ,
505+ implicit_self : hir:: ImplicitSelfKind :: None ,
506+ } ) ;
507+
508+ // Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
509+ let ( pat, task_context_hid) = self . pat_ident_binding_mode (
510+ span,
511+ Ident :: with_dummy_span ( sym:: _task_context) ,
512+ hir:: BindingAnnotation :: Mutable ,
513+ ) ;
514+ let param = hir:: Param { attrs : & [ ] , hir_id : self . next_id ( ) , pat, span } ;
515+ let params = arena_vec ! [ self ; param] ;
516+
517+ let body_id = self . lower_body ( move |this| {
489518 this. generator_kind = Some ( hir:: GeneratorKind :: Async ( async_gen_kind) ) ;
490- body ( this)
519+
520+ let old_ctx = this. task_context ;
521+ this. task_context = Some ( task_context_hid) ;
522+ let res = body ( this) ;
523+ this. task_context = old_ctx;
524+ ( params, res)
491525 } ) ;
492526
493- // `static || -> <ret_ty> { body }`:
527+ // `static |_task_context | -> <ret_ty> { body }`:
494528 let generator_kind = hir:: ExprKind :: Closure (
495529 capture_clause,
496530 decl,
@@ -523,13 +557,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
523557 /// ```rust
524558 /// match <expr> {
525559 /// mut pinned => loop {
526- /// match ::std::future::poll_with_tls_context(unsafe {
527- /// <::std::pin::Pin>::new_unchecked(&mut pinned)
528- /// }) {
560+ /// match unsafe { ::std::future::poll_with_context(
561+ /// <::std::pin::Pin>::new_unchecked(&mut pinned),
562+ /// task_context,
563+ /// ) } {
529564 /// ::std::task::Poll::Ready(result) => break result,
530565 /// ::std::task::Poll::Pending => {}
531566 /// }
532- /// yield ();
567+ /// task_context = yield ();
533568 /// }
534569 /// }
535570 /// ```
@@ -561,12 +596,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
561596 let ( pinned_pat, pinned_pat_hid) =
562597 self . pat_ident_binding_mode ( span, pinned_ident, hir:: BindingAnnotation :: Mutable ) ;
563598
564- // ::std::future::poll_with_tls_context(unsafe {
565- // ::std::pin::Pin::new_unchecked(&mut pinned)
566- // })`
599+ let task_context_ident = Ident :: with_dummy_span ( sym:: _task_context) ;
600+
601+ // unsafe {
602+ // ::std::future::poll_with_context(
603+ // ::std::pin::Pin::new_unchecked(&mut pinned),
604+ // task_context,
605+ // )
606+ // }
567607 let poll_expr = {
568608 let pinned = self . expr_ident ( span, pinned_ident, pinned_pat_hid) ;
569609 let ref_mut_pinned = self . expr_mut_addr_of ( span, pinned) ;
610+ let task_context = if let Some ( task_context_hid) = self . task_context {
611+ self . expr_ident_mut ( span, task_context_ident, task_context_hid)
612+ } else {
613+ // Use of `await` outside of an async context, we cannot use `task_context` here.
614+ self . expr_err ( span)
615+ } ;
570616 let pin_ty_id = self . next_id ( ) ;
571617 let new_unchecked_expr_kind = self . expr_call_std_assoc_fn (
572618 pin_ty_id,
@@ -575,14 +621,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
575621 "new_unchecked" ,
576622 arena_vec ! [ self ; ref_mut_pinned] ,
577623 ) ;
578- let new_unchecked =
579- self . arena . alloc ( self . expr ( span, new_unchecked_expr_kind, ThinVec :: new ( ) ) ) ;
580- let unsafe_expr = self . expr_unsafe ( new_unchecked) ;
581- self . expr_call_std_path (
624+ let new_unchecked = self . expr ( span, new_unchecked_expr_kind, ThinVec :: new ( ) ) ;
625+ let call = self . expr_call_std_path (
582626 gen_future_span,
583- & [ sym:: future, sym:: poll_with_tls_context] ,
584- arena_vec ! [ self ; unsafe_expr] ,
585- )
627+ & [ sym:: future, sym:: poll_with_context] ,
628+ arena_vec ! [ self ; new_unchecked, task_context] ,
629+ ) ;
630+ self . arena . alloc ( self . expr_unsafe ( call) )
586631 } ;
587632
588633 // `::std::task::Poll::Ready(result) => break result`
@@ -622,14 +667,26 @@ impl<'hir> LoweringContext<'_, 'hir> {
622667 self . stmt_expr ( span, match_expr)
623668 } ;
624669
670+ // task_context = yield ();
625671 let yield_stmt = {
626672 let unit = self . expr_unit ( span) ;
627673 let yield_expr = self . expr (
628674 span,
629675 hir:: ExprKind :: Yield ( unit, hir:: YieldSource :: Await ) ,
630676 ThinVec :: new ( ) ,
631677 ) ;
632- self . stmt_expr ( span, yield_expr)
678+ let yield_expr = self . arena . alloc ( yield_expr) ;
679+
680+ if let Some ( task_context_hid) = self . task_context {
681+ let lhs = self . expr_ident ( span, task_context_ident, task_context_hid) ;
682+ let assign =
683+ self . expr ( span, hir:: ExprKind :: Assign ( lhs, yield_expr, span) , AttrVec :: new ( ) ) ;
684+ self . stmt_expr ( span, assign)
685+ } else {
686+ // Use of `await` outside of an async context. Return `yield_expr` so that we can
687+ // proceed with type checking.
688+ self . stmt ( span, hir:: StmtKind :: Semi ( yield_expr) )
689+ }
633690 } ;
634691
635692 let loop_block = self . block_all ( span, arena_vec ! [ self ; inner_match_stmt, yield_stmt] , None ) ;
0 commit comments