Skip to content

Commit 0a38e33

Browse files
committed
Merge remote-tracking branch 'upstream_gitee/main' into main_1009
# Conflicts: # docs/source/user_guide/feature_guide/eplb_swift_balancer.md # vllm_ascend/patch/platform/patch_common/__init__.py
2 parents 41917ba + 0bf3f21 commit 0a38e33

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2643
-718
lines changed

.github/workflows/_e2e_nightly.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ jobs:
9696
pip install -r requirements-dev.txt
9797
pip install -v -e .
9898
99+
- name: Checkout aisbench repo and Install aisbench
100+
run: |
101+
git clone https://gitee.com/aisbench/benchmark.git
102+
cd benchmark
103+
git checkout v3.0-20250930-master
104+
pip3 install -e ./
105+
pip3 install -r requirements/api.txt
106+
pip3 install -r requirements/extra.txt
107+
99108
- name: Run vllm-project/vllm-ascend test
100109
env:
101110
VLLM_WORKER_MULTIPROC_METHOD: spawn

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ jobs:
177177
pytest -sv tests/e2e/multicard/test_data_parallel.py
178178
pytest -sv tests/e2e/multicard/test_expert_parallel.py
179179
pytest -sv tests/e2e/multicard/test_external_launcher.py
180+
pytest -sv tests/e2e/multicard/test_single_request_aclgraph.py
180181
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
181182
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
182183

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,7 @@ jobs:
119119
TORCH_DEVICE_BACKEND_AUTOLOAD: 0
120120
run: |
121121
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
122-
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
123-
--ignore=tests/ut/test_platform.py \
124-
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py \
125-
--ignore=tests/ut/core/test_scheduler.py \
126-
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
127-
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
128-
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
129-
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
130-
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
122+
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut
131123
132124
- name: Upload coverage to Codecov
133125
# only upload coverage when commits merged

docs/source/user_guide/configuration/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ The details of each config option are as follows:
5858
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
5959
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
6060
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
61+
| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.|
6162

6263
**ascend_scheduler_config**
6364

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 167 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@
8484
#
8585
# For more details, see the code and comments in this file.
8686

87-
8887
import argparse
8988
import asyncio
9089
import functools
9190
import heapq
91+
import json
9292
import os
9393
import sys
94-
import uuid
9594
import threading
95+
import uuid
9696
from contextlib import asynccontextmanager
97-
from typing import List
97+
from dataclasses import dataclass
98+
from typing import Any, List
9899

99100
import httpx
100101
from fastapi import FastAPI, Request
@@ -106,6 +107,7 @@
106107
# Add uvloop for faster event loop if available
107108
try:
108109
import uvloop
110+
109111
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
110112
except ImportError:
111113
pass
@@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None:
324326

325327

326328
def with_cancellation(handler_func):
327-
329+
328330
@functools.wraps(handler_func)
329331
async def wrapper(*args, **kwargs):
330332
request = kwargs["request"]
@@ -337,9 +339,9 @@ async def wrapper(*args, **kwargs):
337339
if handler_task in done:
338340
return handler_task.result()
339341
return None
340-
342+
341343
return wrapper
342-
344+
343345

344346
app = FastAPI(lifespan=lifespan)
345347

@@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
362364
"remote_host": None,
363365
"remote_port": None,
364366
"aborted_request": list(aborted_requests),
365-
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
367+
"metaserver":
368+
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
366369
}
367370
req_data["stream"] = False
368371
req_data["max_tokens"] = 1
@@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
455458
return "chatcmpl-" + req_id
456459

457460

461+
async def _handle_select_instance(api: str, req_data: Any,
462+
request_length: int):
463+
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464+
logger.debug(
465+
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466+
)
467+
request_id = await proxy_state.next_req_id()
468+
# Select prefiller
469+
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470+
prefiller = proxy_state.prefillers[prefiller_idx]
471+
result_future = asyncio.Future() # type: ignore
472+
request_id_api = get_api_request_id(api, request_id)
473+
proxy_state.req_id_future[request_id_api] = result_future
474+
# Send request to prefiller
475+
asyncio.get_running_loop().create_task(
476+
send_request_to_service(prefiller.client,
477+
prefiller_idx,
478+
api,
479+
req_data,
480+
request_id,
481+
max_retries=global_args.max_retries,
482+
base_delay=global_args.retry_delay))
483+
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484+
485+
response = await result_future
486+
del proxy_state.req_id_future[request_id_api]
487+
req_data["kv_transfer_params"] = response
488+
489+
# Select decoder
490+
decoder_score = proxy_state.calculate_decode_scores(request_length)
491+
logger.debug("Decoder score: %f", decoder_score)
492+
# Use the prefiller's kv_transfer_params to select decoder
493+
decoder_idx = proxy_state.select_decoder(decoder_score)
494+
decoder = proxy_state.decoders[decoder_idx]
495+
logger.debug("Using %s %s", prefiller.url, decoder.url)
496+
return InstanceInfo(request_id=request_id,
497+
prefiller_idx=prefiller_idx,
498+
prefiller_score=prefiller_score,
499+
prefiller=prefiller,
500+
decoder=decoder,
501+
decoder_idx=decoder_idx,
502+
decoder_score=decoder_score)
503+
504+
505+
@dataclass
506+
class InstanceInfo:
507+
request_id: str
508+
prefiller_idx: int
509+
prefiller_score: float
510+
prefiller: ServerState
511+
decoder_idx: int
512+
decoder_score: float
513+
decoder: ServerState
514+
515+
458516
async def _handle_completions(api: str, request: Request):
459517
try:
460518
req_data = await request.json()
461519
req_body = await request.body()
462520
request_length = len(req_body)
463-
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464-
logger.debug(
465-
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466-
)
467-
request_id = await proxy_state.next_req_id()
468-
# Select prefiller
469-
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470-
prefiller = proxy_state.prefillers[prefiller_idx]
471-
result_future = asyncio.Future() # type: ignore
472-
request_id_api = get_api_request_id(api, request_id)
473-
proxy_state.req_id_future[request_id_api] = result_future
474-
# Send request to prefiller
475-
asyncio.get_running_loop().create_task(send_request_to_service(
476-
prefiller.client,
477-
prefiller_idx,
478-
api,
479-
req_data,
480-
request_id,
481-
max_retries=global_args.max_retries,
482-
base_delay=global_args.retry_delay))
483-
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484-
485-
response = await result_future
486-
del proxy_state.req_id_future[request_id_api]
487-
req_data["kv_transfer_params"] = response
488-
489-
# Select decoder
490-
decoder_score = proxy_state.calculate_decode_scores(request_length)
491-
logger.debug("Decoder score: %f", decoder_score)
492-
# Use the prefiller's kv_transfer_params to select decoder
493-
decoder_idx = proxy_state.select_decoder(decoder_score)
494-
decoder = proxy_state.decoders[decoder_idx]
495-
logger.debug("Using %s %s", prefiller.url, decoder.url)
496-
# Stream response from decoder
497-
released_kv = False
521+
instance_info = await _handle_select_instance(api, req_data,
522+
request_length)
523+
stream_flag = bool(req_data.get("stream", False))
524+
chat_flag = "messages" in req_data
525+
526+
if "prompt" in req_data:
527+
origin_prompt = req_data["prompt"]
528+
elif chat_flag:
529+
messages = req_data["messages"]
530+
origin_prompt = messages[0].get("content", "")
531+
else:
532+
origin_prompt = ""
533+
# refer to vLLM sampling_params: max_token default value
534+
origin_max_tokens = req_data.get("max_tokens", 16)
535+
498536
async def generate_stream():
499-
nonlocal released_kv
537+
nonlocal instance_info
538+
generated_token = ""
539+
released_kv = False
540+
retry_count = 0
541+
retry = True
542+
completion_tokens = 0
500543
# Only one await per chunk, minimal logic in loop
501544
try:
502-
async for chunk in stream_service_response_with_retry(
503-
decoder.client,
504-
api,
505-
req_data,
506-
request_id=request_id,
507-
max_retries=global_args.max_retries,
508-
base_delay=global_args.retry_delay):
509-
if not released_kv and chunk:
510-
proxy_state.release_prefiller_kv(
511-
prefiller_idx, prefiller_score)
512-
released_kv = True
513-
yield chunk
545+
while retry:
546+
retry = False
547+
async for chunk in stream_service_response_with_retry(
548+
instance_info.decoder.client,
549+
api,
550+
req_data,
551+
request_id=instance_info.request_id,
552+
max_retries=global_args.max_retries,
553+
base_delay=global_args.retry_delay):
554+
if not released_kv and chunk:
555+
proxy_state.release_prefiller_kv(
556+
instance_info.prefiller_idx,
557+
instance_info.prefiller_score)
558+
released_kv = True
559+
chunk_str = chunk.decode("utf-8").strip()
560+
if not chunk_str:
561+
continue
562+
if chunk_str.startswith("data: "):
563+
chunk_str = chunk_str[len("data: "):]
564+
try:
565+
chunk_json = json.loads(chunk_str)
566+
except json.JSONDecodeError:
567+
# if chunk is [done], skip it.
568+
logger.warning(
569+
f"Skipping chunk: {chunk_str}")
570+
yield chunk
571+
continue
572+
choices = chunk_json.get("choices", [])
573+
if not choices:
574+
yield chunk
575+
continue
576+
577+
choice = choices[0]
578+
delta = choice.get("delta") or {}
579+
message = choice.get("message") or {}
580+
content = (
581+
delta.get("content")
582+
or message.get("content")
583+
or choice.get("text")
584+
or ""
585+
)
586+
generated_token += content
587+
588+
stop_reason = choice.get(
589+
"stop_reason")
590+
usage = chunk_json.get("usage", {})
591+
completion_tokens = (completion_tokens + 1) if stream_flag else \
592+
(completion_tokens + usage.get("completion_tokens"))
593+
if stop_reason == "recomputed":
594+
retry = True
595+
retry_count += 1
596+
if chat_flag:
597+
messages[0][
598+
"content"] = origin_prompt + generated_token
599+
else:
600+
req_data[
601+
"prompt"] = origin_prompt + generated_token
602+
req_data[
603+
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
604+
tmp_request_length = len(
605+
json.dumps(req_data).encode("utf-8"))
606+
instance_info = await _handle_select_instance(
607+
api, req_data, tmp_request_length)
608+
break
609+
if retry_count > 0 and not stream_flag:
610+
if chat_flag:
611+
choices[0]["message"][
612+
"content"] = generated_token
613+
else:
614+
choices[0]["text"] = generated_token
615+
chunk = json.dumps(chunk_json).encode("utf-8")
616+
yield chunk
514617
except Exception as e:
515618
logger.error(
516-
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
619+
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
517620
)
518-
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
519-
proxy_state.release_prefiller_kv(prefiller_idx,
520-
prefiller_score)
621+
proxy_state.abort_prefiller_request(
622+
instance_info.prefiller_idx, instance_info.request_id)
623+
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
624+
instance_info.prefiller_score)
521625

522626
# After streaming done, release tokens
523-
proxy_state.release_decoder(decoder_idx, decoder_score)
627+
proxy_state.release_decoder(instance_info.decoder_idx,
628+
instance_info.decoder_score)
524629

525630
return StreamingResponse(generate_stream(),
526631
media_type="application/json")
@@ -564,13 +669,12 @@ async def metaserver(request: Request):
564669
result_future = proxy_state.req_id_future[request_id]
565670
result_future.set_result(req_data)
566671
except Exception as e:
567-
logger.error(
568-
f"Post metaserver failed with: {str(e)}"
569-
)
672+
logger.error(f"Post metaserver failed with: {str(e)}")
570673

571674

572675
if __name__ == '__main__':
573676
global global_args
574677
global_args = parse_args()
575678
import uvicorn
679+
576680
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)