11import logging
22
33from concurrent .futures import ThreadPoolExecutor , Future
4- from typing import List , Union
4+ from typing import List , Union , Tuple , Optional
55
66from databricks .sql .cloudfetch .downloader import (
77 ResultSetDownloadHandler ,
88 DownloadableResultSettings ,
99 DownloadedFile ,
1010)
1111from databricks .sql .types import SSLOptions
12-
12+ from databricks . sql . telemetry . models . event import StatementType
1313from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
1414
1515logger = logging .getLogger (__name__ )
@@ -22,24 +22,31 @@ def __init__(
2222 max_download_threads : int ,
2323 lz4_compressed : bool ,
2424 ssl_options : SSLOptions ,
25+ session_id_hex : Optional [str ],
26+ statement_id : str ,
27+ chunk_id : int ,
2528 ):
26- self ._pending_links : List [TSparkArrowResultLink ] = []
27- for link in links :
29+ self ._pending_links : List [Tuple [int , TSparkArrowResultLink ]] = []
30+ self .chunk_id = chunk_id
31+ for i , link in enumerate (links , start = chunk_id ):
2832 if link .rowCount <= 0 :
2933 continue
3034 logger .debug (
31- "ResultFileDownloadManager: adding file link, start offset {}, row count: {}" .format (
32- link .startRowOffset , link .rowCount
35+ "ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}" .format (
36+ i , link .startRowOffset , link .rowCount
3337 )
3438 )
35- self ._pending_links .append (link )
39+ self ._pending_links .append ((i , link ))
40+ self .chunk_id += len (links )
3641
3742 self ._download_tasks : List [Future [DownloadedFile ]] = []
3843 self ._max_download_threads : int = max_download_threads
3944 self ._thread_pool = ThreadPoolExecutor (max_workers = self ._max_download_threads )
4045
4146 self ._downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
4247 self ._ssl_options = ssl_options
48+ self .session_id_hex = session_id_hex
49+ self .statement_id = statement_id
4350
4451 def get_next_downloaded_file (
4552 self , next_row_offset : int
@@ -89,14 +96,19 @@ def _schedule_downloads(self):
8996 while (len (self ._download_tasks ) < self ._max_download_threads ) and (
9097 len (self ._pending_links ) > 0
9198 ):
92- link = self ._pending_links .pop (0 )
99+ chunk_id , link = self ._pending_links .pop (0 )
93100 logger .debug (
94- "- start: {}, row count: {}" .format (link .startRowOffset , link .rowCount )
101+ "- chunk: {}, start: {}, row count: {}" .format (
102+ chunk_id , link .startRowOffset , link .rowCount
103+ )
95104 )
96105 handler = ResultSetDownloadHandler (
97106 settings = self ._downloadable_result_settings ,
98107 link = link ,
99108 ssl_options = self ._ssl_options ,
109+ chunk_id = chunk_id ,
110+ session_id_hex = self .session_id_hex ,
111+ statement_id = self .statement_id ,
100112 )
101113 task = self ._thread_pool .submit (handler .run )
102114 self ._download_tasks .append (task )
@@ -117,7 +129,8 @@ def add_link(self, link: TSparkArrowResultLink):
117129 link .startRowOffset , link .rowCount
118130 )
119131 )
120- self ._pending_links .append (link )
132+ self ._pending_links .append ((self .chunk_id , link ))
133+ self .chunk_id += 1
121134
122135 def _shutdown_manager (self ):
123136 # Clear download handlers and shutdown the thread pool
0 commit comments