77
88from datadog_lambda .patch import (
99 _patch_httplib ,
10- _patch_requests ,
10+ _ensure_patch_requests ,
1111)
1212from datadog_lambda .constants import TraceHeader
1313
1414
1515class TestPatchHTTPClients (unittest .TestCase ):
1616
1717 def setUp (self ):
18- patcher = patch (' datadog_lambda.patch.get_dd_trace_context' )
18+ patcher = patch (" datadog_lambda.patch.get_dd_trace_context" )
1919 self .mock_get_dd_trace_context = patcher .start ()
2020 self .mock_get_dd_trace_context .return_value = {
21- TraceHeader .TRACE_ID : ' 123' ,
22- TraceHeader .PARENT_ID : ' 321' ,
23- TraceHeader .SAMPLING_PRIORITY : '2' ,
21+ TraceHeader .TRACE_ID : " 123" ,
22+ TraceHeader .PARENT_ID : " 321" ,
23+ TraceHeader .SAMPLING_PRIORITY : "2" ,
2424 }
2525 self .addCleanup (patcher .stop )
2626
@@ -34,10 +34,29 @@ def test_patch_httplib(self):
3434 self .mock_get_dd_trace_context .assert_called ()
3535
3636 def test_patch_requests (self ):
37- _patch_requests ()
37+ _ensure_patch_requests ()
3838 import requests
3939 r = requests .get ("https://www.datadoghq.com/" )
4040 self .mock_get_dd_trace_context .assert_called ()
41- self .assertEqual (r .request .headers [TraceHeader .TRACE_ID ], '123' )
42- self .assertEqual (r .request .headers [TraceHeader .PARENT_ID ], '321' )
43- self .assertEqual (r .request .headers [TraceHeader .SAMPLING_PRIORITY ], '2' )
41+ self .assertEqual (r .request .headers [TraceHeader .TRACE_ID ], "123" )
42+ self .assertEqual (r .request .headers [TraceHeader .PARENT_ID ], "321" )
43+ self .assertEqual (r .request .headers [TraceHeader .SAMPLING_PRIORITY ], "2" )
44+
45+ def test_patch_requests_with_headers (self ):
46+ _ensure_patch_requests ()
47+ import requests
48+ r = requests .get ("https://www.datadoghq.com/" , headers = {"key" : "value" })
49+ self .mock_get_dd_trace_context .assert_called ()
50+ self .assertEqual (r .request .headers ["key" ], "value" )
51+ self .assertEqual (r .request .headers [TraceHeader .TRACE_ID ], "123" )
52+ self .assertEqual (r .request .headers [TraceHeader .PARENT_ID ], "321" )
53+ self .assertEqual (r .request .headers [TraceHeader .SAMPLING_PRIORITY ], "2" )
54+
55+ def test_patch_requests_with_headers_none (self ):
56+ _ensure_patch_requests ()
57+ import requests
58+ r = requests .get ("https://www.datadoghq.com/" , headers = None )
59+ self .mock_get_dd_trace_context .assert_called ()
60+ self .assertEqual (r .request .headers [TraceHeader .TRACE_ID ], "123" )
61+ self .assertEqual (r .request .headers [TraceHeader .PARENT_ID ], "321" )
62+ self .assertEqual (r .request .headers [TraceHeader .SAMPLING_PRIORITY ], "2" )
0 commit comments