8484# 
8585# For more details, see the code and comments in this file. 
8686
87- 
8887import  argparse 
8988import  asyncio 
9089import  functools 
9190import  heapq 
91+ import  json 
9292import  os 
9393import  sys 
94- import  uuid 
9594import  threading 
95+ import  uuid 
9696from  contextlib  import  asynccontextmanager 
97- from  typing  import  List 
97+ from  dataclasses  import  dataclass 
98+ from  typing  import  Any , List 
9899
99100import  httpx 
100101from  fastapi  import  FastAPI , Request 
106107# Add uvloop for faster event loop if available 
107108try :
108109    import  uvloop 
110+ 
109111    asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
110112except  ImportError :
111113    pass 
@@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None:
324326
325327
326328def  with_cancellation (handler_func ):
327-      
329+ 
328330    @functools .wraps (handler_func ) 
329331    async  def  wrapper (* args , ** kwargs ):
330332        request  =  kwargs ["request" ]
@@ -337,9 +339,9 @@ async def wrapper(*args, **kwargs):
337339        if  handler_task  in  done :
338340            return  handler_task .result ()
339341        return  None 
340-      
342+ 
341343    return  wrapper 
342-          
344+ 
343345
344346app  =  FastAPI (lifespan = lifespan )
345347
@@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
362364        "remote_host" : None ,
363365        "remote_port" : None ,
364366        "aborted_request" : list (aborted_requests ),
365-         "metaserver" : f"http://{ global_args .host }  :{ global_args .port }  /v1/metaserver" 
367+         "metaserver" :
368+         f"http://{ global_args .host }  :{ global_args .port }  /v1/metaserver" 
366369    }
367370    req_data ["stream" ] =  False 
368371    req_data ["max_tokens" ] =  1 
@@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
455458        return  "chatcmpl-"  +  req_id 
456459
457460
461+ async  def  _handle_select_instance (api : str , req_data : Any ,
462+                                   request_length : int ):
463+     prefiller_score  =  proxy_state .calculate_prefill_scores (request_length )
464+     logger .debug (
465+         f"Request length: { request_length }  , Prefiller score: { prefiller_score }  " 
466+     )
467+     request_id  =  await  proxy_state .next_req_id ()
468+     # Select prefiller 
469+     prefiller_idx  =  proxy_state .select_prefiller (prefiller_score )
470+     prefiller  =  proxy_state .prefillers [prefiller_idx ]
471+     result_future  =  asyncio .Future ()  # type: ignore 
472+     request_id_api  =  get_api_request_id (api , request_id )
473+     proxy_state .req_id_future [request_id_api ] =  result_future 
474+     # Send request to prefiller 
475+     asyncio .get_running_loop ().create_task (
476+         send_request_to_service (prefiller .client ,
477+                                 prefiller_idx ,
478+                                 api ,
479+                                 req_data ,
480+                                 request_id ,
481+                                 max_retries = global_args .max_retries ,
482+                                 base_delay = global_args .retry_delay ))
483+     proxy_state .release_prefiller (prefiller_idx , prefiller_score )
484+ 
485+     response  =  await  result_future 
486+     del  proxy_state .req_id_future [request_id_api ]
487+     req_data ["kv_transfer_params" ] =  response 
488+ 
489+     # Select decoder 
490+     decoder_score  =  proxy_state .calculate_decode_scores (request_length )
491+     logger .debug ("Decoder score: %f" , decoder_score )
492+     # Use the prefiller's kv_transfer_params to select decoder 
493+     decoder_idx  =  proxy_state .select_decoder (decoder_score )
494+     decoder  =  proxy_state .decoders [decoder_idx ]
495+     logger .debug ("Using %s %s" , prefiller .url , decoder .url )
496+     return  InstanceInfo (request_id = request_id ,
497+                         prefiller_idx = prefiller_idx ,
498+                         prefiller_score = prefiller_score ,
499+                         prefiller = prefiller ,
500+                         decoder = decoder ,
501+                         decoder_idx = decoder_idx ,
502+                         decoder_score = decoder_score )
503+ 
504+ 
505+ @dataclass  
506+ class  InstanceInfo :
507+     request_id : str 
508+     prefiller_idx : int 
509+     prefiller_score : float 
510+     prefiller : ServerState 
511+     decoder_idx : int 
512+     decoder_score : float 
513+     decoder : ServerState 
514+ 
515+ 
458516async  def  _handle_completions (api : str , request : Request ):
459517    try :
460518        req_data  =  await  request .json ()
461519        req_body  =  await  request .body ()
462520        request_length  =  len (req_body )
463-         prefiller_score  =  proxy_state .calculate_prefill_scores (request_length )
464-         logger .debug (
465-             f"Request length: { request_length }  , Prefiller score: { prefiller_score }  " 
466-         )
467-         request_id  =  await  proxy_state .next_req_id ()
468-         # Select prefiller 
469-         prefiller_idx  =  proxy_state .select_prefiller (prefiller_score )
470-         prefiller  =  proxy_state .prefillers [prefiller_idx ]
471-         result_future  =  asyncio .Future ()  # type: ignore 
472-         request_id_api  =  get_api_request_id (api , request_id )
473-         proxy_state .req_id_future [request_id_api ] =  result_future 
474-         # Send request to prefiller 
475-         asyncio .get_running_loop ().create_task (send_request_to_service (
476-             prefiller .client ,
477-             prefiller_idx ,
478-             api ,
479-             req_data ,
480-             request_id ,
481-             max_retries = global_args .max_retries ,
482-             base_delay = global_args .retry_delay ))
483-         proxy_state .release_prefiller (prefiller_idx , prefiller_score )
484-         
485-         response  =  await  result_future 
486-         del  proxy_state .req_id_future [request_id_api ]
487-         req_data ["kv_transfer_params" ] =  response 
488- 
489-         # Select decoder 
490-         decoder_score  =  proxy_state .calculate_decode_scores (request_length )
491-         logger .debug ("Decoder score: %f" , decoder_score )
492-         # Use the prefiller's kv_transfer_params to select decoder 
493-         decoder_idx  =  proxy_state .select_decoder (decoder_score )
494-         decoder  =  proxy_state .decoders [decoder_idx ]
495-         logger .debug ("Using %s %s" , prefiller .url , decoder .url )
496-         # Stream response from decoder 
497-         released_kv  =  False 
521+         instance_info  =  await  _handle_select_instance (api , req_data ,
522+                                                       request_length )
523+         stream_flag  =  bool (req_data .get ("stream" , False ))
524+         chat_flag  =  "messages"  in  req_data 
525+ 
526+         if  "prompt"  in  req_data :
527+             origin_prompt  =  req_data ["prompt" ]
528+         elif  chat_flag :
529+             messages  =  req_data ["messages" ]
530+             origin_prompt  =  messages [0 ].get ("content" , "" )
531+         else :
532+             origin_prompt  =  "" 
533+         # refer to vLLM sampling_params: max_token default value 
534+         origin_max_tokens  =  req_data .get ("max_tokens" , 16 )
535+ 
498536        async  def  generate_stream ():
499-             nonlocal  released_kv 
537+             nonlocal  instance_info 
538+             generated_token  =  "" 
539+             released_kv  =  False 
540+             retry_count  =  0 
541+             retry  =  True 
542+             completion_tokens  =  0 
500543            # Only one await per chunk, minimal logic in loop 
501544            try :
502-                 async  for  chunk  in  stream_service_response_with_retry (
503-                         decoder .client ,
504-                         api ,
505-                         req_data ,
506-                         request_id = request_id ,
507-                         max_retries = global_args .max_retries ,
508-                         base_delay = global_args .retry_delay ):
509-                     if  not  released_kv  and  chunk :
510-                         proxy_state .release_prefiller_kv (
511-                             prefiller_idx , prefiller_score )
512-                         released_kv  =  True 
513-                     yield  chunk 
545+                 while  retry :
546+                     retry  =  False 
547+                     async  for  chunk  in  stream_service_response_with_retry (
548+                             instance_info .decoder .client ,
549+                             api ,
550+                             req_data ,
551+                             request_id = instance_info .request_id ,
552+                             max_retries = global_args .max_retries ,
553+                             base_delay = global_args .retry_delay ):
554+                         if  not  released_kv  and  chunk :
555+                             proxy_state .release_prefiller_kv (
556+                                 instance_info .prefiller_idx ,
557+                                 instance_info .prefiller_score )
558+                             released_kv  =  True 
559+                         chunk_str  =  chunk .decode ("utf-8" ).strip ()
560+                         if  not  chunk_str :
561+                             continue 
562+                         if  chunk_str .startswith ("data: " ):
563+                             chunk_str  =  chunk_str [len ("data: " ):]
564+                         try :
565+                             chunk_json  =  json .loads (chunk_str )
566+                         except  json .JSONDecodeError :
567+                             # if chunk is [done], skip it. 
568+                             logger .warning (
569+                                 f"Skipping chunk: { chunk_str }  " )
570+                             yield  chunk 
571+                             continue 
572+                         choices  =  chunk_json .get ("choices" , [])
573+                         if  not  choices :
574+                             yield  chunk 
575+                             continue 
576+ 
577+                         choice  =  choices [0 ]
578+                         delta  =  choice .get ("delta" ) or  {}
579+                         message  =  choice .get ("message" ) or  {}
580+                         content  =  (
581+                                 delta .get ("content" )
582+                                 or  message .get ("content" )
583+                                 or  choice .get ("text" )
584+                                 or  "" 
585+                                 )
586+                         generated_token  +=  content 
587+ 
588+                         stop_reason  =  choice .get (
589+                             "stop_reason" )
590+                         usage  =  chunk_json .get ("usage" , {})
591+                         completion_tokens  =  (completion_tokens  +  1 ) if  stream_flag  else  \
592+                             (completion_tokens  +  usage .get ("completion_tokens" ))
593+                         if  stop_reason  ==  "recomputed" :
594+                             retry  =  True 
595+                             retry_count  +=  1 
596+                             if  chat_flag :
597+                                 messages [0 ][
598+                                     "content" ] =  origin_prompt  +  generated_token 
599+                             else :
600+                                 req_data [
601+                                     "prompt" ] =  origin_prompt  +  generated_token 
602+                             req_data [
603+                                 "max_tokens" ] =  origin_max_tokens  -  completion_tokens  +  retry_count 
604+                             tmp_request_length  =  len (
605+                                 json .dumps (req_data ).encode ("utf-8" ))
606+                             instance_info  =  await  _handle_select_instance (
607+                                 api , req_data , tmp_request_length )
608+                             break 
609+                         if  retry_count  >  0  and  not  stream_flag :
610+                             if  chat_flag :
611+                                 choices [0 ]["message" ][
612+                                     "content" ] =  generated_token 
613+                             else :
614+                                 choices [0 ]["text" ] =  generated_token 
615+                             chunk  =  json .dumps (chunk_json ).encode ("utf-8" )
616+                         yield  chunk 
514617            except  Exception  as  e :
515618                logger .error (
516-                     f"Error during streaming from decoder { decoder .url }  : { str (e )}   the aborted request { request_id }   will be routing to the target prefiller when new request is ready to dispatch to it" 
619+                     f"Error during streaming from decoder { instance_info . decoder .url }  : { str (e )}   the aborted request { instance_info . request_id }   will be routing to the target prefiller when new request is ready to dispatch to it" 
517620                )
518-                 proxy_state .abort_prefiller_request (prefiller_idx , request_id )
519-                 proxy_state .release_prefiller_kv (prefiller_idx ,
520-                                                  prefiller_score )
621+                 proxy_state .abort_prefiller_request (
622+                     instance_info .prefiller_idx , instance_info .request_id )
623+                 proxy_state .release_prefiller_kv (instance_info .prefiller_idx ,
624+                                                  instance_info .prefiller_score )
521625
522626            # After streaming done, release tokens 
523-             proxy_state .release_decoder (decoder_idx , decoder_score )
627+             proxy_state .release_decoder (instance_info .decoder_idx ,
628+                                         instance_info .decoder_score )
524629
525630        return  StreamingResponse (generate_stream (),
526631                                 media_type = "application/json" )
@@ -564,13 +669,12 @@ async def metaserver(request: Request):
564669            result_future  =  proxy_state .req_id_future [request_id ]
565670            result_future .set_result (req_data )
566671    except  Exception  as  e :
567-         logger .error (
568-             f"Post metaserver failed with: { str (e )}  " 
569-         )
672+         logger .error (f"Post metaserver failed with: { str (e )}  " )
570673
571674
572675if  __name__  ==  '__main__' :
573676    global  global_args 
574677    global_args  =  parse_args ()
575678    import  uvicorn 
679+ 
576680    uvicorn .run (app , host = global_args .host , port = global_args .port )
0 commit comments