Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.timings import TextGenerationTimings
from deepsparse.utils import TimerManager
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import (
CACHE_INPUT_PREFIX,
Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(
deterministic: bool = True,
engine_context: Optional[Context] = None,
internal_kv_cache=False,
timer_manager: TimerManager = None,
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None
Expand Down Expand Up @@ -94,7 +97,7 @@ def __init__(
engine_args=engine_args,
context=engine_context,
)

self.timer_manager = timer_manager or TimerManager()
self.sequence_length = sequence_length
self.sampling_temperature = sampling_temperature
self.deterministic = deterministic
Expand Down Expand Up @@ -186,18 +189,21 @@ def __call__(
:param val_inp: Whether the input is for validation or not
:return: The generated token and corresponding logits
"""
timer = self.timer_manager.current
if self.kv_cache:
# if model has kv cache enabled, we need
# to add the kv cache state to the input
inp = self.add_kv_cache_to_input(inp)

out = self.run(inp, val_inp)
with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"):
out = self.run(inp, val_inp)

if self.kv_cache:
logits, *kv_cache_state = out
self.update_kv_cache(
kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length
)
with timer.time(TextGenerationTimings.KV_CACHE_UPDATE):
logits, *kv_cache_state = out
self.update_kv_cache(
kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length
)
else:
logits = out[0]

Expand Down
20 changes: 7 additions & 13 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import os
import warnings
from dataclasses import dataclass
from typing import (
Any,
Callable,
Expand All @@ -42,6 +41,7 @@
create_causal_mask,
pad_to_fixed_length,
)
from deepsparse.transformers.utils.timings import TextGenerationTimings
from deepsparse.utils.onnx import default_cached_outputs


Expand All @@ -50,14 +50,6 @@
__all__ = ["TextGenerationPipeline"]


@dataclass(frozen=True)
class _TextGenerationTimings:
PROMPT_PREFILL: str = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single"
TOKEN_GENERATION: str = "engine_token_generation"
TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single"


class TextGenerationInput(BaseModel):
class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -314,6 +306,7 @@ def initialize_engines(
input_ids_length=input_ids_length,
tokenizer=self.tokenizer,
internal_kv_cache=self.internal_kv_cache,
timer_manager=self.timer_manager,
)

if self.cache_support_enabled:
Expand All @@ -328,6 +321,7 @@ def initialize_engines(
input_ids_length=1,
tokenizer=self.tokenizer,
internal_kv_cache=self.internal_kv_cache,
timer_manager=self.timer_manager,
)

assert (engine is not None) or (
Expand Down Expand Up @@ -471,7 +465,7 @@ def engine_forward(

else:
# run the prompt through
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
with timer.time(TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)

if streamer is not None:
Expand All @@ -495,9 +489,9 @@ def engine_forward(
callback = context.get("callback")
stop = context.get("stop")

with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
with timer.time(TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
Expand Down Expand Up @@ -573,7 +567,7 @@ def prompt_inference(
for token in tokens[num_tokens_processed:]:
run_tokens.append(token)
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(run_tokens)

Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
# flake8: noqa
from .decoder_kv_cache import *
from .helpers import *
from .timings import *
28 changes: 28 additions & 0 deletions src/deepsparse/transformers/utils/timings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass


__all__ = ["TextGenerationTimings"]


@dataclass(frozen=True)
class TextGenerationTimings:
PROMPT_PREFILL: str = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single"
TOKEN_GENERATION: str = "engine_token_generation"
TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single"
KV_CACHE_UPDATE: str = "kv_cache_update"