1+ import inspect
2+ import pytest
3+
4+ from databricks .sql .thrift_api .TCLIService import ttypes
5+
6+
7+ class TestThriftFieldIds :
8+ """
9+ Unit test to validate that all Thrift-generated field IDs comply with the maximum limit.
10+
11+ Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
12+ and ensure compatibility with various Thrift implementations and protocols.
13+ """
14+
15+ MAX_ALLOWED_FIELD_ID = 3329
16+
17+ # Known exceptions that exceed the field ID limit
18+ KNOWN_EXCEPTIONS = {
19+ ('TExecuteStatementReq' , 'enforceEmbeddedSchemaCorrectness' ): 3353 ,
20+ ('TSessionHandle' , 'serverProtocolVersion' ): 3329 ,
21+ }
22+
23+ def test_all_thrift_field_ids_are_within_allowed_range (self ):
24+ """
25+ Validates that all field IDs in Thrift-generated classes are within the allowed range.
26+
27+ This test prevents field ID conflicts and ensures compatibility with different
28+ Thrift implementations and protocols.
29+ """
30+ violations = []
31+
32+ # Get all classes from the ttypes module
33+ for name , obj in inspect .getmembers (ttypes ):
34+ if (inspect .isclass (obj ) and
35+ hasattr (obj , 'thrift_spec' ) and
36+ obj .thrift_spec is not None ):
37+
38+ self ._check_class_field_ids (obj , name , violations )
39+
40+ if violations :
41+ error_message = self ._build_error_message (violations )
42+ pytest .fail (error_message )
43+
44+ def _check_class_field_ids (self , cls , class_name , violations ):
45+ """
46+ Checks all field IDs in a Thrift class and reports violations.
47+
48+ Args:
49+ cls: The Thrift class to check
50+ class_name: Name of the class for error reporting
51+ violations: List to append violation messages to
52+ """
53+ thrift_spec = cls .thrift_spec
54+
55+ if not isinstance (thrift_spec , (tuple , list )):
56+ return
57+
58+ for spec_entry in thrift_spec :
59+ if spec_entry is None :
60+ continue
61+
62+ # Thrift spec format: (field_id, field_type, field_name, ...)
63+ if isinstance (spec_entry , (tuple , list )) and len (spec_entry ) >= 3 :
64+ field_id = spec_entry [0 ]
65+ field_name = spec_entry [2 ]
66+
67+ # Skip known exceptions
68+ if (class_name , field_name ) in self .KNOWN_EXCEPTIONS :
69+ continue
70+
71+ if isinstance (field_id , int ) and field_id >= self .MAX_ALLOWED_FIELD_ID :
72+ violations .append (
73+ "{} field '{}' has field ID {} (exceeds maximum of {})" .format (
74+ class_name , field_name , field_id , self .MAX_ALLOWED_FIELD_ID - 1
75+ )
76+ )
77+
78+ def _build_error_message (self , violations ):
79+ """
80+ Builds a comprehensive error message for field ID violations.
81+
82+ Args:
83+ violations: List of violation messages
84+
85+ Returns:
86+ Formatted error message
87+ """
88+ error_message = (
89+ "Found Thrift field IDs that exceed the maximum allowed value of {}.\n "
90+ "This can cause compatibility issues and conflicts with reserved ID ranges.\n "
91+ "Violations found:\n " .format (self .MAX_ALLOWED_FIELD_ID - 1 )
92+ )
93+
94+ for violation in violations :
95+ error_message += " - {}\n " .format (violation )
96+
97+ return error_message
0 commit comments