From 928e12876d9f179f7211a5fb8a2211da53122f9e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sun, 13 Jul 2025 21:26:38 +0530 Subject: [PATCH 01/19] chunk download latency Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/thrift_backend.py | 22 +++++--- .../sql/cloudfetch/download_manager.py | 23 ++++++--- src/databricks/sql/cloudfetch/downloader.py | 13 ++++- src/databricks/sql/result_set.py | 4 ++ .../sql/telemetry/latency_logger.py | 51 ++++++++++++++++--- src/databricks/sql/telemetry/models/event.py | 5 +- src/databricks/sql/utils.py | 12 ++++- 7 files changed, 105 insertions(+), 25 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 226db8986..4d678bae6 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -6,6 +6,7 @@ import time import threading from typing import List, Optional, Union, Any, TYPE_CHECKING +from uuid import UUID from databricks.sql.result_set import ThriftResultSet @@ -1021,7 +1022,7 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( + execute_response, is_direct_results, statement_id = self._handle_execute_response( resp, cursor ) @@ -1040,6 +1041,8 @@ def execute_command( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, + statement_id=statement_id, ) def get_catalogs( @@ -1061,7 +1064,7 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, is_direct_results, _ = self._handle_execute_response( resp, cursor ) @@ -1107,7 +1110,7 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, is_direct_results, _ = self._handle_execute_response( resp, cursor ) @@ -1157,7 +1160,7 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, is_direct_results, _ = self._handle_execute_response( resp, cursor ) @@ -1207,7 +1210,7 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, is_direct_results, _ = self._handle_execute_response( resp, cursor ) @@ -1241,7 +1244,11 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + execute_response, is_direct_results = self._results_message_to_execute_response( + resp, final_operation_state + ) + + return execute_response, is_direct_results, cursor.active_command_id.to_hex_guid() def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1261,6 +1268,7 @@ def fetch_results( arrow_schema_bytes, description, use_cloud_fetch=True, + statement_id=None, ): thrift_handle = command_id.to_thrift_handle() if not thrift_handle: @@ -1297,6 +1305,8 @@ def fetch_results( lz4_compressed=lz4_compressed, description=description, ssl_options=self._ssl_options, + session_id_hex=self._session_id_hex, + statement_id=statement_id ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 7e96cd323..5954014c7 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,7 +1,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, Future -from typing import List, Union +from typing import List, Union, Optional, Tuple from databricks.sql.cloudfetch.downloader import ( ResultSetDownloadHandler, @@ -22,17 +22,19 @@ def __init__( max_download_threads: int, lz4_compressed: bool, ssl_options: SSLOptions, + session_id_hex: Optional[str] = None, + statement_id: Optional[str] = None, ): - self._pending_links: List[TSparkArrowResultLink] = [] - for link in links: + self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] + for i, link in enumerate(links): if link.rowCount <= 0: continue logger.debug( - "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( - link.startRowOffset, link.rowCount + "ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format( + i, link.startRowOffset, link.rowCount ) ) - self._pending_links.append(link) + self._pending_links.append((i, link)) self._download_tasks: List[Future[DownloadedFile]] = [] self._max_download_threads: int = max_download_threads @@ -40,6 +42,8 @@ def __init__( self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self._ssl_options = ssl_options + self.session_id_hex = session_id_hex + self.statement_id = statement_id def get_next_downloaded_file( self, next_row_offset: int @@ -89,14 +93,17 @@ def _schedule_downloads(self): while (len(self._download_tasks) < self._max_download_threads) and ( len(self._pending_links) > 0 ): - link = self._pending_links.pop(0) + chunk_id, link = self._pending_links.pop(0) logger.debug( - "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) + "- chunk: {}, start: {}, row count: {}".format(chunk_id, link.startRowOffset, link.rowCount) ) handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, link=link, ssl_options=self._ssl_options, + chunk_id=chunk_id, + session_id_hex=self.session_id_hex, + statement_id=self.statement_id ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..68c310af6 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from typing import Optional import requests from requests.adapters import HTTPAdapter, Retry @@ -9,6 +10,7 @@ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions +from databricks.sql.telemetry.latency_logger import log_latency logger = logging.getLogger(__name__) @@ -66,11 +68,18 @@ def __init__( settings: DownloadableResultSettings, link: TSparkArrowResultLink, ssl_options: SSLOptions, + chunk_id: int, + session_id_hex: Optional[str] = None, + statement_id: Optional[str] = None, ): self.settings = settings self.link = link self._ssl_options = ssl_options + self.chunk_id = chunk_id + self.session_id_hex = session_id_hex + self.statement_id = statement_id + @log_latency() def run(self) -> DownloadedFile: """ Download the file described in the cloud fetch link. @@ -80,8 +89,8 @@ def run(self) -> DownloadedFile: """ logger.debug( - "ResultSetDownloadHandler: starting file download, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format( + self.chunk_id, self.link.startRowOffset, self.link.rowCount ) ) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 8934d0d56..3595d65d7 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -198,6 +198,8 @@ def __init__( max_download_threads: int = 10, ssl_options=None, is_direct_results: bool = True, + session_id_hex: Optional[str] = None, + statement_id: Optional[str] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -233,6 +235,8 @@ def __init__( lz4_compressed=execute_response.lz4_compressed, description=execute_response.description, ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=statement_id, ) # Call parent constructor with common attributes diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 0b0c564da..9ce9c96b7 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -7,7 +7,6 @@ SqlExecutionEvent, ) from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType -from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue from uuid import UUID logger = logging.getLogger(__name__) @@ -42,6 +41,9 @@ def get_execution_result(self): def get_retry_count(self): pass + def get_chunk_id(self): + pass + class CursorExtractor(TelemetryExtractor): """ @@ -63,7 +65,8 @@ def get_is_compressed(self) -> bool: def get_execution_result(self) -> ExecutionResultFormat: if self.active_result_set is None: return ExecutionResultFormat.FORMAT_UNSPECIFIED - + + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue if isinstance(self.active_result_set.results, ColumnQueue): return ExecutionResultFormat.COLUMNAR_INLINE elif isinstance(self.active_result_set.results, CloudFetchQueue): @@ -74,11 +77,14 @@ def get_execution_result(self) -> ExecutionResultFormat: def get_retry_count(self) -> int: if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy + hasattr(self.backend, "retry_policy") + and self.backend.retry_policy ): - return len(self.thrift_backend.retry_policy.history) + return len(self.backend.retry_policy.history) return 0 + + def get_chunk_id(self): + return None class ResultSetExtractor(TelemetryExtractor): @@ -101,6 +107,7 @@ def get_is_compressed(self) -> bool: return self.lz4_compressed def get_execution_result(self) -> ExecutionResultFormat: + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue if isinstance(self.results, ColumnQueue): return ExecutionResultFormat.COLUMNAR_INLINE elif isinstance(self.results, CloudFetchQueue): @@ -116,7 +123,34 @@ def get_retry_count(self) -> int: ): return len(self.thrift_backend.retry_policy.history) return 0 + + def get_chunk_id(self): + return None + + +class ResultSetDownloadHandlerExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSetDownloadHandler objects. + """ + def get_session_id_hex(self) -> Optional[str]: + return self._obj.session_id_hex + + def get_statement_id(self) -> Optional[str]: + return self._obj.statement_id + + def get_is_compressed(self) -> bool: + return self._obj.settings.is_lz4_compressed + + def get_execution_result(self) -> ExecutionResultFormat: + return ExecutionResultFormat.EXTERNAL_LINKS + + def get_retry_count(self) -> Optional[int]: + # standard requests and urllib3 libraries don't expose retry count + return None + def get_chunk_id(self) -> Optional[int]: + return self._obj.chunk_id + def get_extractor(obj): """ @@ -133,12 +167,15 @@ def get_extractor(obj): TelemetryExtractor: A specialized extractor instance: - CursorExtractor for Cursor objects - ResultSetExtractor for ResultSet objects + - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - None for all other objects """ if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) elif obj.__class__.__name__ == "ResultSet": return ResultSetExtractor(obj) + elif obj.__class__.__name__=="ResultSetDownloadHandler": + return ResultSetDownloadHandlerExtractor(obj) else: logger.debug("No extractor found for %s", obj.__class__.__name__) return None @@ -196,6 +233,7 @@ def _safe_call(func_to_call): duration_ms = int((end_time - start_time) * 1000) extractor = get_extractor(self) + print("function name", func.__name__, "latency", duration_ms, "session_id_hex", extractor.get_session_id_hex(), "statement_id", extractor.get_statement_id(), flush=True) if extractor is not None: session_id_hex = _safe_call(extractor.get_session_id_hex) @@ -205,7 +243,8 @@ def _safe_call(func_to_call): statement_type=statement_type, is_compressed=_safe_call(extractor.get_is_compressed), execution_result=_safe_call(extractor.get_execution_result), - retry_count=_safe_call(extractor.get_retry_count), + retry_count=extractor.get_retry_count(), + chunk_id=_safe_call(extractor.get_chunk_id), ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index f5496deec..fb5c1d090 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -122,13 +122,14 @@ class SqlExecutionEvent(JsonSerializableMixin): is_compressed (bool): Whether the result is compressed execution_result (ExecutionResultFormat): Format of the execution result retry_count (int): Number of retry attempts made + chunk_id (int): ID of the chunk if applicable """ statement_type: StatementType is_compressed: bool execution_result: ExecutionResultFormat - retry_count: int - + retry_count: Optional[int] + chunk_id: Optional[int] @dataclass class TelemetryEvent(JsonSerializableMixin): diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 35c7bce4d..b96617536 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -62,6 +62,8 @@ def build_queue( ssl_options: SSLOptions, lz4_compressed: bool = True, description: List[Tuple] = [], + session_id_hex: Optional[str] = None, + statement_id: Optional[str] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -106,6 +108,8 @@ def build_queue( description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=statement_id, ) else: raise AssertionError("Row set type is not valid") @@ -211,6 +215,8 @@ def __init__( result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, description: List[Tuple] = [], + session_id_hex: Optional[str] = None, + statement_id: Optional[str] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. @@ -231,7 +237,9 @@ def __init__( self.lz4_compressed = lz4_compressed self.description = description self._ssl_options = ssl_options - + self.session_id_hex = session_id_hex + self.statement_id = statement_id + logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( start_row_offset @@ -249,6 +257,8 @@ def __init__( max_download_threads=self.max_download_threads, lz4_compressed=self.lz4_compressed, ssl_options=self._ssl_options, + session_id_hex=self.session_id_hex, + statement_id=self.statement_id, ) self.table = self._create_next_table() From 804f9e0529f2d8c4ecc88a7a1a8431ecdd9abdd2 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sun, 13 Jul 2025 21:42:43 +0530 Subject: [PATCH 02/19] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/thrift_backend.py | 16 ++++++---- .../sql/cloudfetch/download_manager.py | 6 ++-- .../sql/telemetry/latency_logger.py | 30 ++++++++++++------- src/databricks/sql/telemetry/models/event.py | 1 + src/databricks/sql/utils.py | 2 +- 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 4d678bae6..324a7cd6e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1022,9 +1022,11 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results, statement_id = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + statement_id, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1248,7 +1250,11 @@ def _handle_execute_response(self, resp, cursor): resp, final_operation_state ) - return execute_response, is_direct_results, cursor.active_command_id.to_hex_guid() + return ( + execute_response, + is_direct_results, + cursor.active_command_id.to_hex_guid(), + ) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1306,7 +1312,7 @@ def fetch_results( description=description, ssl_options=self._ssl_options, session_id_hex=self._session_id_hex, - statement_id=statement_id + statement_id=statement_id, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 5954014c7..832a6f2ea 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -95,7 +95,9 @@ def _schedule_downloads(self): ): chunk_id, link = self._pending_links.pop(0) logger.debug( - "- chunk: {}, start: {}, row count: {}".format(chunk_id, link.startRowOffset, link.rowCount) + "- chunk: {}, start: {}, row count: {}".format( + chunk_id, link.startRowOffset, link.rowCount + ) ) handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, @@ -103,7 +105,7 @@ def _schedule_downloads(self): ssl_options=self._ssl_options, chunk_id=chunk_id, session_id_hex=self.session_id_hex, - statement_id=self.statement_id + statement_id=self.statement_id, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 9ce9c96b7..c38b90f1d 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -65,8 +65,9 @@ def get_is_compressed(self) -> bool: def get_execution_result(self) -> ExecutionResultFormat: if self.active_result_set is None: return ExecutionResultFormat.FORMAT_UNSPECIFIED - + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + if isinstance(self.active_result_set.results, ColumnQueue): return ExecutionResultFormat.COLUMNAR_INLINE elif isinstance(self.active_result_set.results, CloudFetchQueue): @@ -76,13 +77,10 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.FORMAT_UNSPECIFIED def get_retry_count(self) -> int: - if ( - hasattr(self.backend, "retry_policy") - and self.backend.retry_policy - ): + if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: return len(self.backend.retry_policy.history) return 0 - + def get_chunk_id(self): return None @@ -108,6 +106,7 @@ def get_is_compressed(self) -> bool: def get_execution_result(self) -> ExecutionResultFormat: from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + if isinstance(self.results, ColumnQueue): return ExecutionResultFormat.COLUMNAR_INLINE elif isinstance(self.results, CloudFetchQueue): @@ -123,7 +122,7 @@ def get_retry_count(self) -> int: ): return len(self.thrift_backend.retry_policy.history) return 0 - + def get_chunk_id(self): return None @@ -132,6 +131,7 @@ class ResultSetDownloadHandlerExtractor(TelemetryExtractor): """ Telemetry extractor specialized for ResultSetDownloadHandler objects. """ + def get_session_id_hex(self) -> Optional[str]: return self._obj.session_id_hex @@ -150,7 +150,7 @@ def get_retry_count(self) -> Optional[int]: def get_chunk_id(self) -> Optional[int]: return self._obj.chunk_id - + def get_extractor(obj): """ @@ -174,7 +174,7 @@ def get_extractor(obj): return CursorExtractor(obj) elif obj.__class__.__name__ == "ResultSet": return ResultSetExtractor(obj) - elif obj.__class__.__name__=="ResultSetDownloadHandler": + elif obj.__class__.__name__ == "ResultSetDownloadHandler": return ResultSetDownloadHandlerExtractor(obj) else: logger.debug("No extractor found for %s", obj.__class__.__name__) @@ -233,7 +233,17 @@ def _safe_call(func_to_call): duration_ms = int((end_time - start_time) * 1000) extractor = get_extractor(self) - print("function name", func.__name__, "latency", duration_ms, "session_id_hex", extractor.get_session_id_hex(), "statement_id", extractor.get_statement_id(), flush=True) + print( + "function name", + func.__name__, + "latency", + duration_ms, + "session_id_hex", + extractor.get_session_id_hex(), + "statement_id", + extractor.get_statement_id(), + flush=True, + ) if extractor is not None: session_id_hex = _safe_call(extractor.get_session_id_hex) diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index fb5c1d090..83f72cd3b 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -131,6 +131,7 @@ class SqlExecutionEvent(JsonSerializableMixin): retry_count: Optional[int] chunk_id: Optional[int] + @dataclass class TelemetryEvent(JsonSerializableMixin): """ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b96617536..044038b2b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -239,7 +239,7 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id - + logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( start_row_offset From e1119092ace70f44747b6d3031ab389bae815060 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sun, 13 Jul 2025 23:26:32 +0530 Subject: [PATCH 03/19] test fixes Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/latency_logger.py | 2 +- tests/unit/test_downloader.py | 14 +++++------ tests/unit/test_thrift_backend.py | 23 +++++++++++-------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index c38b90f1d..022c254f8 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -253,7 +253,7 @@ def _safe_call(func_to_call): statement_type=statement_type, is_compressed=_safe_call(extractor.get_is_compressed), execution_result=_safe_call(extractor.get_execution_result), - retry_count=extractor.get_retry_count(), + retry_count=_safe_call(extractor.get_retry_count), chunk_id=_safe_call(extractor.get_chunk_id), ) diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 2a3b715b5..d6ee28984 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -27,7 +27,7 @@ def test_run_link_expired(self, mock_time): # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) with self.assertRaises(Error) as context: @@ -43,7 +43,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) with self.assertRaises(Error) as context: @@ -63,7 +63,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): result_link = Mock(expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() @@ -82,7 +82,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) file = d.run() @@ -105,7 +105,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) file = d.run() @@ -121,7 +121,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) with self.assertRaises(ConnectionError): d.run() @@ -136,7 +136,7 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0 ) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 37569f755..abfb13a02 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -649,7 +649,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, _, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -892,6 +892,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ( execute_response, _, + _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, @@ -927,7 +928,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock(return_value=(Mock(), Mock())) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -965,7 +966,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, _, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -997,7 +998,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1046,6 +1047,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ( execute_response, has_more_rows_result, + _ ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(is_direct_results, has_more_rows_result) @@ -1179,7 +1181,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1215,7 +1217,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1248,7 +1250,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1290,7 +1292,7 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1336,7 +1338,7 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1682,7 +1684,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock(return_value=(Mock(), Mock())) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2254,6 +2256,7 @@ def test_execute_command_sets_complex_type_fields_correctly( mock_handle_execute_response.return_value = ( mock_execute_response, mock_arrow_schema, + Mock() ) # Iterate through each possible combination of native types (True, False and unset) From 88d8a956cc90c0915efbe0c6a8d7db2856a7f46a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sun, 13 Jul 2025 23:33:53 +0530 Subject: [PATCH 04/19] sea-migration static type checking fixes Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/databricks_client.py | 3 ++- src/databricks/sql/backend/thrift_backend.py | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 0337d8d06..cff4720a4 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -79,9 +79,10 @@ def execute_command( lz4_compression: bool, cursor: Cursor, use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: Union[List[ttypes.TSparkParameter], List[Dict[str, Any]]], async_op: bool, enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, ) -> Union[ResultSet, None]: """ Executes a SQL command or query within the specified session. diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 324a7cd6e..1b88987f1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -240,10 +240,6 @@ def __init__( def max_download_threads(self) -> int: return self._max_download_threads - @property - def max_download_threads(self) -> int: - return self._max_download_threads - # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): # Configure retries & timing: use user-settings or defaults, and bound From e84ba7d97b9fcf08c75e42696e800b25df864bd9 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sun, 13 Jul 2025 23:42:42 +0530 Subject: [PATCH 05/19] check types fix Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/sea/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index cfb27adbd..5a3295aa5 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -14,6 +14,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -402,7 +403,7 @@ def execute_command( lz4_compression: bool, cursor: Cursor, use_cloud_fetch: bool, - parameters: List[Dict[str, Any]], + parameters: Union[List[Dict[str, Any]], List["ttypes.TSparkParameter"]], async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, From 4e07f0276e9bf13be32cd56169ed70b0778f29b4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 14 Jul 2025 09:26:17 +0530 Subject: [PATCH 06/19] fix type issues Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 1 + src/databricks/sql/backend/sea/backend.py | 11 +++++++---- tests/unit/test_sea_backend.py | 9 ++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index cff4720a4..84e169993 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -101,6 +101,7 @@ def execute_command( parameters: List of parameters to bind to the query async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness + row_limit: Maximum number of rows in the response. Returns: If async_op is False, returns a ResultSet object containing the diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 5a3295aa5..d94799322 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -15,6 +15,7 @@ MetadataCommands, ) from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -403,7 +404,7 @@ def execute_command( lz4_compression: bool, cursor: Cursor, use_cloud_fetch: bool, - parameters: Union[List[Dict[str, Any]], List["ttypes.TSparkParameter"]], + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, @@ -438,9 +439,11 @@ def execute_command( for param in parameters: sea_parameters.append( StatementParameter( - name=param["name"], - value=param["value"], - type=param["type"] if "type" in param else None, + name=param.name, + value=( + param.value.stringValue if param.value is not None else None + ), + type=param.type, ) ) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 7eae8e5a8..da45b4299 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -13,6 +13,8 @@ _filter_session_configuration, ) from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.parameters.native import IntegerParameter, TDbsqlParameter +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.exc import ( @@ -355,7 +357,8 @@ def test_command_execution_advanced( "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - param = {"name": "param1", "value": "value1", "type": "STRING"} + dbsql_param = IntegerParameter(name="param1", value=1) + param = dbsql_param.as_tspark_param(named=True) with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( @@ -374,8 +377,8 @@ def test_command_execution_advanced( assert "parameters" in kwargs["data"] assert len(kwargs["data"]["parameters"]) == 1 assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" + assert kwargs["data"]["parameters"][0]["value"] == "1" + assert kwargs["data"]["parameters"][0]["type"] == "INT" # Test execution failure mock_http_client.reset_mock() From c054039aca624514566eaeab807ab6b453b64168 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 14 Jul 2025 09:39:59 +0530 Subject: [PATCH 07/19] type fix revert Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/databricks_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 84e169993..fb276251a 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -79,7 +79,7 @@ def execute_command( lz4_compression: bool, cursor: Cursor, use_cloud_fetch: bool, - parameters: Union[List[ttypes.TSparkParameter], List[Dict[str, Any]]], + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, From 3e0bb1e1e86931e369e5ca2ee7ffcd96b6050c97 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 14 Jul 2025 09:46:35 +0530 Subject: [PATCH 08/19] - Signed-off-by: Sai Shree Pradhan From a6c690bdf8f4dc18a8902dc5ab1af725ed85be4c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 14 Jul 2025 10:20:25 +0530 Subject: [PATCH 09/19] statement id in get metadata functions Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/thrift_backend.py | 40 ++++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1b88987f1..441a23c0b 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1062,9 +1062,11 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results, _ = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + statement_id, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1081,6 +1083,8 @@ def get_catalogs( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, + statement_id=statement_id, ) def get_schemas( @@ -1108,9 +1112,11 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results, _ = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + statement_id, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1127,6 +1133,8 @@ def get_schemas( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, + statement_id=statement_id, ) def get_tables( @@ -1158,9 +1166,11 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results, _ = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + statement_id, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1177,6 +1187,8 @@ def get_tables( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, + statement_id=statement_id, ) def get_columns( @@ -1208,9 +1220,11 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results, _ = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + statement_id, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1227,6 +1241,8 @@ def get_columns( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, + statement_id=statement_id, ) def _handle_execute_response(self, resp, cursor): From 534538498e14546186e143949a990964528aae62 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 14 Jul 2025 14:01:36 +0530 Subject: [PATCH 10/19] removed result set extractor Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/latency_logger.py | 60 +------------------ 1 file changed, 2 insertions(+), 58 deletions(-) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 022c254f8..b7f1a8fe6 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -85,48 +85,6 @@ def get_chunk_id(self): return None -class ResultSetExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSet objects. - - Extracts telemetry information from database result set objects, including - operation IDs, session information, compression settings, and result formats. - """ - - def get_statement_id(self) -> Optional[str]: - if self.command_id: - return str(UUID(bytes=self.command_id.operationId.guid)) - return None - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.lz4_compressed - - def get_execution_result(self) -> ExecutionResultFormat: - from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue - - if isinstance(self.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) - return 0 - - def get_chunk_id(self): - return None - - class ResultSetDownloadHandlerExtractor(TelemetryExtractor): """ Telemetry extractor specialized for ResultSetDownloadHandler objects. @@ -160,20 +118,17 @@ def get_extractor(obj): that can extract telemetry information from that object type. Args: - obj: The object to create an extractor for. Can be a Cursor, ResultSet, - or any other object. + obj: The object to create an extractor for. Can be a Cursor, + ResultSetDownloadHandler, or any other object. Returns: TelemetryExtractor: A specialized extractor instance: - CursorExtractor for Cursor objects - - ResultSetExtractor for ResultSet objects - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - None for all other objects """ if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSet": - return ResultSetExtractor(obj) elif obj.__class__.__name__ == "ResultSetDownloadHandler": return ResultSetDownloadHandlerExtractor(obj) else: @@ -233,17 +188,6 @@ def _safe_call(func_to_call): duration_ms = int((end_time - start_time) * 1000) extractor = get_extractor(self) - print( - "function name", - func.__name__, - "latency", - duration_ms, - "session_id_hex", - extractor.get_session_id_hex(), - "statement_id", - extractor.get_statement_id(), - flush=True, - ) if extractor is not None: session_id_hex = _safe_call(extractor.get_session_id_hex) From a0318e342a880add4c8231f3b56d00b2009fab20 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 14 Jul 2025 14:18:14 +0530 Subject: [PATCH 11/19] databricks client type Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 2 +- src/databricks/sql/session.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 75e89d92a..c04d460cc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -280,7 +280,7 @@ def read(self) -> Optional[OAuthToken]: driver_connection_params = DriverConnectionParameters( http_path=http_path, - mode=DatabricksClientType.THRIFT, + mode=DatabricksClientType.SEA if self.session.use_sea else DatabricksClientType.THRIFT, host_info=HostDetails(host_url=server_hostname, port=self.session.port), auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 4f59857e9..6b8aa4f5f 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -97,10 +97,10 @@ def _create_backend( kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" - use_sea = kwargs.get("use_sea", False) + self.use_sea = kwargs.get("use_sea", False) databricks_client_class: Type[DatabricksClient] - if use_sea: + if self.use_sea: logger.debug("Creating SEA backend client") databricks_client_class = SeaDatabricksClient else: From e79c325e64d06b7086a005898c1082a5ca7f6892 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 14 Jul 2025 14:19:22 +0530 Subject: [PATCH 12/19] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c04d460cc..c279f2c1f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -280,7 +280,9 @@ def read(self) -> Optional[OAuthToken]: driver_connection_params = DriverConnectionParameters( http_path=http_path, - mode=DatabricksClientType.SEA if self.session.use_sea else DatabricksClientType.THRIFT, + mode=DatabricksClientType.SEA + if self.session.use_sea + else DatabricksClientType.THRIFT, host_info=HostDetails(host_url=server_hostname, port=self.session.port), auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), From 6d122c4c4f1b559d884f452743f0642401fb1666 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 15 Jul 2025 15:34:53 +0530 Subject: [PATCH 13/19] remove defaults, fix chunk id Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/thrift_backend.py | 83 +++++++++---------- src/databricks/sql/client.py | 35 +++++--- .../sql/cloudfetch/download_manager.py | 14 ++-- src/databricks/sql/cloudfetch/downloader.py | 7 +- src/databricks/sql/result_set.py | 18 +++- .../sql/telemetry/latency_logger.py | 22 +++-- src/databricks/sql/utils.py | 20 +++-- tests/unit/test_client.py | 17 ++-- tests/unit/test_cloud_fetch_queue.py | 58 ++++++++++++- tests/unit/test_download_manager.py | 6 +- tests/unit/test_downloader.py | 14 ++-- tests/unit/test_fetches.py | 10 ++- tests/unit/test_thrift_backend.py | 46 +++++----- 13 files changed, 236 insertions(+), 114 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 441a23c0b..df4c31da1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet + from databricks.sql.telemetry.models.event import StatementType from databricks.sql.backend.types import ( CommandState, @@ -832,7 +833,7 @@ def _results_message_to_execute_response(self, resp, operation_state): return execute_response, is_direct_results def get_execution_result( - self, command_id: CommandId, cursor: "Cursor" + self, command_id: CommandId, cursor: "Cursor", statement_type: StatementType ) -> "ResultSet": thrift_handle = command_id.to_thrift_handle() if not thrift_handle: @@ -900,6 +901,8 @@ def get_execution_result( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + session_id_hex=self._session_id_hex, + statement_type=statement_type, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -965,6 +968,7 @@ def execute_command( max_bytes: int, lz4_compression: bool, cursor: Cursor, + statement_type: StatementType, use_cloud_fetch=True, parameters=[], async_op=False, @@ -1018,11 +1022,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - ( - execute_response, - is_direct_results, - statement_id, - ) = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1040,7 +1042,7 @@ def execute_command( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_id=statement_id, + statement_type=statement_type, ) def get_catalogs( @@ -1049,6 +1051,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", + statement_type: StatementType, ) -> "ResultSet": thrift_handle = session_id.to_thrift_handle() if not thrift_handle: @@ -1062,11 +1065,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - ( - execute_response, - is_direct_results, - statement_id, - ) = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1084,7 +1085,7 @@ def get_catalogs( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_id=statement_id, + statement_id=statement_type, ) def get_schemas( @@ -1093,6 +1094,7 @@ def get_schemas( max_rows: int, max_bytes: int, cursor: Cursor, + statement_type: StatementType, catalog_name=None, schema_name=None, ) -> "ResultSet": @@ -1112,11 +1114,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - ( - execute_response, - is_direct_results, - statement_id, - ) = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1134,7 +1134,7 @@ def get_schemas( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_id=statement_id, + statement_type=statement_type, ) def get_tables( @@ -1143,6 +1143,7 @@ def get_tables( max_rows: int, max_bytes: int, cursor: Cursor, + statement_type: StatementType, catalog_name=None, schema_name=None, table_name=None, @@ -1166,11 +1167,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - ( - execute_response, - is_direct_results, - statement_id, - ) = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1188,7 +1187,7 @@ def get_tables( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_id=statement_id, + statement_type=statement_type, ) def get_columns( @@ -1197,6 +1196,7 @@ def get_columns( max_rows: int, max_bytes: int, cursor: Cursor, + statement_type: StatementType, catalog_name=None, schema_name=None, table_name=None, @@ -1220,11 +1220,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - ( - execute_response, - is_direct_results, - statement_id, - ) = self._handle_execute_response(resp, cursor) + execute_response, is_direct_results = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1242,7 +1240,7 @@ def get_columns( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_id=statement_id, + statement_type=statement_type, ) def _handle_execute_response(self, resp, cursor): @@ -1258,15 +1256,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - execute_response, is_direct_results = self._results_message_to_execute_response( - resp, final_operation_state - ) - - return ( - execute_response, - is_direct_results, - cursor.active_command_id.to_hex_guid(), - ) + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1285,8 +1275,9 @@ def fetch_results( lz4_compressed: bool, arrow_schema_bytes, description, + statement_type, + chunk_id: int, use_cloud_fetch=True, - statement_id=None, ): thrift_handle = command_id.to_thrift_handle() if not thrift_handle: @@ -1324,10 +1315,16 @@ def fetch_results( description=description, ssl_options=self._ssl_options, session_id_hex=self._session_id_hex, - statement_id=statement_id, + statement_id=command_id.to_hex_guid(), + statement_type=statement_type, + chunk_id=chunk_id, ) - return queue, resp.hasMoreRows + return ( + queue, + resp.hasMoreRows, + len(resp.results.resultLinks) if resp.results.resultLinks else 0, + ) def cancel_command(self, command_id: CommandId) -> None: thrift_handle = command_id.to_thrift_handle() diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c279f2c1f..aadb0547a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -708,7 +708,7 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency(StatementType.SQL) + @log_latency() def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None ): @@ -717,6 +717,7 @@ def _handle_staging_put( Raise an exception if request fails. Returns no data. """ + self.statement_type = StatementType.SQL if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", @@ -748,7 +749,7 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) - @log_latency(StatementType.SQL) + @log_latency() def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None ): @@ -757,6 +758,7 @@ def _handle_staging_get( Raise an exception if request fails. Returns no data. """ + self.statement_type = StatementType.SQL if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", @@ -776,12 +778,13 @@ def _handle_staging_get( with open(local_file, "wb") as fp: fp.write(r.content) - @log_latency(StatementType.SQL) + @log_latency() def _handle_staging_remove( self, presigned_url: str, headers: Optional[dict] = None ): """Make an HTTP DELETE request to the presigned_url""" + self.statement_type = StatementType.SQL r = requests.delete(url=presigned_url, headers=headers) if not r.ok: @@ -790,7 +793,7 @@ def _handle_staging_remove( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency(StatementType.QUERY) + @log_latency() def execute( self, operation: str, @@ -829,6 +832,7 @@ def execute( :returns self """ + self.statement_type = StatementType.QUERY logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) @@ -866,6 +870,7 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, + statement_type=self.statement_type, ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -875,7 +880,7 @@ def execute( return self - @log_latency(StatementType.QUERY) + @log_latency() def execute_async( self, operation: str, @@ -891,6 +896,7 @@ def execute_async( :return: """ + self.statement_type = StatementType.QUERY param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -924,6 +930,7 @@ def execute_async( async_op=True, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, + statement_type=self.statement_type, ) return self @@ -964,7 +971,7 @@ def get_async_execution_result(self): operation_state = self.get_query_state() if operation_state == CommandState.SUCCEEDED: self.active_result_set = self.backend.get_execution_result( - self.active_command_id, self + self.active_command_id, cursor=self, statement_type=self.statement_type ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -994,13 +1001,14 @@ def executemany(self, operation, seq_of_parameters): self.execute(operation, parameters) return self - @log_latency(StatementType.METADATA) + @log_latency() def catalogs(self) -> "Cursor": """ Get all available catalogs. :returns self """ + self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_catalogs( @@ -1008,10 +1016,11 @@ def catalogs(self) -> "Cursor": max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, + statement_type=self.statement_type, ) return self - @log_latency(StatementType.METADATA) + @log_latency() def schemas( self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None ) -> "Cursor": @@ -1021,6 +1030,7 @@ def schemas( Names can contain % wildcards. :returns self """ + self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_schemas( @@ -1030,10 +1040,11 @@ def schemas( cursor=self, catalog_name=catalog_name, schema_name=schema_name, + statement_type=self.statement_type, ) return self - @log_latency(StatementType.METADATA) + @log_latency() def tables( self, catalog_name: Optional[str] = None, @@ -1047,6 +1058,7 @@ def tables( Names can contain % wildcards. :returns self """ + self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() @@ -1059,10 +1071,11 @@ def tables( schema_name=schema_name, table_name=table_name, table_types=table_types, + statement_type=self.statement_type, ) return self - @log_latency(StatementType.METADATA) + @log_latency() def columns( self, catalog_name: Optional[str] = None, @@ -1076,6 +1089,7 @@ def columns( Names can contain % wildcards. :returns self """ + self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() @@ -1088,6 +1102,7 @@ def columns( schema_name=schema_name, table_name=table_name, column_name=column_name, + statement_type=self.statement_type, ) return self diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 832a6f2ea..fcb3d3996 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,7 +1,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, Future -from typing import List, Union, Optional, Tuple +from typing import List, Union, Tuple, Optional from databricks.sql.cloudfetch.downloader import ( ResultSetDownloadHandler, @@ -9,7 +9,7 @@ DownloadedFile, ) from databricks.sql.types import SSLOptions - +from databricks.sql.telemetry.models.event import StatementType from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -22,11 +22,13 @@ def __init__( max_download_threads: int, lz4_compressed: bool, ssl_options: SSLOptions, - session_id_hex: Optional[str] = None, - statement_id: Optional[str] = None, + session_id_hex: Optional[str], + statement_id: str, + statement_type: StatementType, + chunk_id: int, ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] - for i, link in enumerate(links): + for i, link in enumerate(links, start=chunk_id): if link.rowCount <= 0: continue logger.debug( @@ -44,6 +46,7 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id + self.statement_type = statement_type def get_next_downloaded_file( self, next_row_offset: int @@ -106,6 +109,7 @@ def _schedule_downloads(self): chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, + statement_type=self.statement_type, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 68c310af6..49f4ccc3c 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -11,6 +11,7 @@ from databricks.sql.exc import Error from databricks.sql.types import SSLOptions from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.event import StatementType logger = logging.getLogger(__name__) @@ -69,8 +70,9 @@ def __init__( link: TSparkArrowResultLink, ssl_options: SSLOptions, chunk_id: int, - session_id_hex: Optional[str] = None, - statement_id: Optional[str] = None, + session_id_hex: Optional[str], + statement_id: str, + statement_type: StatementType, ): self.settings = settings self.link = link @@ -78,6 +80,7 @@ def __init__( self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id + self.statement_type = statement_type @log_latency() def run(self) -> DownloadedFile: diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 3595d65d7..e10822262 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -22,6 +22,7 @@ ColumnQueue, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse +from databricks.sql.telemetry.models.event import StatementType logger = logging.getLogger(__name__) @@ -191,6 +192,8 @@ def __init__( connection: "Connection", execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", + session_id_hex: Optional[str], + statement_type: StatementType, buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -198,8 +201,6 @@ def __init__( max_download_threads: int = 10, ssl_options=None, is_direct_results: bool = True, - session_id_hex: Optional[str] = None, - statement_id: Optional[str] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -216,6 +217,8 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ + self.statement_type = statement_type + self.chunk_id = 0 # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch @@ -236,8 +239,12 @@ def __init__( description=execute_response.description, ssl_options=ssl_options, session_id_hex=session_id_hex, - statement_id=statement_id, + statement_id=execute_response.command_id.to_hex_guid(), + statement_type=statement_type, + chunk_id=self.chunk_id, ) + if t_row_set and t_row_set.resultLinks: + self.chunk_id = len(t_row_set.resultLinks) # Call parent constructor with common attributes super().__init__( @@ -261,7 +268,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, is_direct_results, result_links_count = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -270,9 +277,12 @@ def _fill_results_buffer(self): arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, use_cloud_fetch=self._use_cloud_fetch, + statement_type=self.statement_type, + chunk_id=self.chunk_id, ) self.results = results self.is_direct_results = is_direct_results + self.chunk_id += result_links_count def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index b7f1a8fe6..b0782e263 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -7,7 +7,6 @@ SqlExecutionEvent, ) from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType -from uuid import UUID logger = logging.getLogger(__name__) @@ -35,7 +34,7 @@ def get_statement_id(self): def get_is_compressed(self): pass - def get_execution_result(self): + def get_execution_result_format(self): pass def get_retry_count(self): @@ -44,6 +43,9 @@ def get_retry_count(self): def get_chunk_id(self): pass + def get_statement_type(self): + pass + class CursorExtractor(TelemetryExtractor): """ @@ -62,7 +64,7 @@ def get_session_id_hex(self) -> Optional[str]: def get_is_compressed(self) -> bool: return self.connection.lz4_compression - def get_execution_result(self) -> ExecutionResultFormat: + def get_execution_result_format(self) -> ExecutionResultFormat: if self.active_result_set is None: return ExecutionResultFormat.FORMAT_UNSPECIFIED @@ -84,6 +86,9 @@ def get_retry_count(self) -> int: def get_chunk_id(self): return None + def get_statement_type(self): + return self.statement_type + class ResultSetDownloadHandlerExtractor(TelemetryExtractor): """ @@ -99,7 +104,7 @@ def get_statement_id(self) -> Optional[str]: def get_is_compressed(self) -> bool: return self._obj.settings.is_lz4_compressed - def get_execution_result(self) -> ExecutionResultFormat: + def get_execution_result_format(self) -> ExecutionResultFormat: return ExecutionResultFormat.EXTERNAL_LINKS def get_retry_count(self) -> Optional[int]: @@ -109,6 +114,9 @@ def get_retry_count(self) -> Optional[int]: def get_chunk_id(self) -> Optional[int]: return self._obj.chunk_id + def get_statement_type(self): + return self.statement_type + def get_extractor(obj): """ @@ -194,9 +202,11 @@ def _safe_call(func_to_call): statement_id = _safe_call(extractor.get_statement_id) sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, + statement_type=_safe_call(extractor.get_statement_type), is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call(extractor.get_execution_result), + execution_result=_safe_call( + extractor.get_execution_result_format + ), retry_count=_safe_call(extractor.get_retry_count), chunk_id=_safe_call(extractor.get_chunk_id), ) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 044038b2b..4db8ea725 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -31,7 +31,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.types import CommandId - +from databricks.sql.telemetry.models.event import StatementType from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter import logging @@ -60,10 +60,12 @@ def build_queue( arrow_schema_bytes: bytes, max_download_threads: int, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + statement_type: StatementType, + chunk_id: int, lz4_compressed: bool = True, description: List[Tuple] = [], - session_id_hex: Optional[str] = None, - statement_id: Optional[str] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -110,6 +112,8 @@ def build_queue( ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, + statement_type=statement_type, + chunk_id=chunk_id, ) else: raise AssertionError("Row set type is not valid") @@ -211,12 +215,14 @@ def __init__( schema_bytes, max_download_threads: int, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + statement_type: StatementType, + chunk_id: int, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, description: List[Tuple] = [], - session_id_hex: Optional[str] = None, - statement_id: Optional[str] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. @@ -239,6 +245,8 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id + self.statement_type = statement_type + self.chunk_id = chunk_id logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( @@ -259,6 +267,8 @@ def __init__( ssl_options=self._ssl_options, session_id_hex=self.session_id_hex, statement_id=self.statement_id, + statement_type=self.statement_type, + chunk_id=self.chunk_id, ) self.table = self._create_next_table() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a2525ed97..b1639fa66 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -115,7 +115,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -128,6 +128,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): connection=connection, execute_response=mock_execute_response, thrift_client=mock_backend, + session_id_hex=Mock(), + statement_type=Mock(), ) # Mock execute_command to return our real result set @@ -188,12 +190,14 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( connection=mock_connection, execute_response=Mock(), thrift_client=mock_backend, + session_id_hex=Mock(), + statement_type=Mock(), ) # Setup session mock on the mock_connection mock_session = Mock() @@ -215,10 +219,9 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) + mock_thrift_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend - ) + mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock(), statement_type=Mock(),) result_set.close() @@ -261,9 +264,9 @@ def test_closed_cursor_doesnt_allow_operations(self): def test_negative_fetch_throws_exception(self): mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(Mock(), Mock(), mock_backend, session_id_hex=Mock(), statement_type=Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 7dec4e680..2eab268d6 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -4,7 +4,7 @@ pyarrow = None import unittest import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils @@ -63,6 +63,10 @@ def test_initializer_adds_links(self, mock_create_next_table): result_links=result_links, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert len(queue.download_manager._pending_links) == 10 @@ -77,6 +81,10 @@ def test_initializer_no_links_to_add(self): result_links=result_links, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert len(queue.download_manager._pending_links) == 0 @@ -93,6 +101,10 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): result_links=[], max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue._create_next_table() is None @@ -114,6 +126,10 @@ def test_initializer_create_next_table_success( description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) expected_result = self.make_arrow_table() @@ -139,6 +155,10 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -159,6 +179,10 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -179,6 +203,10 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -204,6 +232,10 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -223,6 +255,10 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table is None @@ -240,6 +276,10 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -259,6 +299,10 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -278,6 +322,10 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -303,6 +351,10 @@ def test_remaining_rows_multiple_tables_fully_returned( description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -328,6 +380,10 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) assert queue.table is None diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 64edbdebe..b3d6c7988 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock import databricks.sql.cloudfetch.download_manager as download_manager from databricks.sql.types import SSLOptions @@ -19,6 +19,10 @@ def create_download_manager( max_download_threads, lz4_compressed, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + statement_type=Mock(), + chunk_id=0, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index d6ee28984..c440bf116 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -27,7 +27,7 @@ def test_run_link_expired(self, mock_time): # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) with self.assertRaises(Error) as context: @@ -43,7 +43,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) with self.assertRaises(Error) as context: @@ -63,7 +63,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): result_link = Mock(expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() @@ -82,7 +82,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) file = d.run() @@ -105,7 +105,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) file = d.run() @@ -121,7 +121,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) with self.assertRaises(ConnectionError): d.run() @@ -136,7 +136,7 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0 + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() ) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..07c995015 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -43,7 +43,7 @@ def make_dummy_result_set_from_initial_results(initial_results): # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False, 0) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -63,6 +63,8 @@ def make_dummy_result_set_from_initial_results(initial_results): ), thrift_client=mock_thrift_backend, t_row_set=None, + session_id_hex=Mock(), + statement_type=Mock(), ) return rs @@ -79,12 +81,14 @@ def fetch_results( arrow_schema_bytes, description, use_cloud_fetch=True, + statement_type=Mock(), + chunk_id=0, ): nonlocal batch_index results = FetchTests.make_arrow_queue(batch_list[batch_index]) batch_index += 1 - return results, batch_index < len(batch_list) + return results, batch_index < len(batch_list), 0 mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results @@ -106,6 +110,8 @@ def fetch_results( is_staging_operation=False, ), thrift_client=mock_thrift_backend, + session_id_hex=Mock(), + statement_type=Mock(), ) return rs diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index abfb13a02..2bdad60b0 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -649,7 +649,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _, _ = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -731,7 +731,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock(), Mock()) self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) @@ -772,7 +772,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock(), Mock()) self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) @@ -892,7 +892,6 @@ def test_handle_execute_response_can_handle_without_direct_results( ( execute_response, _, - _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, @@ -928,7 +927,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -966,7 +965,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _, _ = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -998,7 +997,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - _, _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1047,7 +1046,6 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ( execute_response, has_more_rows_result, - _ ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(is_direct_results, has_more_rows_result) @@ -1099,7 +1097,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) - _, has_more_rows_resp = thrift_backend.fetch_results( + _, has_more_rows_resp, _ = thrift_backend.fetch_results( command_id=Mock(), max_rows=1, max_bytes=1, @@ -1107,6 +1105,8 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( lz4_compressed=False, arrow_schema_bytes=Mock(), description=Mock(), + statement_type=Mock(), + chunk_id=0, ) self.assertEqual(is_direct_results, has_more_rows_resp) @@ -1152,7 +1152,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - arrow_queue, has_more_results = thrift_backend.fetch_results( + arrow_queue, has_more_results, _ = thrift_backend.fetch_results( command_id=Mock(), max_rows=1, max_bytes=1, @@ -1160,6 +1160,8 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): lz4_compressed=False, arrow_schema_bytes=schema, description=MagicMock(), + statement_type=Mock(), + chunk_id=0, ) self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @@ -1181,11 +1183,11 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( - "foo", Mock(), 100, 200, Mock(), cursor_mock + "foo", Mock(), 100, 200, Mock(), cursor_mock, Mock() ) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1217,10 +1219,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() - result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock, Mock()) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1250,7 +1252,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1258,6 +1260,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( 100, 200, cursor_mock, + Mock(), catalog_name="catalog_pattern", schema_name="schema_pattern", ) @@ -1292,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1300,6 +1303,7 @@ def test_get_tables_calls_client_and_handle_execute_response( 100, 200, cursor_mock, + Mock(), catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", @@ -1338,7 +1342,7 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1346,6 +1350,7 @@ def test_get_columns_calls_client_and_handle_execute_response( 100, 200, cursor_mock, + Mock(), catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", @@ -1450,7 +1455,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( thrift_backend = self._make_fake_thrift_backend() with self.assertRaises(OperationalError) as cm: - thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock(), Mock()) self.assertIn( "Expected results to be in Arrow or column based format", str(cm.exception) ) @@ -1684,7 +1689,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2256,7 +2261,6 @@ def test_execute_command_sets_complex_type_fields_correctly( mock_handle_execute_response.return_value = ( mock_execute_response, mock_arrow_schema, - Mock() ) # Iterate through each possible combination of native types (True, False and unset) @@ -2280,7 +2284,7 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 ][0] From 1b9a2b8cc0881fd031de9a8d4e3c15f8bd533b65 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 15 Jul 2025 16:56:56 +0530 Subject: [PATCH 14/19] added statement type to command id Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/thrift_backend.py | 30 +++++++++---------- src/databricks/sql/backend/types.py | 16 ++++++++++ src/databricks/sql/client.py | 8 +---- src/databricks/sql/result_set.py | 6 ++-- .../sql/telemetry/latency_logger.py | 7 ++--- tests/unit/test_client.py | 6 ++-- tests/unit/test_fetches.py | 7 ++--- tests/unit/test_thrift_backend.py | 7 +---- 8 files changed, 40 insertions(+), 47 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index df4c31da1..06b5e9437 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -9,12 +9,11 @@ from uuid import UUID from databricks.sql.result_set import ThriftResultSet - +from databricks.sql.telemetry.models.event import StatementType if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet - from databricks.sql.telemetry.models.event import StatementType from databricks.sql.backend.types import ( CommandState, @@ -833,7 +832,7 @@ def _results_message_to_execute_response(self, resp, operation_state): return execute_response, is_direct_results def get_execution_result( - self, command_id: CommandId, cursor: "Cursor", statement_type: StatementType + self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": thrift_handle = command_id.to_thrift_handle() if not thrift_handle: @@ -889,6 +888,7 @@ def get_execution_result( arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) + execute_response.command_id.set_statement_type(StatementType.QUERY) return ThriftResultSet( connection=cursor.connection, @@ -902,7 +902,6 @@ def get_execution_result( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_type=statement_type, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -968,7 +967,6 @@ def execute_command( max_bytes: int, lz4_compression: bool, cursor: Cursor, - statement_type: StatementType, use_cloud_fetch=True, parameters=[], async_op=False, @@ -1030,6 +1028,8 @@ def execute_command( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results + execute_response.command_id.set_statement_type(StatementType.QUERY) + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1042,7 +1042,6 @@ def execute_command( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_type=statement_type, ) def get_catalogs( @@ -1051,7 +1050,6 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - statement_type: StatementType, ) -> "ResultSet": thrift_handle = session_id.to_thrift_handle() if not thrift_handle: @@ -1073,6 +1071,8 @@ def get_catalogs( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results + execute_response.command_id.set_statement_type(StatementType.METADATA) + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1085,7 +1085,6 @@ def get_catalogs( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_id=statement_type, ) def get_schemas( @@ -1094,7 +1093,6 @@ def get_schemas( max_rows: int, max_bytes: int, cursor: Cursor, - statement_type: StatementType, catalog_name=None, schema_name=None, ) -> "ResultSet": @@ -1122,6 +1120,8 @@ def get_schemas( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results + execute_response.command_id.set_statement_type(StatementType.METADATA) + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1134,7 +1134,6 @@ def get_schemas( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_type=statement_type, ) def get_tables( @@ -1143,7 +1142,6 @@ def get_tables( max_rows: int, max_bytes: int, cursor: Cursor, - statement_type: StatementType, catalog_name=None, schema_name=None, table_name=None, @@ -1175,6 +1173,8 @@ def get_tables( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results + execute_response.command_id.set_statement_type(StatementType.METADATA) + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1187,7 +1187,6 @@ def get_tables( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_type=statement_type, ) def get_columns( @@ -1196,7 +1195,6 @@ def get_columns( max_rows: int, max_bytes: int, cursor: Cursor, - statement_type: StatementType, catalog_name=None, schema_name=None, table_name=None, @@ -1228,6 +1226,8 @@ def get_columns( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results + execute_response.command_id.set_statement_type(StatementType.METADATA) + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1240,7 +1240,6 @@ def get_columns( ssl_options=self._ssl_options, is_direct_results=is_direct_results, session_id_hex=self._session_id_hex, - statement_type=statement_type, ) def _handle_execute_response(self, resp, cursor): @@ -1275,7 +1274,6 @@ def fetch_results( lz4_compressed: bool, arrow_schema_bytes, description, - statement_type, chunk_id: int, use_cloud_fetch=True, ): @@ -1316,7 +1314,7 @@ def fetch_results( ssl_options=self._ssl_options, session_id_hex=self._session_id_hex, statement_id=command_id.to_hex_guid(), - statement_type=statement_type, + statement_type=command_id.statement_type, chunk_id=chunk_id, ) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index f6428a187..da1d59ee5 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -4,6 +4,7 @@ import logging from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.telemetry.models.enums import StatementType from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) @@ -281,6 +282,7 @@ def __init__( operation_type: Optional[int] = None, has_result_set: bool = False, modified_row_count: Optional[int] = None, + statement_type: Optional[StatementType] = None, ): """ Initialize a CommandId. @@ -300,6 +302,7 @@ def __init__( self.operation_type = operation_type self.has_result_set = has_result_set self.modified_row_count = modified_row_count + self._statement_type = statement_type def __str__(self) -> str: """ @@ -411,6 +414,19 @@ def to_hex_guid(self) -> str: else: return str(self.guid) + def set_statement_type(self, statement_type: StatementType): + """ + Set the statement type for this command. + """ + self._statement_type = statement_type + + @property + def statement_type(self) -> Optional[StatementType]: + """ + Get the statement type for this command. + """ + return self._statement_type + @dataclass class ExecuteResponse: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index aadb0547a..cbacaa5f2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -870,7 +870,6 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, - statement_type=self.statement_type, ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -930,7 +929,6 @@ def execute_async( async_op=True, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, - statement_type=self.statement_type, ) return self @@ -971,7 +969,7 @@ def get_async_execution_result(self): operation_state = self.get_query_state() if operation_state == CommandState.SUCCEEDED: self.active_result_set = self.backend.get_execution_result( - self.active_command_id, cursor=self, statement_type=self.statement_type + self.active_command_id, self ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -1016,7 +1014,6 @@ def catalogs(self) -> "Cursor": max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, - statement_type=self.statement_type, ) return self @@ -1040,7 +1037,6 @@ def schemas( cursor=self, catalog_name=catalog_name, schema_name=schema_name, - statement_type=self.statement_type, ) return self @@ -1071,7 +1067,6 @@ def tables( schema_name=schema_name, table_name=table_name, table_types=table_types, - statement_type=self.statement_type, ) return self @@ -1102,7 +1097,6 @@ def columns( schema_name=schema_name, table_name=table_name, column_name=column_name, - statement_type=self.statement_type, ) return self diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 843e5051d..139d650c6 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -194,7 +194,6 @@ def __init__( execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", session_id_hex: Optional[str], - statement_type: StatementType, buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -218,7 +217,7 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ - self.statement_type = statement_type + self.statement_type = execute_response.command_id.statement_type self.chunk_id = 0 # Initialize ThriftResultSet-specific attributes @@ -241,7 +240,7 @@ def __init__( ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=execute_response.command_id.to_hex_guid(), - statement_type=statement_type, + statement_type=self.statement_type, chunk_id=self.chunk_id, ) if t_row_set and t_row_set.resultLinks: @@ -278,7 +277,6 @@ def _fill_results_buffer(self): arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, use_cloud_fetch=self._use_cloud_fetch, - statement_type=self.statement_type, chunk_id=self.chunk_id, ) self.results = results diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index b0782e263..10f1b2291 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -144,7 +144,7 @@ def get_extractor(obj): return None -def log_latency(statement_type: StatementType = StatementType.NONE): +def log_latency(): """ Decorator for logging execution latency and telemetry information. @@ -158,11 +158,8 @@ def log_latency(statement_type: StatementType = StatementType.NONE): - Creates a SqlExecutionEvent with execution details - Sends the telemetry data asynchronously via TelemetryClient - Args: - statement_type (StatementType): The type of SQL statement being executed. - Usage: - @log_latency(StatementType.SQL) + @log_latency() def execute(self, query): # Method implementation pass diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6dfe75258..e9acfc9ba 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -129,7 +129,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): execute_response=mock_execute_response, thrift_client=mock_backend, session_id_hex=Mock(), - statement_type=Mock(), ) # Mock execute_command to return our real result set @@ -197,7 +196,6 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): execute_response=Mock(), thrift_client=mock_backend, session_id_hex=Mock(), - statement_type=Mock(), ) result_set.results = mock_results @@ -225,7 +223,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_thrift_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock(), statement_type=Mock(),) + mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock()) result_set.close() @@ -271,7 +269,7 @@ def test_negative_fetch_throws_exception(self): mock_backend = Mock() mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend, session_id_hex=Mock(), statement_type=Mock()) + result_set = ThriftResultSet(Mock(), Mock(), mock_backend, session_id_hex=Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 07c995015..9bb29de8f 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -54,7 +54,7 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( - command_id=None, + command_id=Mock(), status=None, has_been_closed_server_side=True, description=description, @@ -64,7 +64,6 @@ def make_dummy_result_set_from_initial_results(initial_results): thrift_client=mock_thrift_backend, t_row_set=None, session_id_hex=Mock(), - statement_type=Mock(), ) return rs @@ -81,7 +80,6 @@ def fetch_results( arrow_schema_bytes, description, use_cloud_fetch=True, - statement_type=Mock(), chunk_id=0, ): nonlocal batch_index @@ -102,7 +100,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( - command_id=None, + command_id=Mock(), status=None, has_been_closed_server_side=False, description=description, @@ -111,7 +109,6 @@ def fetch_results( ), thrift_client=mock_thrift_backend, session_id_hex=Mock(), - statement_type=Mock(), ) return rs diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 2bdad60b0..452eb4d3e 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1105,7 +1105,6 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( lz4_compressed=False, arrow_schema_bytes=Mock(), description=Mock(), - statement_type=Mock(), chunk_id=0, ) @@ -1160,7 +1159,6 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): lz4_compressed=False, arrow_schema_bytes=schema, description=MagicMock(), - statement_type=Mock(), chunk_id=0, ) @@ -1222,7 +1220,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() - result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock, Mock()) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1260,7 +1258,6 @@ def test_get_schemas_calls_client_and_handle_execute_response( 100, 200, cursor_mock, - Mock(), catalog_name="catalog_pattern", schema_name="schema_pattern", ) @@ -1303,7 +1300,6 @@ def test_get_tables_calls_client_and_handle_execute_response( 100, 200, cursor_mock, - Mock(), catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", @@ -1350,7 +1346,6 @@ def test_get_columns_calls_client_and_handle_execute_response( 100, 200, cursor_mock, - Mock(), catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", From 153d34638792c089b5492b1ac05cfdedd1e77938 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 15 Jul 2025 17:12:41 +0530 Subject: [PATCH 15/19] check types fix Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/types.py | 5 ++--- tests/unit/test_client.py | 5 ++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index da1d59ee5..30b6fbcfd 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -282,7 +282,6 @@ def __init__( operation_type: Optional[int] = None, has_result_set: bool = False, modified_row_count: Optional[int] = None, - statement_type: Optional[StatementType] = None, ): """ Initialize a CommandId. @@ -302,7 +301,7 @@ def __init__( self.operation_type = operation_type self.has_result_set = has_result_set self.modified_row_count = modified_row_count - self._statement_type = statement_type + self._statement_type = StatementType.NONE def __str__(self) -> str: """ @@ -421,7 +420,7 @@ def set_statement_type(self, statement_type: StatementType): self._statement_type = statement_type @property - def statement_type(self) -> Optional[StatementType]: + def statement_type(self) -> StatementType: """ Get the statement type for this command. """ diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e9acfc9ba..d015cfacc 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -189,6 +189,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_results = Mock() mock_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( @@ -223,7 +224,9 @@ def test_closing_result_set_hard_closes_commands(self): mock_thrift_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock()) + mock_connection, mock_results_response, mock_thrift_backend, session_id_hex=Mock() + ) + result_set.results = mock_results result_set.close() From 79d58dd953573ac0fb56046644ed7d3944a04d0c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 16 Jul 2025 11:55:54 +0530 Subject: [PATCH 16/19] renamed chunk_id to num_downloaded_chunks Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/result_set.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 139d650c6..f74cd0a5a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -218,7 +218,7 @@ def __init__( :param is_direct_results: Whether there are more rows to fetch """ self.statement_type = execute_response.command_id.statement_type - self.chunk_id = 0 + self.num_downloaded_chunks = 0 # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch @@ -241,10 +241,10 @@ def __init__( session_id_hex=session_id_hex, statement_id=execute_response.command_id.to_hex_guid(), statement_type=self.statement_type, - chunk_id=self.chunk_id, + chunk_id=self.num_downloaded_chunks, ) if t_row_set and t_row_set.resultLinks: - self.chunk_id = len(t_row_set.resultLinks) + self.num_downloaded_chunks += len(t_row_set.resultLinks) # Call parent constructor with common attributes super().__init__( @@ -277,11 +277,11 @@ def _fill_results_buffer(self): arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, use_cloud_fetch=self._use_cloud_fetch, - chunk_id=self.chunk_id, + chunk_id=self.num_downloaded_chunks, ) self.results = results self.is_direct_results = is_direct_results - self.chunk_id += result_links_count + self.num_downloaded_chunks += result_links_count def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] From a35d3e9904d6f318d2b8f332c902e95cc31f03c9 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 16 Jul 2025 15:41:09 +0530 Subject: [PATCH 17/19] set statement type to query for chunk download Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/sea/queue.py | 4 +-- src/databricks/sql/backend/thrift_backend.py | 12 --------- src/databricks/sql/backend/types.py | 14 ---------- src/databricks/sql/client.py | 27 +++++++------------ .../sql/cloudfetch/download_manager.py | 3 --- src/databricks/sql/cloudfetch/downloader.py | 4 +-- src/databricks/sql/result_set.py | 2 -- .../sql/telemetry/latency_logger.py | 15 +++-------- src/databricks/sql/utils.py | 8 ------ tests/unit/test_cloud_fetch_queue.py | 14 ---------- tests/unit/test_download_manager.py | 1 - tests/unit/test_downloader.py | 14 +++++----- 12 files changed, 21 insertions(+), 97 deletions(-) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index ab869d813..b525c9455 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -135,14 +135,12 @@ def __init__( super().__init__( max_download_threads=max_download_threads, ssl_options=ssl_options, - # TODO: fix these arguments when telemetry is implemented in SEA - session_id_hex=None, statement_id=statement_id, - statement_type=StatementType.NONE, chunk_id=0, schema_bytes=None, lz4_compressed=lz4_compressed, description=description, + session_id_hex=None, # TODO: fix this argument when telemetry is implemented in SEA ) self._sea_client = sea_client diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index a86ee2fed..84679cb33 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -889,7 +889,6 @@ def get_execution_result( arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) - execute_response.command_id.set_statement_type(StatementType.QUERY) return ThriftResultSet( connection=cursor.connection, @@ -1029,8 +1028,6 @@ def execute_command( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results - execute_response.command_id.set_statement_type(StatementType.QUERY) - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1072,8 +1069,6 @@ def get_catalogs( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results - execute_response.command_id.set_statement_type(StatementType.METADATA) - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1121,8 +1116,6 @@ def get_schemas( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results - execute_response.command_id.set_statement_type(StatementType.METADATA) - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1174,8 +1167,6 @@ def get_tables( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results - execute_response.command_id.set_statement_type(StatementType.METADATA) - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1227,8 +1218,6 @@ def get_columns( if resp.directResults and resp.directResults.resultSet: t_row_set = resp.directResults.resultSet.results - execute_response.command_id.set_statement_type(StatementType.METADATA) - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1315,7 +1304,6 @@ def fetch_results( ssl_options=self._ssl_options, session_id_hex=self._session_id_hex, statement_id=command_id.to_hex_guid(), - statement_type=command_id.statement_type, chunk_id=chunk_id, ) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 30b6fbcfd..a4ec307d4 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -301,7 +301,6 @@ def __init__( self.operation_type = operation_type self.has_result_set = has_result_set self.modified_row_count = modified_row_count - self._statement_type = StatementType.NONE def __str__(self) -> str: """ @@ -413,19 +412,6 @@ def to_hex_guid(self) -> str: else: return str(self.guid) - def set_statement_type(self, statement_type: StatementType): - """ - Set the statement type for this command. - """ - self._statement_type = statement_type - - @property - def statement_type(self) -> StatementType: - """ - Get the statement type for this command. - """ - return self._statement_type - @dataclass class ExecuteResponse: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index cbacaa5f2..c279f2c1f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -708,7 +708,7 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() + @log_latency(StatementType.SQL) def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None ): @@ -717,7 +717,6 @@ def _handle_staging_put( Raise an exception if request fails. Returns no data. """ - self.statement_type = StatementType.SQL if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", @@ -749,7 +748,7 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) - @log_latency() + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None ): @@ -758,7 +757,6 @@ def _handle_staging_get( Raise an exception if request fails. Returns no data. """ - self.statement_type = StatementType.SQL if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", @@ -778,13 +776,12 @@ def _handle_staging_get( with open(local_file, "wb") as fp: fp.write(r.content) - @log_latency() + @log_latency(StatementType.SQL) def _handle_staging_remove( self, presigned_url: str, headers: Optional[dict] = None ): """Make an HTTP DELETE request to the presigned_url""" - self.statement_type = StatementType.SQL r = requests.delete(url=presigned_url, headers=headers) if not r.ok: @@ -793,7 +790,7 @@ def _handle_staging_remove( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() + @log_latency(StatementType.QUERY) def execute( self, operation: str, @@ -832,7 +829,6 @@ def execute( :returns self """ - self.statement_type = StatementType.QUERY logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) @@ -879,7 +875,7 @@ def execute( return self - @log_latency() + @log_latency(StatementType.QUERY) def execute_async( self, operation: str, @@ -895,7 +891,6 @@ def execute_async( :return: """ - self.statement_type = StatementType.QUERY param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -999,14 +994,13 @@ def executemany(self, operation, seq_of_parameters): self.execute(operation, parameters) return self - @log_latency() + @log_latency(StatementType.METADATA) def catalogs(self) -> "Cursor": """ Get all available catalogs. :returns self """ - self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_catalogs( @@ -1017,7 +1011,7 @@ def catalogs(self) -> "Cursor": ) return self - @log_latency() + @log_latency(StatementType.METADATA) def schemas( self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None ) -> "Cursor": @@ -1027,7 +1021,6 @@ def schemas( Names can contain % wildcards. :returns self """ - self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() self.active_result_set = self.backend.get_schemas( @@ -1040,7 +1033,7 @@ def schemas( ) return self - @log_latency() + @log_latency(StatementType.METADATA) def tables( self, catalog_name: Optional[str] = None, @@ -1054,7 +1047,6 @@ def tables( Names can contain % wildcards. :returns self """ - self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() @@ -1070,7 +1062,7 @@ def tables( ) return self - @log_latency() + @log_latency(StatementType.METADATA) def columns( self, catalog_name: Optional[str] = None, @@ -1084,7 +1076,6 @@ def columns( Names can contain % wildcards. :returns self """ - self.statement_type = StatementType.METADATA self._check_not_closed() self._close_and_clear_active_result_set() diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index c37a921f1..32b698bed 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -24,7 +24,6 @@ def __init__( ssl_options: SSLOptions, session_id_hex: Optional[str], statement_id: str, - statement_type: StatementType, chunk_id: int, ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] @@ -48,7 +47,6 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id - self.statement_type = statement_type def get_next_downloaded_file( self, next_row_offset: int @@ -111,7 +109,6 @@ def _schedule_downloads(self): chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, - statement_type=self.statement_type, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 49f4ccc3c..e19a69046 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -72,7 +72,6 @@ def __init__( chunk_id: int, session_id_hex: Optional[str], statement_id: str, - statement_type: StatementType, ): self.settings = settings self.link = link @@ -80,9 +79,8 @@ def __init__( self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id - self.statement_type = statement_type - @log_latency() + @log_latency(StatementType.QUERY) def run(self) -> DownloadedFile: """ Download the file described in the cloud fetch link. diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index f74cd0a5a..cd2f980e8 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -217,7 +217,6 @@ def __init__( :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch """ - self.statement_type = execute_response.command_id.statement_type self.num_downloaded_chunks = 0 # Initialize ThriftResultSet-specific attributes @@ -240,7 +239,6 @@ def __init__( ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=execute_response.command_id.to_hex_guid(), - statement_type=self.statement_type, chunk_id=self.num_downloaded_chunks, ) if t_row_set and t_row_set.resultLinks: diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 10f1b2291..52fcbfc02 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -43,9 +43,6 @@ def get_retry_count(self): def get_chunk_id(self): pass - def get_statement_type(self): - pass - class CursorExtractor(TelemetryExtractor): """ @@ -86,9 +83,6 @@ def get_retry_count(self) -> int: def get_chunk_id(self): return None - def get_statement_type(self): - return self.statement_type - class ResultSetDownloadHandlerExtractor(TelemetryExtractor): """ @@ -114,9 +108,6 @@ def get_retry_count(self) -> Optional[int]: def get_chunk_id(self) -> Optional[int]: return self._obj.chunk_id - def get_statement_type(self): - return self.statement_type - def get_extractor(obj): """ @@ -144,7 +135,7 @@ def get_extractor(obj): return None -def log_latency(): +def log_latency(statement_type: StatementType = StatementType.NONE): """ Decorator for logging execution latency and telemetry information. @@ -159,7 +150,7 @@ def log_latency(): - Sends the telemetry data asynchronously via TelemetryClient Usage: - @log_latency() + @log_latency(StatementType.QUERY) def execute(self, query): # Method implementation pass @@ -199,7 +190,7 @@ def _safe_call(func_to_call): statement_id = _safe_call(extractor.get_statement_id) sql_exec_event = SqlExecutionEvent( - statement_type=_safe_call(extractor.get_statement_type), + statement_type=statement_type, is_compressed=_safe_call(extractor.get_is_compressed), execution_result=_safe_call( extractor.get_execution_result_format diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index a262cd7ad..f2f9fcb95 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -63,7 +63,6 @@ def build_queue( ssl_options: SSLOptions, session_id_hex: Optional[str], statement_id: str, - statement_type: StatementType, chunk_id: int, lz4_compressed: bool = True, description: List[Tuple] = [], @@ -113,7 +112,6 @@ def build_queue( ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, - statement_type=statement_type, chunk_id=chunk_id, ) else: @@ -225,7 +223,6 @@ def __init__( ssl_options: SSLOptions, session_id_hex: Optional[str], statement_id: str, - statement_type: StatementType, chunk_id: int, schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, @@ -249,7 +246,6 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id - self.statement_type = statement_type self.chunk_id = chunk_id # Table state @@ -264,7 +260,6 @@ def __init__( ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, - statement_type=statement_type, chunk_id=chunk_id, ) @@ -371,7 +366,6 @@ def __init__( ssl_options: SSLOptions, session_id_hex: Optional[str], statement_id: str, - statement_type: StatementType, chunk_id: int, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, @@ -398,7 +392,6 @@ def __init__( description=description, session_id_hex=session_id_hex, statement_id=statement_id, - statement_type=statement_type, chunk_id=chunk_id, ) @@ -406,7 +399,6 @@ def __init__( self.result_links = result_links or [] self.session_id_hex = session_id_hex self.statement_id = statement_id - self.statement_type = statement_type self.chunk_id = chunk_id logger.debug( diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 3801590b8..f50c1b82d 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -65,7 +65,6 @@ def test_initializer_adds_links(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) @@ -83,7 +82,6 @@ def test_initializer_no_links_to_add(self): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) @@ -103,7 +101,6 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) @@ -128,7 +125,6 @@ def test_initializer_create_next_table_success( ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) expected_result = self.make_arrow_table() @@ -157,7 +153,6 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -182,7 +177,6 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -206,7 +200,6 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -235,7 +228,6 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -261,7 +253,6 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table is None @@ -282,7 +273,6 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -305,7 +295,6 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -328,7 +317,6 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -357,7 +345,6 @@ def test_remaining_rows_multiple_tables_fully_returned( ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table == self.make_arrow_table() @@ -389,7 +376,6 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) assert queue.table is None diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index b3d6c7988..6eb17a05a 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -21,7 +21,6 @@ def create_download_manager( ssl_options=SSLOptions(), session_id_hex=Mock(), statement_id=Mock(), - statement_type=Mock(), chunk_id=0, ) diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index c440bf116..9879e17c7 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -27,7 +27,7 @@ def test_run_link_expired(self, mock_time): # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(Error) as context: @@ -43,7 +43,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(Error) as context: @@ -63,7 +63,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): result_link = Mock(expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() @@ -82,7 +82,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) file = d.run() @@ -105,7 +105,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) file = d.run() @@ -121,7 +121,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(ConnectionError): d.run() @@ -136,7 +136,7 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), statement_type=Mock() + settings, result_link, ssl_options=SSLOptions(), chunk_id=0, session_id_hex=Mock(), statement_id=Mock() ) with self.assertRaises(TimeoutError): d.run() From d78ffba9c3f1b509630d5c20b699c23ebeacb252 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 16 Jul 2025 15:46:17 +0530 Subject: [PATCH 18/19] comment fix Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/backend/sea/queue.py | 5 +++-- src/databricks/sql/telemetry/latency_logger.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index b525c9455..e9764ce76 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -136,11 +136,12 @@ def __init__( max_download_threads=max_download_threads, ssl_options=ssl_options, statement_id=statement_id, - chunk_id=0, schema_bytes=None, lz4_compressed=lz4_compressed, description=description, - session_id_hex=None, # TODO: fix this argument when telemetry is implemented in SEA + # TODO: fix these arguments when telemetry is implemented in SEA + session_id_hex=None, + chunk_id=0, ) self._sea_client = sea_client diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 52fcbfc02..12cacd851 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -149,6 +149,9 @@ def log_latency(statement_type: StatementType = StatementType.NONE): - Creates a SqlExecutionEvent with execution details - Sends the telemetry data asynchronously via TelemetryClient + Args: + statement_type (StatementType): The type of SQL statement being executed. + Usage: @log_latency(StatementType.QUERY) def execute(self, query): From 6896f93b909a64ff366de3736782ddfcaa36f74e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 21 Jul 2025 11:06:45 +0530 Subject: [PATCH 19/19] removed dup check for trowset Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cd2f980e8..cb553f952 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -241,7 +241,7 @@ def __init__( statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_downloaded_chunks, ) - if t_row_set and t_row_set.resultLinks: + if t_row_set.resultLinks: self.num_downloaded_chunks += len(t_row_set.resultLinks) # Call parent constructor with common attributes