66
77
88class TestDsmContext (unittest .TestCase ):
9+ def setUp (self ):
10+ patcher = patch ("datadog_lambda.dsm._dsm_set_sqs_context" )
11+ self .mock_dsm_set_sqs_context = patcher .start ()
12+ self .addCleanup (patcher .stop )
13+
14+ patcher = patch ("ddtrace.internal.datastreams.data_streams_processor" )
15+ self .mock_data_streams_processor = patcher .start ()
16+ self .addCleanup (patcher .stop )
17+
18+ patcher = patch ("ddtrace.internal.datastreams.botocore.get_datastreams_context" )
19+ self .mock_get_datastreams_context = patcher .start ()
20+ self .mock_get_datastreams_context .return_value = {}
21+ self .addCleanup (patcher .stop )
22+
23+ patcher = patch (
24+ "ddtrace.internal.datastreams.botocore.calculate_sqs_payload_size"
25+ )
26+ self .mock_calculate_sqs_payload_size = patcher .start ()
27+ self .mock_calculate_sqs_payload_size .return_value = 100
28+ self .addCleanup (patcher .stop )
29+
30+ patcher = patch ("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" )
31+ self .mock_dsm_pathway_codec_decode = patcher .start ()
32+ self .addCleanup (patcher .stop )
33+
934 def test_non_sqs_event_source_does_nothing (self ):
1035 """Test that non-SQS event sources don't trigger DSM context setting"""
1136 event = {"Records" : [{"body" : "test" }]}
1237
1338 mock_event_source = MagicMock ()
1439 mock_event_source .equals .return_value = False # Not SQS
1540
16- with patch ("datadog_lambda.dsm._dsm_set_sqs_context" ) as mock_sqs_context :
17- set_dsm_context (event , mock_event_source )
41+ set_dsm_context (event , mock_event_source )
1842
19- mock_event_source .equals .assert_called_once_with (EventTypes .SQS )
20- mock_sqs_context .assert_not_called ()
43+ mock_event_source .equals .assert_called_once_with (EventTypes .SQS )
44+ self . mock_dsm_set_sqs_context .assert_not_called ()
2145
2246 def test_event_with_no_records_does_nothing (self ):
2347 """Test that events where Records is None don't trigger DSM processing"""
@@ -28,12 +52,8 @@ def test_event_with_no_records_does_nothing(self):
2852 ]
2953
3054 for event in events_with_no_records :
31- with patch (
32- "ddtrace.internal.datastreams.data_streams_processor"
33- ) as mock_processor :
34- _dsm_set_sqs_context (event )
35-
36- mock_processor .assert_not_called ()
55+ _dsm_set_sqs_context (event )
56+ self .mock_data_streams_processor .assert_not_called ()
3757
3858 def test_sqs_event_triggers_dsm_sqs_context (self ):
3959 """Test that SQS event sources trigger the SQS-specific DSM context function"""
@@ -50,10 +70,9 @@ def test_sqs_event_triggers_dsm_sqs_context(self):
5070 mock_event_source = MagicMock ()
5171 mock_event_source .equals .return_value = True
5272
53- with patch ("datadog_lambda.dsm._dsm_set_sqs_context" ) as mock_sqs_context :
54- set_dsm_context (sqs_event , mock_event_source )
73+ set_dsm_context (sqs_event , mock_event_source )
5574
56- mock_sqs_context .assert_called_once_with (sqs_event )
75+ self . mock_dsm_set_sqs_context .assert_called_once_with (sqs_event )
5776
5877 def test_multiple_records_process_each_record (self ):
5978 """Test that each record in an SQS event gets processed individually"""
@@ -74,40 +93,24 @@ def test_multiple_records_process_each_record(self):
7493 ]
7594 }
7695
77- mock_processor = MagicMock ()
7896 mock_context = MagicMock ()
97+ self .mock_dsm_pathway_codec_decode .return_value = mock_context
98+
99+ _dsm_set_sqs_context (multi_record_event )
100+
101+ self .assertEqual (mock_context .set_checkpoint .call_count , 3 )
102+
103+ calls = mock_context .set_checkpoint .call_args_list
104+ expected_arns = [
105+ "arn:aws:sqs:us-east-1:123456789012:queue1" ,
106+ "arn:aws:sqs:us-east-1:123456789012:queue2" ,
107+ "arn:aws:sqs:us-east-1:123456789012:queue3" ,
108+ ]
79109
80- with patch (
81- "ddtrace.internal.datastreams.data_streams_processor" ,
82- return_value = mock_processor ,
83- ):
84- with patch (
85- "ddtrace.internal.datastreams.botocore.get_datastreams_context" ,
86- return_value = {},
87- ):
88- with patch (
89- "ddtrace.internal.datastreams.botocore.calculate_sqs_payload_size" ,
90- return_value = 100 ,
91- ):
92- with patch (
93- "ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" ,
94- return_value = mock_context ,
95- ):
96- _dsm_set_sqs_context (multi_record_event )
97-
98- assert mock_context .set_checkpoint .call_count == 3
99-
100- calls = mock_context .set_checkpoint .call_args_list
101- expected_arns = [
102- "arn:aws:sqs:us-east-1:123456789012:queue1" ,
103- "arn:aws:sqs:us-east-1:123456789012:queue2" ,
104- "arn:aws:sqs:us-east-1:123456789012:queue3" ,
105- ]
106-
107- for i , call in enumerate (calls ):
108- args , kwargs = call
109- tags = args [0 ]
110- self .assertIn ("direction:in" , tags )
111- self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
112- self .assertIn ("type:sqs" , tags )
113- self .assertEqual (kwargs ["payload_size" ], 100 )
110+ for i , call in enumerate (calls ):
111+ args , kwargs = call
112+ tags = args [0 ]
113+ self .assertIn ("direction:in" , tags )
114+ self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
115+ self .assertIn ("type:sqs" , tags )
116+ self .assertEqual (kwargs ["payload_size" ], 100 )
0 commit comments