|
1 | 1 | import time |
2 | | -from unittest.mock import Mock |
| 2 | +from unittest.mock import Mock, patch |
3 | 3 | from uuid import uuid4 |
4 | 4 | from langchain_core.outputs import LLMResult, Generation |
5 | 5 | from opentelemetry.instrumentation.langchain.callback_handler import TraceloopCallbackHandler |
@@ -37,33 +37,37 @@ def test_ttft_metric_recorded_on_first_token(self): |
37 | 37 | mock_span = Mock(spec=Span) |
38 | 38 | mock_span.attributes = {SpanAttributes.LLM_SYSTEM: "Langchain"} |
39 | 39 |
|
40 | | - # Create span holder with specific start time |
41 | | - start_time = time.time() |
42 | | - span_holder = SpanHolder( |
43 | | - span=mock_span, |
44 | | - token=None, |
45 | | - context=None, |
46 | | - children=[], |
47 | | - workflow_name="test", |
48 | | - entity_name="test", |
49 | | - entity_path="test", |
50 | | - start_time=start_time |
51 | | - ) |
52 | | - self.handler.spans[run_id] = span_holder |
53 | | - |
54 | | - # Simulate first token arrival after a small delay |
55 | | - time.sleep(0.1) |
56 | | - self.handler.on_llm_new_token("Hello", run_id=run_id) |
57 | | - |
58 | | - # Verify TTFT metric was recorded |
59 | | - self.ttft_histogram.record.assert_called_once() |
60 | | - args = self.ttft_histogram.record.call_args |
61 | | - ttft_value = args[0][0] |
62 | | - assert ttft_value > 0.05, "TTFT should be greater than 0.05 seconds" |
63 | | - |
64 | | - # Verify attributes |
65 | | - attributes = args[1]["attributes"] |
66 | | - assert attributes[SpanAttributes.LLM_SYSTEM] == "Langchain" |
| 40 | + # Use mock time for stable testing |
| 41 | + with patch('opentelemetry.instrumentation.langchain.callback_handler.time.time') as mock_time, \ |
| 42 | + patch('opentelemetry.instrumentation.langchain.span_utils.time.time') as mock_span_time: |
| 43 | + |
| 44 | + start_time = 1000.0 |
| 45 | + mock_time.return_value = start_time |
| 46 | + mock_span_time.return_value = start_time |
| 47 | + |
| 48 | + span_holder = SpanHolder( |
| 49 | + span=mock_span, |
| 50 | + token=None, |
| 51 | + context=None, |
| 52 | + children=[], |
| 53 | + workflow_name="test", |
| 54 | + entity_name="test", |
| 55 | + entity_path="test", |
| 56 | + start_time=start_time |
| 57 | + ) |
| 58 | + self.handler.spans[run_id] = span_holder |
| 59 | + |
| 60 | + mock_time.return_value = start_time + 0.1 |
| 61 | + mock_span_time.return_value = start_time + 0.1 |
| 62 | + self.handler.on_llm_new_token("Hello", run_id=run_id) |
| 63 | + |
| 64 | + self.ttft_histogram.record.assert_called_once() |
| 65 | + args = self.ttft_histogram.record.call_args |
| 66 | + ttft_value = args[0][0] |
| 67 | + assert abs(ttft_value - 0.1) < 0.001, f"TTFT should be approximately 0.1 seconds, got {ttft_value}" |
| 68 | + |
| 69 | + attributes = args[1]["attributes"] |
| 70 | + assert attributes[SpanAttributes.LLM_SYSTEM] == "Langchain" |
67 | 71 |
|
68 | 72 | def test_ttft_metric_not_recorded_on_subsequent_tokens(self): |
69 | 73 | """Test that TTFT metric is only recorded once.""" |
@@ -134,37 +138,47 @@ def test_streaming_time_to_generate_metric(self): |
134 | 138 | mock_span = Mock(spec=Span) |
135 | 139 | mock_span.attributes = {SpanAttributes.LLM_SYSTEM: "Langchain"} |
136 | 140 |
|
137 | | - start_time = time.time() |
138 | | - span_holder = SpanHolder( |
139 | | - span=mock_span, |
140 | | - token=None, |
141 | | - context=None, |
142 | | - children=[], |
143 | | - workflow_name="test", |
144 | | - entity_name="test", |
145 | | - entity_path="test", |
146 | | - start_time=start_time |
147 | | - ) |
148 | | - self.handler.spans[run_id] = span_holder |
149 | | - |
150 | | - # Simulate token arrival |
151 | | - time.sleep(0.05) |
152 | | - self.handler.on_llm_new_token("Hello", run_id=run_id) |
153 | | - |
154 | | - # Simulate completion after more time |
155 | | - time.sleep(0.05) |
156 | | - llm_result = LLMResult( |
157 | | - generations=[[Generation(text="Hello world")]], |
158 | | - llm_output={"model_name": "test-model"} |
159 | | - ) |
160 | | - |
161 | | - self.handler.on_llm_end(llm_result, run_id=run_id) |
162 | | - |
163 | | - # Verify streaming time metric was recorded |
164 | | - self.streaming_time_histogram.record.assert_called_once() |
165 | | - args = self.streaming_time_histogram.record.call_args |
166 | | - streaming_time = args[0][0] |
167 | | - assert streaming_time > 0.04, "Streaming time should be greater than 0.04 seconds" |
| 141 | + with patch('opentelemetry.instrumentation.langchain.callback_handler.time.time') as mock_time, \ |
| 142 | + patch('opentelemetry.instrumentation.langchain.span_utils.time.time') as mock_span_time: |
| 143 | + |
| 144 | + start_time = 1000.0 |
| 145 | + mock_time.return_value = start_time |
| 146 | + mock_span_time.return_value = start_time |
| 147 | + |
| 148 | + span_holder = SpanHolder( |
| 149 | + span=mock_span, |
| 150 | + token=None, |
| 151 | + context=None, |
| 152 | + children=[], |
| 153 | + workflow_name="test", |
| 154 | + entity_name="test", |
| 155 | + entity_path="test", |
| 156 | + start_time=start_time |
| 157 | + ) |
| 158 | + self.handler.spans[run_id] = span_holder |
| 159 | + |
| 160 | + first_token_time = start_time + 0.05 |
| 161 | + mock_time.return_value = first_token_time |
| 162 | + mock_span_time.return_value = first_token_time |
| 163 | + self.handler.on_llm_new_token("Hello", run_id=run_id) |
| 164 | + |
| 165 | + completion_time = first_token_time + 0.05 |
| 166 | + mock_time.return_value = completion_time |
| 167 | + mock_span_time.return_value = completion_time |
| 168 | + llm_result = LLMResult( |
| 169 | + generations=[[Generation(text="Hello world")]], |
| 170 | + llm_output={"model_name": "test-model"} |
| 171 | + ) |
| 172 | + |
| 173 | + self.handler.on_llm_end(llm_result, run_id=run_id) |
| 174 | + |
| 175 | + self.streaming_time_histogram.record.assert_called_once() |
| 176 | + args = self.streaming_time_histogram.record.call_args |
| 177 | + streaming_time = args[0][0] |
| 178 | + assert abs(streaming_time - 0.05) < 0.001, ( |
| 179 | + f"Streaming time should be approximately 0.05 seconds, " |
| 180 | + f"got {streaming_time}" |
| 181 | + ) |
168 | 182 |
|
169 | 183 | def test_exception_metric_recorded_on_error(self): |
170 | 184 | """Test that exception metric is recorded on LLM errors.""" |
|
0 commit comments