1+ import pytest
2+ from unittest .mock import patch , MagicMock
3+ import io
4+ import time
5+
6+ from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
7+ from databricks .sql .auth .retry import DatabricksRetryPolicy
8+
9+ PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
10+
11+ def create_mock_conn (responses ):
12+ """Creates a mock connection object whose getresponse() method yields a series of responses."""
13+ mock_conn = MagicMock ()
14+ mock_http_responses = []
15+ for resp in responses :
16+ mock_http_response = MagicMock ()
17+ mock_http_response .status = resp .get ("status" )
18+ mock_http_response .headers = resp .get ("headers" , {})
19+ body = resp .get ("body" , b'{}' )
20+ mock_http_response .fp = io .BytesIO (body )
21+ def release ():
22+ mock_http_response .fp .close ()
23+ mock_http_response .release_conn = release
24+ mock_http_responses .append (mock_http_response )
25+ mock_conn .getresponse .side_effect = mock_http_responses
26+ return mock_conn
27+
28+ class TestTelemetryClientRetries :
29+ @pytest .fixture (autouse = True )
30+ def setup_and_teardown (self ):
31+ TelemetryClientFactory ._initialized = False
32+ TelemetryClientFactory ._clients = {}
33+ TelemetryClientFactory ._executor = None
34+ yield
35+ if TelemetryClientFactory ._executor :
36+ TelemetryClientFactory ._executor .shutdown (wait = True )
37+ TelemetryClientFactory ._initialized = False
38+ TelemetryClientFactory ._clients = {}
39+ TelemetryClientFactory ._executor = None
40+
41+ def get_client (self , session_id , num_retries = 3 ):
42+ """
43+ Configures a client with a specific number of retries.
44+ """
45+ TelemetryClientFactory .initialize_telemetry_client (
46+ telemetry_enabled = True ,
47+ session_id_hex = session_id ,
48+ auth_provider = None ,
49+ host_url = "test.databricks.com" ,
50+ )
51+ client = TelemetryClientFactory .get_telemetry_client (session_id )
52+
53+ retry_policy = DatabricksRetryPolicy (
54+ delay_min = 0.01 ,
55+ delay_max = 0.02 ,
56+ stop_after_attempts_duration = 2.0 ,
57+ stop_after_attempts_count = num_retries ,
58+ delay_default = 0.1 ,
59+ force_dangerous_codes = [],
60+ urllib3_kwargs = {'total' : num_retries }
61+ )
62+ adapter = client ._http_client .session .adapters .get ("https://" )
63+ adapter .max_retries = retry_policy
64+ return client
65+
66+ @pytest .mark .parametrize (
67+ "status_code, description" ,
68+ [
69+ (401 , "Unauthorized" ),
70+ (403 , "Forbidden" ),
71+ (501 , "Not Implemented" ),
72+ (200 , "Success" ),
73+ ],
74+ )
75+ def test_non_retryable_status_codes_are_not_retried (self , status_code , description ):
76+ """
77+ Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried.
78+ """
79+ # Use the status code in the session ID for easier debugging if it fails
80+ client = self .get_client (f"session-{ status_code } " )
81+ mock_responses = [{"status" : status_code }]
82+
83+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
84+ client .export_failure_log ("TestError" , "Test message" )
85+ TelemetryClientFactory .close (client ._session_id_hex )
86+
87+ mock_get_conn .return_value .getresponse .assert_called_once ()
88+
89+ def test_exceeds_retry_count_limit (self ):
90+ """
91+ Verifies that the client retries up to the specified number of times before giving up.
92+ Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
93+ """
94+ num_retries = 3
95+ expected_total_calls = num_retries + 1
96+ retry_after = 1
97+ client = self .get_client ("session-exceed-limit" , num_retries = num_retries )
98+ mock_responses = [{"status" : 503 , "headers" : {"Retry-After" : str (retry_after )}}, {"status" : 429 }, {"status" : 502 }, {"status" : 503 }]
99+
100+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
101+ start_time = time .time ()
102+ client .export_failure_log ("TestError" , "Test message" )
103+ TelemetryClientFactory .close (client ._session_id_hex )
104+ end_time = time .time ()
105+
106+ assert mock_get_conn .return_value .getresponse .call_count == expected_total_calls
107+ assert end_time - start_time > retry_after
0 commit comments