Skip to content

Commit bc2d63b

Browse files
authored
feat: add OpenAI JSON Schema structured output support (#18897)
1 parent ac5c3ff commit bc2d63b

File tree

7 files changed

+737
-35
lines changed

7 files changed

+737
-35
lines changed

llama-index-core/llama_index/core/llms/llm.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,25 @@ class Test(BaseModel):
442442
dispatcher.event(LLMStructuredPredictEndEvent(output=result))
443443
return result
444444

445+
def _structured_stream_call(
446+
self,
447+
output_cls: Type[Model],
448+
prompt: PromptTemplate,
449+
llm_kwargs: Optional[Dict[str, Any]] = None,
450+
**prompt_args: Any,
451+
) -> Generator[
452+
Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None, None
453+
]:
454+
from llama_index.core.program.utils import get_program_for_llm
455+
456+
program = get_program_for_llm(
457+
output_cls,
458+
prompt,
459+
self,
460+
pydantic_program_mode=self.pydantic_program_mode,
461+
)
462+
return program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
463+
445464
@dispatcher.span
446465
def stream_structured_predict(
447466
self,
@@ -484,28 +503,42 @@ class Test(BaseModel):
484503
```
485504
486505
"""
487-
from llama_index.core.program.utils import get_program_for_llm
488-
489506
dispatcher.event(
490507
LLMStructuredPredictStartEvent(
491508
output_cls=output_cls, template=prompt, template_args=prompt_args
492509
)
493510
)
494-
program = get_program_for_llm(
495-
output_cls,
496-
prompt,
497-
self,
498-
pydantic_program_mode=self.pydantic_program_mode,
499-
)
500511

501-
result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
512+
result = self._structured_stream_call(
513+
output_cls, prompt, llm_kwargs, **prompt_args
514+
)
502515
for r in result:
503516
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
504517
assert not isinstance(r, list)
505518
yield r
506519

507520
dispatcher.event(LLMStructuredPredictEndEvent(output=r))
508521

522+
async def _structured_astream_call(
523+
self,
524+
output_cls: Type[Model],
525+
prompt: PromptTemplate,
526+
llm_kwargs: Optional[Dict[str, Any]] = None,
527+
**prompt_args: Any,
528+
) -> AsyncGenerator[
529+
Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None
530+
]:
531+
from llama_index.core.program.utils import get_program_for_llm
532+
533+
program = get_program_for_llm(
534+
output_cls,
535+
prompt,
536+
self,
537+
pydantic_program_mode=self.pydantic_program_mode,
538+
)
539+
540+
return await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
541+
509542
@dispatcher.span
510543
async def astream_structured_predict(
511544
self,
@@ -550,23 +583,15 @@ class Test(BaseModel):
550583
"""
551584

552585
async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
553-
from llama_index.core.program.utils import (
554-
get_program_for_llm,
555-
)
556-
557586
dispatcher.event(
558587
LLMStructuredPredictStartEvent(
559588
output_cls=output_cls, template=prompt, template_args=prompt_args
560589
)
561590
)
562-
program = get_program_for_llm(
563-
output_cls,
564-
prompt,
565-
self,
566-
pydantic_program_mode=self.pydantic_program_mode,
567-
)
568591

569-
result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
592+
result = await self._structured_astream_call(
593+
output_cls, prompt, llm_kwargs, **prompt_args
594+
)
570595
async for r in result:
571596
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
572597
assert not isinstance(r, list)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
Simplified streaming utilities for processing structured outputs from message content.
3+
4+
This module provides utilities for processing streaming responses that contain
5+
structured data directly in the message content (not in function calls).
6+
"""
7+
8+
from typing import Optional, Type, Union
9+
10+
from pydantic import ValidationError
11+
12+
from llama_index.core.base.llms.types import ChatResponse
13+
from llama_index.core.program.utils import (
14+
FlexibleModel,
15+
_repair_incomplete_json,
16+
create_flexible_model,
17+
)
18+
from llama_index.core.types import Model
19+
20+
21+
def process_streaming_content_incremental(
22+
chat_response: ChatResponse,
23+
output_cls: Type[Model],
24+
cur_object: Optional[Union[Model, FlexibleModel]] = None,
25+
) -> Union[Model, FlexibleModel]:
26+
"""
27+
Process streaming response content with true incremental list handling.
28+
29+
This version can extract partial progress from incomplete JSON and build
30+
lists incrementally (e.g., 1 joke → 2 jokes → 3 jokes) rather than
31+
jumping from empty to complete lists.
32+
33+
Args:
34+
chat_response (ChatResponse): The chat response to process
35+
output_cls (Type[BaseModel]): The target output class
36+
cur_object (Optional[BaseModel]): Current best object (for comparison)
37+
flexible_mode (bool): Whether to use flexible schema during parsing
38+
39+
Returns:
40+
Union[BaseModel, FlexibleModel]: Processed object with incremental updates
41+
42+
"""
43+
partial_output_cls = create_flexible_model(output_cls)
44+
45+
# Get content from message
46+
content = chat_response.message.content
47+
if not content:
48+
return cur_object if cur_object is not None else partial_output_cls()
49+
try:
50+
parsed_obj = partial_output_cls.model_validate_json(content)
51+
except (ValidationError, ValueError):
52+
try:
53+
repaired_json = _repair_incomplete_json(content)
54+
parsed_obj = partial_output_cls.model_validate_json(repaired_json)
55+
except (ValidationError, ValueError):
56+
extracted_obj = _extract_partial_list_progress(
57+
content, output_cls, cur_object, partial_output_cls
58+
)
59+
parsed_obj = (
60+
extracted_obj if extracted_obj is not None else partial_output_cls()
61+
)
62+
63+
# If we still couldn't parse anything, use previous object
64+
if parsed_obj is None:
65+
if cur_object is not None:
66+
return cur_object
67+
else:
68+
return partial_output_cls()
69+
70+
# Use incremental comparison that considers list progress
71+
try:
72+
return output_cls.model_validate(parsed_obj.model_dump(exclude_unset=True))
73+
except ValidationError:
74+
return parsed_obj
75+
76+
77+
def _extract_partial_list_progress(
78+
content: str,
79+
output_cls: Type[Model],
80+
cur_object: Optional[Union[Model, FlexibleModel]],
81+
partial_output_cls: Type[FlexibleModel],
82+
) -> Optional[FlexibleModel]:
83+
"""
84+
Try to extract partial list progress from incomplete JSON.
85+
86+
This attempts to build upon the current object by detecting partial
87+
list additions even when JSON is malformed.
88+
"""
89+
if not isinstance(content, str) or cur_object is None:
90+
return None
91+
92+
try:
93+
import re
94+
95+
# Try to extract list patterns from incomplete JSON
96+
# Look for patterns like: "jokes": [{"setup": "...", "punchline": "..."}
97+
list_pattern = r'"(\w+)":\s*\[([^\]]*)'
98+
matches = re.findall(list_pattern, content)
99+
100+
if not matches:
101+
return None
102+
103+
# Start with current object data
104+
current_data = (
105+
cur_object.model_dump() if hasattr(cur_object, "model_dump") else {}
106+
)
107+
108+
for field_name, list_content in matches:
109+
if (
110+
hasattr(output_cls, "model_fields")
111+
and field_name in output_cls.model_fields
112+
):
113+
# Try to parse individual items from the list content
114+
items = _parse_partial_list_items(list_content, field_name, output_cls)
115+
if items:
116+
current_data[field_name] = items
117+
118+
# Try to create object with updated data
119+
return partial_output_cls.model_validate(current_data)
120+
121+
except Exception:
122+
return None
123+
124+
125+
def _parse_partial_list_items(
126+
list_content: str, field_name: str, output_cls: Type[Model]
127+
) -> list:
128+
"""
129+
Parse individual items from partial list content.
130+
"""
131+
try:
132+
import json
133+
import re
134+
135+
items = []
136+
137+
# Look for complete object patterns within the list
138+
# Pattern: {"key": "value", "key2": "value2"}
139+
object_pattern = r"\{[^{}]*\}"
140+
object_matches = re.findall(object_pattern, list_content)
141+
142+
for obj_str in object_matches:
143+
try:
144+
# Try to parse as complete JSON object
145+
obj_data = json.loads(obj_str)
146+
items.append(obj_data)
147+
except (json.JSONDecodeError, SyntaxError):
148+
# Try to repair and parse
149+
try:
150+
repaired = _repair_incomplete_json(obj_str)
151+
obj_data = json.loads(repaired)
152+
items.append(obj_data)
153+
except (json.JSONDecodeError, SyntaxError):
154+
continue
155+
156+
return items
157+
158+
except Exception:
159+
return []

0 commit comments

Comments
 (0)