11import logging
22
33from concurrent .futures import ThreadPoolExecutor , Future
4- from typing import List , Union , Tuple , Optional
4+ from typing import List , Union
55
66from databricks .sql .cloudfetch .downloader import (
77 ResultSetDownloadHandler ,
88 DownloadableResultSettings ,
99 DownloadedFile ,
1010)
1111from databricks .sql .types import SSLOptions
12- from databricks . sql . telemetry . models . event import StatementType
12+
1313from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
1414
1515logger = logging .getLogger (__name__ )
@@ -22,31 +22,24 @@ def __init__(
2222 max_download_threads : int ,
2323 lz4_compressed : bool ,
2424 ssl_options : SSLOptions ,
25- session_id_hex : Optional [str ],
26- statement_id : str ,
27- chunk_id : int ,
2825 ):
29- self ._pending_links : List [Tuple [int , TSparkArrowResultLink ]] = []
30- self .chunk_id = chunk_id
31- for i , link in enumerate (links , start = chunk_id ):
26+ self ._pending_links : List [TSparkArrowResultLink ] = []
27+ for link in links :
3228 if link .rowCount <= 0 :
3329 continue
3430 logger .debug (
35- "ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}" .format (
36- i , link .startRowOffset , link .rowCount
31+ "ResultFileDownloadManager: adding file link, start offset {}, row count: {}" .format (
32+ link .startRowOffset , link .rowCount
3733 )
3834 )
39- self ._pending_links .append ((i , link ))
40- self .chunk_id += len (links )
35+ self ._pending_links .append (link )
4136
4237 self ._download_tasks : List [Future [DownloadedFile ]] = []
4338 self ._max_download_threads : int = max_download_threads
4439 self ._thread_pool = ThreadPoolExecutor (max_workers = self ._max_download_threads )
4540
4641 self ._downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
4742 self ._ssl_options = ssl_options
48- self .session_id_hex = session_id_hex
49- self .statement_id = statement_id
5043
5144 def get_next_downloaded_file (
5245 self , next_row_offset : int
@@ -96,19 +89,14 @@ def _schedule_downloads(self):
9689 while (len (self ._download_tasks ) < self ._max_download_threads ) and (
9790 len (self ._pending_links ) > 0
9891 ):
99- chunk_id , link = self ._pending_links .pop (0 )
92+ link = self ._pending_links .pop (0 )
10093 logger .debug (
101- "- chunk: {}, start: {}, row count: {}" .format (
102- chunk_id , link .startRowOffset , link .rowCount
103- )
94+ "- start: {}, row count: {}" .format (link .startRowOffset , link .rowCount )
10495 )
10596 handler = ResultSetDownloadHandler (
10697 settings = self ._downloadable_result_settings ,
10798 link = link ,
10899 ssl_options = self ._ssl_options ,
109- chunk_id = chunk_id ,
110- session_id_hex = self .session_id_hex ,
111- statement_id = self .statement_id ,
112100 )
113101 task = self ._thread_pool .submit (handler .run )
114102 self ._download_tasks .append (task )
@@ -129,8 +117,7 @@ def add_link(self, link: TSparkArrowResultLink):
129117 link .startRowOffset , link .rowCount
130118 )
131119 )
132- self ._pending_links .append ((self .chunk_id , link ))
133- self .chunk_id += 1
120+ self ._pending_links .append (link )
134121
135122 def _shutdown_manager (self ):
136123 # Clear download handlers and shutdown the thread pool
0 commit comments