Skip to content

Commit b57c3f3

Browse files
authored
Chunk download latency (#634)
* chunk download latency Signed-off-by: Sai Shree Pradhan <[email protected]> * formatting Signed-off-by: Sai Shree Pradhan <[email protected]> * test fixes Signed-off-by: Sai Shree Pradhan <[email protected]> * sea-migration static type checking fixes Signed-off-by: Sai Shree Pradhan <[email protected]> * check types fix Signed-off-by: Sai Shree Pradhan <[email protected]> * fix type issues Signed-off-by: varun-edachali-dbx <[email protected]> * type fix revert Signed-off-by: Sai Shree Pradhan <[email protected]> * - Signed-off-by: Sai Shree Pradhan <[email protected]> * statement id in get metadata functions Signed-off-by: Sai Shree Pradhan <[email protected]> * removed result set extractor Signed-off-by: Sai Shree Pradhan <[email protected]> * databricks client type Signed-off-by: Sai Shree Pradhan <[email protected]> * formatting Signed-off-by: Sai Shree Pradhan <[email protected]> * remove defaults, fix chunk id Signed-off-by: Sai Shree Pradhan <[email protected]> * added statement type to command id Signed-off-by: Sai Shree Pradhan <[email protected]> * check types fix Signed-off-by: Sai Shree Pradhan <[email protected]> * renamed chunk_id to num_downloaded_chunks Signed-off-by: Sai Shree Pradhan <[email protected]> * set statement type to query for chunk download Signed-off-by: Sai Shree Pradhan <[email protected]> * comment fix Signed-off-by: Sai Shree Pradhan <[email protected]> * removed dup check for trowset Signed-off-by: Sai Shree Pradhan <[email protected]> --------- Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent 806e5f5 commit b57c3f3

File tree

17 files changed

+218
-89
lines changed

17 files changed

+218
-89
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
66

77
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
8+
from databricks.sql.telemetry.models.enums import StatementType
89

910
from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler
1011

@@ -327,9 +328,13 @@ def __init__(
327328
super().__init__(
328329
max_download_threads=max_download_threads,
329330
ssl_options=ssl_options,
331+
statement_id=statement_id,
330332
schema_bytes=None,
331333
lz4_compressed=lz4_compressed,
332334
description=description,
335+
# TODO: fix these arguments when telemetry is implemented in SEA
336+
session_id_hex=None,
337+
chunk_id=0,
333338
)
334339

335340
logger.debug(

src/databricks/sql/backend/thrift_backend.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import time
77
import threading
88
from typing import List, Optional, Union, Any, TYPE_CHECKING
9+
from uuid import UUID
910

1011
from databricks.sql.result_set import ThriftResultSet
11-
12+
from databricks.sql.telemetry.models.event import StatementType
1213

1314
if TYPE_CHECKING:
1415
from databricks.sql.client import Cursor
@@ -900,6 +901,7 @@ def get_execution_result(
900901
max_download_threads=self.max_download_threads,
901902
ssl_options=self._ssl_options,
902903
is_direct_results=is_direct_results,
904+
session_id_hex=self._session_id_hex,
903905
)
904906

905907
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -1037,6 +1039,7 @@ def execute_command(
10371039
max_download_threads=self.max_download_threads,
10381040
ssl_options=self._ssl_options,
10391041
is_direct_results=is_direct_results,
1042+
session_id_hex=self._session_id_hex,
10401043
)
10411044

10421045
def get_catalogs(
@@ -1077,6 +1080,7 @@ def get_catalogs(
10771080
max_download_threads=self.max_download_threads,
10781081
ssl_options=self._ssl_options,
10791082
is_direct_results=is_direct_results,
1083+
session_id_hex=self._session_id_hex,
10801084
)
10811085

10821086
def get_schemas(
@@ -1123,6 +1127,7 @@ def get_schemas(
11231127
max_download_threads=self.max_download_threads,
11241128
ssl_options=self._ssl_options,
11251129
is_direct_results=is_direct_results,
1130+
session_id_hex=self._session_id_hex,
11261131
)
11271132

11281133
def get_tables(
@@ -1173,6 +1178,7 @@ def get_tables(
11731178
max_download_threads=self.max_download_threads,
11741179
ssl_options=self._ssl_options,
11751180
is_direct_results=is_direct_results,
1181+
session_id_hex=self._session_id_hex,
11761182
)
11771183

11781184
def get_columns(
@@ -1223,6 +1229,7 @@ def get_columns(
12231229
max_download_threads=self.max_download_threads,
12241230
ssl_options=self._ssl_options,
12251231
is_direct_results=is_direct_results,
1232+
session_id_hex=self._session_id_hex,
12261233
)
12271234

12281235
def _handle_execute_response(self, resp, cursor):
@@ -1257,6 +1264,7 @@ def fetch_results(
12571264
lz4_compressed: bool,
12581265
arrow_schema_bytes,
12591266
description,
1267+
chunk_id: int,
12601268
use_cloud_fetch=True,
12611269
):
12621270
thrift_handle = command_id.to_thrift_handle()
@@ -1294,9 +1302,16 @@ def fetch_results(
12941302
lz4_compressed=lz4_compressed,
12951303
description=description,
12961304
ssl_options=self._ssl_options,
1305+
session_id_hex=self._session_id_hex,
1306+
statement_id=command_id.to_hex_guid(),
1307+
chunk_id=chunk_id,
12971308
)
12981309

1299-
return queue, resp.hasMoreRows
1310+
return (
1311+
queue,
1312+
resp.hasMoreRows,
1313+
len(resp.results.resultLinks) if resp.results.resultLinks else 0,
1314+
)
13001315

13011316
def cancel_command(self, command_id: CommandId) -> None:
13021317
thrift_handle = command_id.to_thrift_handle()

src/databricks/sql/backend/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55

66
from databricks.sql.backend.utils.guid_utils import guid_to_hex_id
7+
from databricks.sql.telemetry.models.enums import StatementType
78
from databricks.sql.thrift_api.TCLIService import ttypes
89

910
logger = logging.getLogger(__name__)

src/databricks/sql/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def read(self) -> Optional[OAuthToken]:
284284

285285
driver_connection_params = DriverConnectionParameters(
286286
http_path=http_path,
287-
mode=DatabricksClientType.THRIFT,
287+
mode=DatabricksClientType.SEA
288+
if self.session.use_sea
289+
else DatabricksClientType.THRIFT,
288290
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
289291
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
290292
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4-
from typing import List, Union
4+
from typing import List, Union, Tuple, Optional
55

66
from databricks.sql.cloudfetch.downloader import (
77
ResultSetDownloadHandler,
88
DownloadableResultSettings,
99
DownloadedFile,
1010
)
1111
from databricks.sql.types import SSLOptions
12-
12+
from databricks.sql.telemetry.models.event import StatementType
1313
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1414

1515
logger = logging.getLogger(__name__)
@@ -22,24 +22,31 @@ def __init__(
2222
max_download_threads: int,
2323
lz4_compressed: bool,
2424
ssl_options: SSLOptions,
25+
session_id_hex: Optional[str],
26+
statement_id: str,
27+
chunk_id: int,
2528
):
26-
self._pending_links: List[TSparkArrowResultLink] = []
27-
for link in links:
29+
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
30+
self.chunk_id = chunk_id
31+
for i, link in enumerate(links, start=chunk_id):
2832
if link.rowCount <= 0:
2933
continue
3034
logger.debug(
31-
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
32-
link.startRowOffset, link.rowCount
35+
"ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format(
36+
i, link.startRowOffset, link.rowCount
3337
)
3438
)
35-
self._pending_links.append(link)
39+
self._pending_links.append((i, link))
40+
self.chunk_id += len(links)
3641

3742
self._download_tasks: List[Future[DownloadedFile]] = []
3843
self._max_download_threads: int = max_download_threads
3944
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
4045

4146
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
4247
self._ssl_options = ssl_options
48+
self.session_id_hex = session_id_hex
49+
self.statement_id = statement_id
4350

4451
def get_next_downloaded_file(
4552
self, next_row_offset: int
@@ -89,14 +96,19 @@ def _schedule_downloads(self):
8996
while (len(self._download_tasks) < self._max_download_threads) and (
9097
len(self._pending_links) > 0
9198
):
92-
link = self._pending_links.pop(0)
99+
chunk_id, link = self._pending_links.pop(0)
93100
logger.debug(
94-
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
101+
"- chunk: {}, start: {}, row count: {}".format(
102+
chunk_id, link.startRowOffset, link.rowCount
103+
)
95104
)
96105
handler = ResultSetDownloadHandler(
97106
settings=self._downloadable_result_settings,
98107
link=link,
99108
ssl_options=self._ssl_options,
109+
chunk_id=chunk_id,
110+
session_id_hex=self.session_id_hex,
111+
statement_id=self.statement_id,
100112
)
101113
task = self._thread_pool.submit(handler.run)
102114
self._download_tasks.append(task)
@@ -117,7 +129,8 @@ def add_link(self, link: TSparkArrowResultLink):
117129
link.startRowOffset, link.rowCount
118130
)
119131
)
120-
self._pending_links.append(link)
132+
self._pending_links.append((self.chunk_id, link))
133+
self.chunk_id += 1
121134

122135
def _shutdown_manager(self):
123136
# Clear download handlers and shutdown the thread pool

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3+
from typing import Optional
34

45
import requests
56
from requests.adapters import HTTPAdapter, Retry
@@ -9,6 +10,8 @@
910
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1011
from databricks.sql.exc import Error
1112
from databricks.sql.types import SSLOptions
13+
from databricks.sql.telemetry.latency_logger import log_latency
14+
from databricks.sql.telemetry.models.event import StatementType
1215

1316
logger = logging.getLogger(__name__)
1417

@@ -66,11 +69,18 @@ def __init__(
6669
settings: DownloadableResultSettings,
6770
link: TSparkArrowResultLink,
6871
ssl_options: SSLOptions,
72+
chunk_id: int,
73+
session_id_hex: Optional[str],
74+
statement_id: str,
6975
):
7076
self.settings = settings
7177
self.link = link
7278
self._ssl_options = ssl_options
79+
self.chunk_id = chunk_id
80+
self.session_id_hex = session_id_hex
81+
self.statement_id = statement_id
7382

83+
@log_latency(StatementType.QUERY)
7484
def run(self) -> DownloadedFile:
7585
"""
7686
Download the file described in the cloud fetch link.
@@ -80,8 +90,8 @@ def run(self) -> DownloadedFile:
8090
"""
8191

8292
logger.debug(
83-
"ResultSetDownloadHandler: starting file download, offset {}, row count {}".format(
84-
self.link.startRowOffset, self.link.rowCount
93+
"ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format(
94+
self.chunk_id, self.link.startRowOffset, self.link.rowCount
8595
)
8696
)
8797

src/databricks/sql/result_set.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ColumnQueue,
2323
)
2424
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
25+
from databricks.sql.telemetry.models.event import StatementType
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -192,6 +193,7 @@ def __init__(
192193
connection: "Connection",
193194
execute_response: "ExecuteResponse",
194195
thrift_client: "ThriftDatabricksClient",
196+
session_id_hex: Optional[str],
195197
buffer_size_bytes: int = 104857600,
196198
arraysize: int = 10000,
197199
use_cloud_fetch: bool = True,
@@ -215,6 +217,7 @@ def __init__(
215217
:param ssl_options: SSL options for cloud fetch
216218
:param is_direct_results: Whether there are more rows to fetch
217219
"""
220+
self.num_downloaded_chunks = 0
218221

219222
# Initialize ThriftResultSet-specific attributes
220223
self._use_cloud_fetch = use_cloud_fetch
@@ -234,7 +237,12 @@ def __init__(
234237
lz4_compressed=execute_response.lz4_compressed,
235238
description=execute_response.description,
236239
ssl_options=ssl_options,
240+
session_id_hex=session_id_hex,
241+
statement_id=execute_response.command_id.to_hex_guid(),
242+
chunk_id=self.num_downloaded_chunks,
237243
)
244+
if t_row_set.resultLinks:
245+
self.num_downloaded_chunks += len(t_row_set.resultLinks)
238246

239247
# Call parent constructor with common attributes
240248
super().__init__(
@@ -258,7 +266,7 @@ def __init__(
258266
self._fill_results_buffer()
259267

260268
def _fill_results_buffer(self):
261-
results, is_direct_results = self.backend.fetch_results(
269+
results, is_direct_results, result_links_count = self.backend.fetch_results(
262270
command_id=self.command_id,
263271
max_rows=self.arraysize,
264272
max_bytes=self.buffer_size_bytes,
@@ -267,9 +275,11 @@ def _fill_results_buffer(self):
267275
arrow_schema_bytes=self._arrow_schema_bytes,
268276
description=self.description,
269277
use_cloud_fetch=self._use_cloud_fetch,
278+
chunk_id=self.num_downloaded_chunks,
270279
)
271280
self.results = results
272281
self.is_direct_results = is_direct_results
282+
self.num_downloaded_chunks += result_links_count
273283

274284
def _convert_columnar_table(self, table):
275285
column_names = [c[0] for c in self.description]

src/databricks/sql/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ def _create_backend(
9797
kwargs: dict,
9898
) -> DatabricksClient:
9999
"""Create and return the appropriate backend client."""
100-
use_sea = kwargs.get("use_sea", False)
100+
self.use_sea = kwargs.get("use_sea", False)
101101

102102
databricks_client_class: Type[DatabricksClient]
103-
if use_sea:
103+
if self.use_sea:
104104
logger.debug("Creating SEA backend client")
105105
databricks_client_class = SeaDatabricksClient
106106
else:

0 commit comments

Comments
 (0)