Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,23 @@ for chunk in llm.stream("Tell me what happened to the Dinosaurs?"):
```

More features coming soon.

## Structured Output (Pydantic)

You can ask the model to return JSON that matches a Pydantic model and
have the client validate and return typed objects using `with_structured_output()`:

```python
from pydantic import BaseModel
from langchain_gradient import ChatGradient

class Person(BaseModel):
name: str
age: int
email: str

llm = ChatGradient(model="llama3.3-70b-instruct", api_key="your_key")
structured = llm.with_structured_output(Person)
person = structured.invoke(["Create a person named John, age 30, email [email protected]"])
print(person)
```
93 changes: 92 additions & 1 deletion langchain_gradient/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""LangchainDigitalocean chat models."""

import os
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union, Type
import json

from pydantic import BaseModel, ValidationError

from gradient import Gradient
from langchain_core.callbacks import (
Expand Down Expand Up @@ -361,3 +364,91 @@ def __getstate__(self) -> dict:

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)

def with_structured_output(
self,
response_model: Type[BaseModel],
*,
multiple: bool = False,
response_format: Optional[Any] = None,
) -> "StructuredChatGradient":
"""Return a lightweight wrapper around this ChatGradient that will parse
and validate the model output into the provided Pydantic ``response_model``.

Parameters
----------
response_model:
A Pydantic model class (subclass of BaseModel) to validate the AI output.
multiple:
If True, expect the model to return a JSON array of objects matching
``response_model`` and return a list of validated models.
response_format:
Arbitrary metadata to be forwarded to the underlying LLM invocation
(for example instructions about the desired response format). This
wrapper will pass it through when invoking the underlying model if
not overridden at call time.
"""

return StructuredChatGradient(
llm=self,
response_model=response_model,
multiple=multiple,
response_format=response_format,
)


class StructuredChatGradient:
"""A small wrapper that invokes a ChatGradient and parses/validates
its output into Pydantic models.

This keeps the core ChatGradient implementation unchanged while providing
a convenient typed interface for consumers who want structured outputs.
"""

def __init__(
self,
llm: ChatGradient,
response_model: Type[BaseModel],
multiple: bool = False,
response_format: Optional[Any] = None,
) -> None:
self.llm = llm
self.response_model = response_model
self.multiple = multiple
self.response_format = response_format

def invoke(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
# Allow callers to override response_format via kwargs
rf = kwargs.pop("response_format", self.response_format)
# If response_format exists, append it to the messages as a system hint
if rf:
# inject formatting instruction as a system message at the start
messages = [BaseMessage(content=str(rf))] + messages # type: ignore
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a BaseMessage directly is incorrect. BaseMessage is likely an abstract base class from langchain_core.messages. Use a concrete message type like SystemMessage instead: from langchain_core.messages import SystemMessage and messages = [SystemMessage(content=str(rf))] + messages.

Copilot uses AI. Check for mistakes.

result = self.llm.invoke(messages, **kwargs)

# result.content is expected to be a string containing JSON
raw = getattr(result, "content", result)
if not isinstance(raw, str):
raise ValueError("LLM returned non-string content for structured output")

try:
parsed = json.loads(raw)
except Exception as e: # noqa: BLE001 - we want to catch JSON errors
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching broad Exception is too permissive. Catch json.JSONDecodeError specifically to handle JSON parsing errors while allowing other exceptions to propagate normally.

Suggested change
except Exception as e: # noqa: BLE001 - we want to catch JSON errors
except json.JSONDecodeError as e:

Copilot uses AI. Check for mistakes.
raise ValueError(f"Failed to parse JSON from model output: {e}\nRaw output: {raw}")

try:
if self.multiple:
if not isinstance(parsed, list):
raise ValueError("Expected JSON array for multiple=True structured output")
return [self.response_model.parse_obj(item) for item in parsed]
else:
# For single objects, accept dict or single-element list
if isinstance(parsed, list) and len(parsed) == 1:
parsed = parsed[0]
if not isinstance(parsed, dict):
raise ValueError("Expected JSON object for structured output")
return self.response_model.parse_obj(parsed)
except ValidationError as ve:
# Provide clear messages about which fields failed validation
raise ValueError(f"Validation error when parsing model output: {ve}")
66 changes: 66 additions & 0 deletions tests/unit_tests/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Unit tests for structured output parsing and validation."""

from typing import List

import pytest
from pydantic import BaseModel

from langchain_gradient.chat_models import ChatGradient


class Person(BaseModel):
name: str
age: int
email: str


class DummyLLM(ChatGradient):
"""A tiny fake LLM that returns a pre-canned response for testing."""

def __init__(self, content: str, **kwargs):
super().__init__(**kwargs)
self._content = content

def invoke(self, messages: List, **kwargs):
# mimic the object returned by ChatGradient.invoke used in code
class R:
def __init__(self, content):
self.content = content

return R(self._content)


def test_single_structured_output_success():
json_str = '{"name": "John", "age": 30, "email": "[email protected]"}'
llm = DummyLLM(content=json_str)
structured = llm.with_structured_output(Person)
person = structured.invoke(messages=["prompt"])
assert isinstance(person, Person)
assert person.name == "John"


def test_multiple_structured_output_success():
json_str = '[{"name": "Alice", "age": 25, "email": "[email protected]"}, {"name": "Bob", "age": 28, "email": "[email protected]"}]'
llm = DummyLLM(content=json_str)
structured = llm.with_structured_output(Person, multiple=True)
people = structured.invoke(messages=["prompt"])
assert isinstance(people, list)
assert all(isinstance(p, Person) for p in people)
assert people[0].name == "Alice"


def test_invalid_json_raises():
llm = DummyLLM(content="not a json")
structured = llm.with_structured_output(Person)
with pytest.raises(ValueError):
structured.invoke(messages=["prompt"])


def test_validation_error_raises():
# Missing age field
json_str = '{"name": "John", "email": "[email protected]"}'
llm = DummyLLM(content=json_str)
structured = llm.with_structured_output(Person)
with pytest.raises(ValueError) as e:
structured.invoke(messages=["prompt"])
assert "Validation error" in str(e.value)