99 List ,
1010 Optional ,
1111 Any ,
12+ Dict ,
1213 Callable ,
14+ TypeVar ,
15+ Generic ,
1316 cast ,
17+ TYPE_CHECKING ,
1418)
1519
20+ from databricks .sql .backend .types import ExecuteResponse , CommandId
21+ from databricks .sql .backend .sea .models .base import ResultData
1622from databricks .sql .backend .sea .backend import SeaDatabricksClient
17- from databricks .sql .backend .types import ExecuteResponse
1823
19- from databricks .sql .result_set import ResultSet , SeaResultSet
24+ if TYPE_CHECKING :
25+ from databricks .sql .result_set import ResultSet , SeaResultSet
2026
2127logger = logging .getLogger (__name__ )
2228
2329
2430class ResultSetFilter :
2531 """
26- A general-purpose filter for result sets.
32+ A general-purpose filter for result sets that can be applied to any backend.
33+
34+ This class provides methods to filter result sets based on various criteria,
35+ similar to the client-side filtering in the JDBC connector.
2736 """
2837
2938 @staticmethod
3039 def _filter_sea_result_set (
31- result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
32- ) -> SeaResultSet :
40+ result_set : " SeaResultSet" , filter_func : Callable [[List [Any ]], bool ]
41+ ) -> " SeaResultSet" :
3342 """
3443 Filter a SEA result set using the provided filter function.
3544
@@ -40,13 +49,15 @@ def _filter_sea_result_set(
4049 Returns:
4150 A filtered SEA result set
4251 """
43-
4452 # Get all remaining rows
4553 all_rows = result_set .results .remaining_rows ()
4654
4755 # Filter rows
4856 filtered_rows = [row for row in all_rows if filter_func (row )]
4957
58+ # Import SeaResultSet here to avoid circular imports
59+ from databricks .sql .result_set import SeaResultSet
60+
5061 # Reuse the command_id from the original result set
5162 command_id = result_set .command_id
5263
@@ -62,13 +73,10 @@ def _filter_sea_result_set(
6273 )
6374
6475 # Create a new ResultData object with filtered data
65-
6676 from databricks .sql .backend .sea .models .base import ResultData
6777
6878 result_data = ResultData (data = filtered_rows , external_links = None )
6979
70- from databricks .sql .result_set import SeaResultSet
71-
7280 # Create a new SeaResultSet with the filtered data
7381 filtered_result_set = SeaResultSet (
7482 connection = result_set .connection ,
@@ -83,11 +91,11 @@ def _filter_sea_result_set(
8391
8492 @staticmethod
8593 def filter_by_column_values (
86- result_set : ResultSet ,
94+ result_set : " ResultSet" ,
8795 column_index : int ,
8896 allowed_values : List [str ],
8997 case_sensitive : bool = False ,
90- ) -> ResultSet :
98+ ) -> " ResultSet" :
9199 """
92100 Filter a result set by values in a specific column.
93101
@@ -100,7 +108,6 @@ def filter_by_column_values(
100108 Returns:
101109 A filtered result set
102110 """
103-
104111 # Convert to uppercase for case-insensitive comparison if needed
105112 if not case_sensitive :
106113 allowed_values = [v .upper () for v in allowed_values ]
@@ -131,8 +138,8 @@ def filter_by_column_values(
131138
132139 @staticmethod
133140 def filter_tables_by_type (
134- result_set : ResultSet , table_types : Optional [List [str ]] = None
135- ) -> ResultSet :
141+ result_set : " ResultSet" , table_types : Optional [List [str ]] = None
142+ ) -> " ResultSet" :
136143 """
137144 Filter a result set of tables by the specified table types.
138145
@@ -147,7 +154,6 @@ def filter_tables_by_type(
147154 Returns:
148155 A filtered result set containing only tables of the specified types
149156 """
150-
151157 # Default table types if none specified
152158 DEFAULT_TABLE_TYPES = ["TABLE" , "VIEW" , "SYSTEM TABLE" ]
153159 valid_types = (
0 commit comments