@@ -446,10 +446,10 @@ impl<C: Config> Client<C> {
446446 path : & str ,
447447 request : I ,
448448 event_mapper : impl Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static ,
449- ) -> Pin < Box < dyn Stream < Item = Result < O , OpenAIError > > + Send > >
449+ ) -> OpenAIEventMappedStream < O >
450450 where
451451 I : Serialize ,
452- O : DeserializeOwned + Send + ' static ,
452+ O : DeserializeOwned + Send + ' static
453453 {
454454 let event_source = self
455455 . http_client
@@ -460,8 +460,7 @@ impl<C: Config> Client<C> {
460460 . eventsource ( )
461461 . unwrap ( ) ;
462462
463- // stream_mapped_raw_events(event_source, event_mapper).await
464- todo ! ( )
463+ OpenAIEventMappedStream :: new ( event_source, event_mapper)
465464 }
466465
467466 /// Make HTTP GET request to receive SSE
@@ -491,19 +490,21 @@ impl<C: Config> Client<C> {
491490/// Request which responds with SSE.
492491/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
493492#[ pin_project]
494- pub struct OpenAIEventStream < O > {
493+ pub struct OpenAIEventStream < O : DeserializeOwned + Send + ' static > {
495494 #[ pin]
496495 stream : Filter < EventSource , future:: Ready < bool > , fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > > ,
496+ done : bool ,
497497 _phantom_data : PhantomData < O > ,
498498}
499499
500- impl < O > OpenAIEventStream < O > {
500+ impl < O : DeserializeOwned + Send + ' static > OpenAIEventStream < O > {
501501 pub ( crate ) fn new ( event_source : EventSource ) -> Self {
502502 Self {
503503 stream : event_source. filter ( |result|
504504 // filter out the first event which is always Event::Open
505505 future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) )
506506 ) ,
507+ done : false ,
507508 _phantom_data : PhantomData ,
508509 }
509510 }
@@ -514,6 +515,9 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
514515
515516 fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
516517 let this = self . project ( ) ;
518+ if * this. done {
519+ return Poll :: Ready ( None ) ;
520+ }
517521 let stream: Pin < & mut _ > = this. stream ;
518522 match stream. poll_next ( cx) {
519523 Poll :: Ready ( response) => {
@@ -524,17 +528,24 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
524528 Event :: Open => unreachable ! ( ) , // it has been filtered out
525529 Event :: Message ( message) => {
526530 if message. data == "[DONE]" {
531+ * this. done = true ;
527532 Poll :: Ready ( None ) // end of the stream, defined by OpenAI
528533 } else {
529534 // deserialize the data
530535 match serde_json:: from_str :: < O > ( & message. data ) {
531- Err ( e) => Poll :: Ready ( Some ( Err ( map_deserialization_error ( e, & message. data . as_bytes ( ) ) ) ) ) ,
536+ Err ( e) => {
537+ * this. done = true ;
538+ Poll :: Ready ( Some ( Err ( map_deserialization_error ( e, & message. data . as_bytes ( ) ) ) ) )
539+ }
532540 Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
533541 }
534542 }
535543 }
536544 }
537- Err ( e) => Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
545+ Err ( e) => {
546+ * this. done = true ;
547+ Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
548+ }
538549 }
539550 }
540551 }
@@ -543,6 +554,77 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
543554 }
544555}
545556
557+ #[ pin_project]
558+ pub struct OpenAIEventMappedStream < O >
559+ where O : Send + ' static
560+ {
561+ #[ pin]
562+ stream : Filter < EventSource , future:: Ready < bool > , fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > > ,
563+ event_mapper : Box < dyn Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static > ,
564+ done : bool ,
565+ _phantom_data : PhantomData < O > ,
566+ }
567+
568+ impl < O > OpenAIEventMappedStream < O >
569+ where O : Send + ' static
570+ {
571+ pub ( crate ) fn new < M > ( event_source : EventSource , event_mapper : M ) -> Self
572+ where M : Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static {
573+ Self {
574+ stream : event_source. filter ( |result|
575+ // filter out the first event which is always Event::Open
576+ future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) )
577+ ) ,
578+ done : false ,
579+ event_mapper : Box :: new ( event_mapper) ,
580+ _phantom_data : PhantomData ,
581+ }
582+ }
583+ }
584+
585+
586+ impl < O > Stream for OpenAIEventMappedStream < O >
587+ where O : Send + ' static
588+ {
589+ type Item = Result < O , OpenAIError > ;
590+
591+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
592+ let this = self . project ( ) ;
593+ if * this. done {
594+ return Poll :: Ready ( None ) ;
595+ }
596+ let stream: Pin < & mut _ > = this. stream ;
597+ match stream. poll_next ( cx) {
598+ Poll :: Ready ( response) => {
599+ match response {
600+ None => Poll :: Ready ( None ) , // end of the stream
601+ Some ( result) => match result {
602+ Ok ( event) => match event {
603+ Event :: Open => unreachable ! ( ) , // it has been filtered out
604+ Event :: Message ( message) => {
605+ if message. data == "[DONE]" {
606+ * this. done = true ;
607+ }
608+ let response = ( this. event_mapper ) ( message) ;
609+ match response {
610+ Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
611+ Err ( _) => Poll :: Ready ( None )
612+ }
613+ }
614+ }
615+ Err ( e) => {
616+ * this. done = true ;
617+ Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
618+ }
619+ }
620+ }
621+ }
622+ Poll :: Pending => Poll :: Pending
623+ }
624+ }
625+ }
626+
627+
546628// pub(crate) async fn stream_mapped_raw_events<O>(
547629// mut event_source: EventSource,
548630// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
0 commit comments