Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions llama-index-core/llama_index/core/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,25 @@ class Test(BaseModel):
dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

def _structured_stream_call(
self,
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> Generator[
Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None, None
]:
from llama_index.core.program.utils import get_program_for_llm

program = get_program_for_llm(
output_cls,
prompt,
self,
pydantic_program_mode=self.pydantic_program_mode,
)
return program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)

@dispatcher.span
def stream_structured_predict(
self,
Expand Down Expand Up @@ -484,28 +503,42 @@ class Test(BaseModel):
```

"""
from llama_index.core.program.utils import get_program_for_llm

dispatcher.event(
LLMStructuredPredictStartEvent(
output_cls=output_cls, template=prompt, template_args=prompt_args
)
)
program = get_program_for_llm(
output_cls,
prompt,
self,
pydantic_program_mode=self.pydantic_program_mode,
)

result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
result = self._structured_stream_call(
output_cls, prompt, llm_kwargs, **prompt_args
)
for r in result:
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
assert not isinstance(r, list)
yield r

dispatcher.event(LLMStructuredPredictEndEvent(output=r))

async def _structured_astream_call(
self,
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> AsyncGenerator[
Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None
]:
from llama_index.core.program.utils import get_program_for_llm

program = get_program_for_llm(
output_cls,
prompt,
self,
pydantic_program_mode=self.pydantic_program_mode,
)

return await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)

@dispatcher.span
async def astream_structured_predict(
self,
Expand Down Expand Up @@ -550,23 +583,15 @@ class Test(BaseModel):
"""

async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
from llama_index.core.program.utils import (
get_program_for_llm,
)

dispatcher.event(
LLMStructuredPredictStartEvent(
output_cls=output_cls, template=prompt, template_args=prompt_args
)
)
program = get_program_for_llm(
output_cls,
prompt,
self,
pydantic_program_mode=self.pydantic_program_mode,
)

result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
result = await self._structured_astream_call(
output_cls, prompt, llm_kwargs, **prompt_args
)
async for r in result:
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
assert not isinstance(r, list)
Expand Down
159 changes: 159 additions & 0 deletions llama-index-core/llama_index/core/program/streaming_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
Simplified streaming utilities for processing structured outputs from message content.
This module provides utilities for processing streaming responses that contain
structured data directly in the message content (not in function calls).
"""

from typing import Optional, Type, Union

from pydantic import ValidationError

from llama_index.core.base.llms.types import ChatResponse
from llama_index.core.program.utils import (
FlexibleModel,
_repair_incomplete_json,
create_flexible_model,
)
from llama_index.core.types import Model


def process_streaming_content_incremental(
chat_response: ChatResponse,
output_cls: Type[Model],
cur_object: Optional[Union[Model, FlexibleModel]] = None,
) -> Union[Model, FlexibleModel]:
"""
Process streaming response content with true incremental list handling.
This version can extract partial progress from incomplete JSON and build
lists incrementally (e.g., 1 joke → 2 jokes → 3 jokes) rather than
jumping from empty to complete lists.
Args:
chat_response (ChatResponse): The chat response to process
output_cls (Type[BaseModel]): The target output class
cur_object (Optional[BaseModel]): Current best object (for comparison)
flexible_mode (bool): Whether to use flexible schema during parsing
Returns:
Union[BaseModel, FlexibleModel]: Processed object with incremental updates
"""
partial_output_cls = create_flexible_model(output_cls)

# Get content from message
content = chat_response.message.content
if not content:
return cur_object if cur_object is not None else partial_output_cls()
try:
parsed_obj = partial_output_cls.model_validate_json(content)
except (ValidationError, ValueError):
try:
repaired_json = _repair_incomplete_json(content)
parsed_obj = partial_output_cls.model_validate_json(repaired_json)
except (ValidationError, ValueError):
extracted_obj = _extract_partial_list_progress(
content, output_cls, cur_object, partial_output_cls
)
parsed_obj = (
extracted_obj if extracted_obj is not None else partial_output_cls()
)

# If we still couldn't parse anything, use previous object
if parsed_obj is None:
if cur_object is not None:
return cur_object
else:
return partial_output_cls()

# Use incremental comparison that considers list progress
try:
return output_cls.model_validate(parsed_obj.model_dump(exclude_unset=True))
except ValidationError:
return parsed_obj


def _extract_partial_list_progress(
content: str,
output_cls: Type[Model],
cur_object: Optional[Union[Model, FlexibleModel]],
partial_output_cls: Type[FlexibleModel],
) -> Optional[FlexibleModel]:
"""
Try to extract partial list progress from incomplete JSON.
This attempts to build upon the current object by detecting partial
list additions even when JSON is malformed.
"""
if not isinstance(content, str) or cur_object is None:
return None

try:
import re

# Try to extract list patterns from incomplete JSON
# Look for patterns like: "jokes": [{"setup": "...", "punchline": "..."}
list_pattern = r'"(\w+)":\s*\[([^\]]*)'
matches = re.findall(list_pattern, content)

if not matches:
return None

# Start with current object data
current_data = (
cur_object.model_dump() if hasattr(cur_object, "model_dump") else {}
)

for field_name, list_content in matches:
if (
hasattr(output_cls, "model_fields")
and field_name in output_cls.model_fields
):
# Try to parse individual items from the list content
items = _parse_partial_list_items(list_content, field_name, output_cls)
if items:
current_data[field_name] = items

# Try to create object with updated data
return partial_output_cls.model_validate(current_data)

except Exception:
return None


def _parse_partial_list_items(
list_content: str, field_name: str, output_cls: Type[Model]
) -> list:
"""
Parse individual items from partial list content.
"""
try:
import json
import re

items = []

# Look for complete object patterns within the list
# Pattern: {"key": "value", "key2": "value2"}
object_pattern = r"\{[^{}]*\}"
object_matches = re.findall(object_pattern, list_content)

for obj_str in object_matches:
try:
# Try to parse as complete JSON object
obj_data = json.loads(obj_str)
items.append(obj_data)
except (json.JSONDecodeError, SyntaxError):
# Try to repair and parse
try:
repaired = _repair_incomplete_json(obj_str)
obj_data = json.loads(repaired)
items.append(obj_data)
except (json.JSONDecodeError, SyntaxError):
continue

return items

except Exception:
return []
Loading