-
Couldn't load subscription status.
- Fork 14
Implement Structured Output with Pydantic Models #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,3 +61,23 @@ for chunk in llm.stream("Tell me what happened to the Dinosaurs?"): | |
| ``` | ||
|
|
||
| More features coming soon. | ||
|
|
||
| ## Structured Output (Pydantic) | ||
|
|
||
| You can ask the model to return JSON that matches a Pydantic model and | ||
| have the client validate and return typed objects using `with_structured_output()`: | ||
|
|
||
| ```python | ||
| from pydantic import BaseModel | ||
| from langchain_gradient import ChatGradient | ||
|
|
||
| class Person(BaseModel): | ||
| name: str | ||
| age: int | ||
| email: str | ||
|
|
||
| llm = ChatGradient(model="llama3.3-70b-instruct", api_key="your_key") | ||
| structured = llm.with_structured_output(Person) | ||
| person = structured.invoke(["Create a person named John, age 30, email [email protected]"]) | ||
| print(person) | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,10 @@ | ||||||
| """LangchainDigitalocean chat models.""" | ||||||
|
|
||||||
| import os | ||||||
| from typing import Any, Dict, Iterator, List, Optional, Union | ||||||
| from typing import Any, Dict, Iterator, List, Optional, Union, Type | ||||||
| import json | ||||||
|
|
||||||
| from pydantic import BaseModel, ValidationError | ||||||
|
|
||||||
| from gradient import Gradient | ||||||
| from langchain_core.callbacks import ( | ||||||
|
|
@@ -361,3 +364,91 @@ def __getstate__(self) -> dict: | |||||
|
|
||||||
| def __setstate__(self, state: dict) -> None: | ||||||
| self.__dict__.update(state) | ||||||
|
|
||||||
| def with_structured_output( | ||||||
| self, | ||||||
| response_model: Type[BaseModel], | ||||||
| *, | ||||||
| multiple: bool = False, | ||||||
| response_format: Optional[Any] = None, | ||||||
| ) -> "StructuredChatGradient": | ||||||
| """Return a lightweight wrapper around this ChatGradient that will parse | ||||||
| and validate the model output into the provided Pydantic ``response_model``. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| response_model: | ||||||
| A Pydantic model class (subclass of BaseModel) to validate the AI output. | ||||||
| multiple: | ||||||
| If True, expect the model to return a JSON array of objects matching | ||||||
| ``response_model`` and return a list of validated models. | ||||||
| response_format: | ||||||
| Arbitrary metadata to be forwarded to the underlying LLM invocation | ||||||
| (for example instructions about the desired response format). This | ||||||
| wrapper will pass it through when invoking the underlying model if | ||||||
| not overridden at call time. | ||||||
| """ | ||||||
|
|
||||||
| return StructuredChatGradient( | ||||||
| llm=self, | ||||||
| response_model=response_model, | ||||||
| multiple=multiple, | ||||||
| response_format=response_format, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class StructuredChatGradient: | ||||||
| """A small wrapper that invokes a ChatGradient and parses/validates | ||||||
| its output into Pydantic models. | ||||||
|
|
||||||
| This keeps the core ChatGradient implementation unchanged while providing | ||||||
| a convenient typed interface for consumers who want structured outputs. | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| llm: ChatGradient, | ||||||
| response_model: Type[BaseModel], | ||||||
| multiple: bool = False, | ||||||
| response_format: Optional[Any] = None, | ||||||
| ) -> None: | ||||||
| self.llm = llm | ||||||
| self.response_model = response_model | ||||||
| self.multiple = multiple | ||||||
| self.response_format = response_format | ||||||
|
|
||||||
| def invoke(self, messages: List[BaseMessage], **kwargs: Any) -> Any: | ||||||
| # Allow callers to override response_format via kwargs | ||||||
| rf = kwargs.pop("response_format", self.response_format) | ||||||
| # If response_format exists, append it to the messages as a system hint | ||||||
| if rf: | ||||||
| # inject formatting instruction as a system message at the start | ||||||
| messages = [BaseMessage(content=str(rf))] + messages # type: ignore | ||||||
|
|
||||||
| result = self.llm.invoke(messages, **kwargs) | ||||||
|
|
||||||
| # result.content is expected to be a string containing JSON | ||||||
| raw = getattr(result, "content", result) | ||||||
| if not isinstance(raw, str): | ||||||
| raise ValueError("LLM returned non-string content for structured output") | ||||||
|
|
||||||
| try: | ||||||
| parsed = json.loads(raw) | ||||||
| except Exception as e: # noqa: BLE001 - we want to catch JSON errors | ||||||
|
||||||
| except Exception as e: # noqa: BLE001 - we want to catch JSON errors | |
| except json.JSONDecodeError as e: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| """Unit tests for structured output parsing and validation.""" | ||
|
|
||
| from typing import List | ||
|
|
||
| import pytest | ||
| from pydantic import BaseModel | ||
|
|
||
| from langchain_gradient.chat_models import ChatGradient | ||
|
|
||
|
|
||
| class Person(BaseModel): | ||
| name: str | ||
| age: int | ||
| email: str | ||
|
|
||
|
|
||
| class DummyLLM(ChatGradient): | ||
| """A tiny fake LLM that returns a pre-canned response for testing.""" | ||
|
|
||
| def __init__(self, content: str, **kwargs): | ||
| super().__init__(**kwargs) | ||
| self._content = content | ||
|
|
||
| def invoke(self, messages: List, **kwargs): | ||
| # mimic the object returned by ChatGradient.invoke used in code | ||
| class R: | ||
| def __init__(self, content): | ||
| self.content = content | ||
|
|
||
| return R(self._content) | ||
|
|
||
|
|
||
| def test_single_structured_output_success(): | ||
| json_str = '{"name": "John", "age": 30, "email": "[email protected]"}' | ||
| llm = DummyLLM(content=json_str) | ||
| structured = llm.with_structured_output(Person) | ||
| person = structured.invoke(messages=["prompt"]) | ||
| assert isinstance(person, Person) | ||
| assert person.name == "John" | ||
|
|
||
|
|
||
| def test_multiple_structured_output_success(): | ||
| json_str = '[{"name": "Alice", "age": 25, "email": "[email protected]"}, {"name": "Bob", "age": 28, "email": "[email protected]"}]' | ||
| llm = DummyLLM(content=json_str) | ||
| structured = llm.with_structured_output(Person, multiple=True) | ||
| people = structured.invoke(messages=["prompt"]) | ||
| assert isinstance(people, list) | ||
| assert all(isinstance(p, Person) for p in people) | ||
| assert people[0].name == "Alice" | ||
|
|
||
|
|
||
| def test_invalid_json_raises(): | ||
| llm = DummyLLM(content="not a json") | ||
| structured = llm.with_structured_output(Person) | ||
| with pytest.raises(ValueError): | ||
| structured.invoke(messages=["prompt"]) | ||
|
|
||
|
|
||
| def test_validation_error_raises(): | ||
| # Missing age field | ||
| json_str = '{"name": "John", "email": "[email protected]"}' | ||
| llm = DummyLLM(content=json_str) | ||
| structured = llm.with_structured_output(Person) | ||
| with pytest.raises(ValueError) as e: | ||
| structured.invoke(messages=["prompt"]) | ||
| assert "Validation error" in str(e.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a
BaseMessagedirectly is incorrect.BaseMessageis likely an abstract base class from langchain_core.messages. Use a concrete message type likeSystemMessageinstead:from langchain_core.messages import SystemMessageandmessages = [SystemMessage(content=str(rf))] + messages.