Skip to content
Merged
202 changes: 137 additions & 65 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import contextvars
import itertools
import warnings
from collections import OrderedDict
from functools import wraps
import sys

import sentry_sdk
from sentry_sdk.ai.monitoring import set_ai_pipeline_name
Expand Down Expand Up @@ -73,6 +75,45 @@
}


# Contextvar to track agent names in a stack for re-entrant agent support
_agent_stack = contextvars.ContextVar("langchain_agent_stack", default=None) # type: contextvars.ContextVar[Optional[List[Optional[str]]]]


def _push_agent(agent_name):
# type: (Optional[str]) -> None
"""Push an agent name onto the stack."""
stack = _agent_stack.get()
if stack is None:
stack = []
else:
# Copy the list to maintain contextvar isolation across async contexts
stack = stack.copy()
stack.append(agent_name)
_agent_stack.set(stack)


def _pop_agent():
# type: () -> Optional[str]
"""Pop an agent name from the stack and return it."""
stack = _agent_stack.get()
if stack:
# Copy the list to maintain contextvar isolation across async contexts
stack = stack.copy()
agent_name = stack.pop()
_agent_stack.set(stack)
return agent_name
return None


def _get_current_agent():
# type: () -> Optional[str]
"""Get the current agent name (top of stack) without removing it."""
stack = _agent_stack.get()
if stack:
return stack[-1]
return None


class LangchainIntegration(Integration):
identifier = "langchain"
origin = f"auto.ai.{identifier}"
Expand Down Expand Up @@ -283,6 +324,10 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
elif "openai" in ai_type:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")

agent_name = _get_current_agent()
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

for key, attribute in DATA_FIELDS.items():
if key in all_params and all_params[key] is not None:
set_data_normalized(span, attribute, all_params[key], unpack=False)
Expand Down Expand Up @@ -435,6 +480,10 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
if tool_description is not None:
span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_description)

agent_name = _get_current_agent()
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

if should_send_default_pii() and self.include_prompts:
set_data_normalized(
span,
Expand Down Expand Up @@ -763,45 +812,50 @@ def new_invoke(self, *args, **kwargs):
name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
origin=LangchainIntegration.origin,
) as span:
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)

_set_tools_on_span(span, tools)

# Run the agent
result = f(self, *args, **kwargs)

input = result.get("input")
if (
input is not None
and should_send_default_pii()
and integration.include_prompts
):
normalized_messages = normalize_message_roles([input])
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(
normalized_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
messages_data,
unpack=False,
_push_agent(agent_name)
try:
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)

_set_tools_on_span(span, tools)

# Run the agent
result = f(self, *args, **kwargs)

input = result.get("input")
if (
input is not None
and should_send_default_pii()
and integration.include_prompts
):
normalized_messages = normalize_message_roles([input])
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(
normalized_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
messages_data,
unpack=False,
)

output = result.get("output")
if (
output is not None
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
output = result.get("output")
if (
output is not None
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)

return result
return result
finally:
# Ensure agent is popped even if an exception occurs
_pop_agent()

return new_invoke

Expand All @@ -821,11 +875,13 @@ def new_stream(self, *args, **kwargs):

span = start_span_function(
op=OP.GEN_AI_INVOKE_AGENT,
name=f"invoke_agent {agent_name}".strip(),
name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
origin=LangchainIntegration.origin,
)
span.__enter__()

_push_agent(agent_name)

if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

Expand Down Expand Up @@ -860,41 +916,57 @@ def new_stream(self, *args, **kwargs):

def new_iterator():
# type: () -> Iterator[Any]
for event in old_iterator:
yield event

exc_info = (None, None, None) # type: tuple[Any, Any, Any]
try:
output = event.get("output")
except Exception:
output = None

if (
output is not None
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
for event in old_iterator:
yield event

span.__exit__(None, None, None)
try:
output = event.get("output")
except Exception:
output = None

if (
output is not None
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
except Exception:
exc_info = sys.exc_info()
set_span_errored(span)
raise
finally:
# Ensure cleanup happens even if iterator is abandoned or fails
_pop_agent()
span.__exit__(*exc_info)

async def new_iterator_async():
# type: () -> AsyncIterator[Any]
async for event in old_iterator:
yield event

exc_info = (None, None, None) # type: tuple[Any, Any, Any]
try:
output = event.get("output")
except Exception:
output = None

if (
output is not None
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
async for event in old_iterator:
yield event

span.__exit__(None, None, None)
try:
output = event.get("output")
except Exception:
output = None

if (
output is not None
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
except Exception:
exc_info = sys.exc_info()
set_span_errored(span)
raise
finally:
# Ensure cleanup happens even if iterator is abandoned or fails
_pop_agent()
span.__exit__(*exc_info)

if str(type(result)) == "<class 'async_generator'>":
result = new_iterator_async()
Expand Down
32 changes: 14 additions & 18 deletions tests/integrations/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _llm_type(self) -> str:
return llm_type


@pytest.mark.xfail
@pytest.mark.parametrize(
"send_default_pii, include_prompts, use_unknown_llm_type",
[
Expand Down Expand Up @@ -202,29 +201,26 @@ def test_langchain_agent(
# We can't guarantee anything about the "shape" of the langchain execution graph
assert len(list(x for x in tx["spans"] if x["op"] == "gen_ai.chat")) > 0

assert "gen_ai.usage.input_tokens" in chat_spans[0]["data"]
assert "gen_ai.usage.output_tokens" in chat_spans[0]["data"]
assert "gen_ai.usage.total_tokens" in chat_spans[0]["data"]
# Token usage is only available in newer versions of langchain (v0.2+)
# where usage_metadata is supported on AIMessageChunk
if "gen_ai.usage.input_tokens" in chat_spans[0]["data"]:
assert chat_spans[0]["data"]["gen_ai.usage.input_tokens"] == 142
assert chat_spans[0]["data"]["gen_ai.usage.output_tokens"] == 50
assert chat_spans[0]["data"]["gen_ai.usage.total_tokens"] == 192

assert chat_spans[0]["data"]["gen_ai.usage.input_tokens"] == 142
assert chat_spans[0]["data"]["gen_ai.usage.output_tokens"] == 50
assert chat_spans[0]["data"]["gen_ai.usage.total_tokens"] == 192

assert "gen_ai.usage.input_tokens" in chat_spans[1]["data"]
assert "gen_ai.usage.output_tokens" in chat_spans[1]["data"]
assert "gen_ai.usage.total_tokens" in chat_spans[1]["data"]
assert chat_spans[1]["data"]["gen_ai.usage.input_tokens"] == 89
assert chat_spans[1]["data"]["gen_ai.usage.output_tokens"] == 28
assert chat_spans[1]["data"]["gen_ai.usage.total_tokens"] == 117
if "gen_ai.usage.input_tokens" in chat_spans[1]["data"]:
assert chat_spans[1]["data"]["gen_ai.usage.input_tokens"] == 89
assert chat_spans[1]["data"]["gen_ai.usage.output_tokens"] == 28
assert chat_spans[1]["data"]["gen_ai.usage.total_tokens"] == 117

if send_default_pii and include_prompts:
assert (
"You are very powerful"
in chat_spans[0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
)
assert "5" in chat_spans[0]["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
assert "word" in tool_exec_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert 5 == int(tool_exec_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT])
assert "word" in tool_exec_span["data"][SPANDATA.GEN_AI_TOOL_INPUT]
assert 5 == int(tool_exec_span["data"][SPANDATA.GEN_AI_TOOL_OUTPUT])
assert (
"You are very powerful"
in chat_spans[1]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
Expand All @@ -248,8 +244,8 @@ def test_langchain_agent(
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[0].get("data", {})
assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in chat_spans[1].get("data", {})
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[1].get("data", {})
assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in tool_exec_span.get("data", {})
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in tool_exec_span.get("data", {})
assert SPANDATA.GEN_AI_TOOL_INPUT not in tool_exec_span.get("data", {})
assert SPANDATA.GEN_AI_TOOL_OUTPUT not in tool_exec_span.get("data", {})

# Verify tool calls are NOT recorded when PII is disabled
assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in chat_spans[0].get(
Expand Down
Loading