|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from abc import ABC |
4 | | -from typing import List, Optional, Tuple, Union, TYPE_CHECKING |
| 4 | +import threading |
| 5 | +from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING |
5 | 6 |
|
6 | 7 | from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager |
7 | 8 |
|
@@ -121,6 +122,179 @@ def close(self): |
121 | 122 | return |
122 | 123 |
|
123 | 124 |
|
| 125 | +class LinkFetcher: |
| 126 | + """ |
| 127 | + Background helper that incrementally retrieves *external links* for a |
| 128 | + result set produced by the SEA backend and feeds them to a |
| 129 | + :class:`databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager`. |
| 130 | +
|
| 131 | + The SEA backend splits large result sets into *chunks*. Each chunk is |
| 132 | + stored remotely (e.g., in object storage) and exposed via a signed URL |
| 133 | + encapsulated by an :class:`ExternalLink`. Only the first batch of links is |
| 134 | + returned with the initial query response. The remaining links must be |
| 135 | + pulled on demand using the *next-chunk* token embedded in each |
| 136 | + :pyattr:`ExternalLink.next_chunk_index`. |
| 137 | +
|
| 138 | + LinkFetcher takes care of this choreography so callers (primarily |
| 139 | + ``SeaCloudFetchQueue``) can simply ask for the link of a specific |
| 140 | + ``chunk_index`` and block until it becomes available. |
| 141 | +
|
| 142 | + Key responsibilities: |
| 143 | +
|
| 144 | + • Maintain an in-memory mapping from ``chunk_index`` → ``ExternalLink``. |
| 145 | + • Launch a background worker thread that continuously requests the next |
| 146 | + batch of links from the backend until all chunks have been discovered or |
| 147 | + an unrecoverable error occurs. |
| 148 | + • Bridge SEA link objects to the Thrift representation expected by the |
| 149 | + existing download manager. |
| 150 | + • Provide a synchronous API (`get_chunk_link`) that blocks until the desired |
| 151 | + link is present in the cache. |
| 152 | + """ |
| 153 | + |
| 154 | + def __init__( |
| 155 | + self, |
| 156 | + download_manager: ResultFileDownloadManager, |
| 157 | + backend: SeaDatabricksClient, |
| 158 | + statement_id: str, |
| 159 | + initial_links: List[ExternalLink], |
| 160 | + total_chunk_count: int, |
| 161 | + ): |
| 162 | + self.download_manager = download_manager |
| 163 | + self.backend = backend |
| 164 | + self._statement_id = statement_id |
| 165 | + |
| 166 | + self._shutdown_event = threading.Event() |
| 167 | + |
| 168 | + self._link_data_update = threading.Condition() |
| 169 | + self._error: Optional[Exception] = None |
| 170 | + self.chunk_index_to_link: Dict[int, ExternalLink] = {} |
| 171 | + |
| 172 | + self._add_links(initial_links) |
| 173 | + self.total_chunk_count = total_chunk_count |
| 174 | + |
| 175 | + # DEBUG: capture initial state for observability |
| 176 | + logger.debug( |
| 177 | + "LinkFetcher[%s]: initialized with %d initial link(s); expecting %d total chunk(s)", |
| 178 | + statement_id, |
| 179 | + len(initial_links), |
| 180 | + total_chunk_count, |
| 181 | + ) |
| 182 | + |
| 183 | + def _add_links(self, links: List[ExternalLink]): |
| 184 | + """Cache *links* locally and enqueue them with the download manager.""" |
| 185 | + logger.debug( |
| 186 | + "LinkFetcher[%s]: caching %d link(s) – chunks %s", |
| 187 | + self._statement_id, |
| 188 | + len(links), |
| 189 | + ", ".join(str(l.chunk_index) for l in links) if links else "<none>", |
| 190 | + ) |
| 191 | + for link in links: |
| 192 | + self.chunk_index_to_link[link.chunk_index] = link |
| 193 | + self.download_manager.add_link(LinkFetcher._convert_to_thrift_link(link)) |
| 194 | + |
| 195 | + def _get_next_chunk_index(self) -> Optional[int]: |
| 196 | + """Return the next *chunk_index* that should be requested from the backend, or ``None`` if we have them all.""" |
| 197 | + with self._link_data_update: |
| 198 | + max_chunk_index = max(self.chunk_index_to_link.keys(), default=None) |
| 199 | + if max_chunk_index is None: |
| 200 | + return 0 |
| 201 | + max_link = self.chunk_index_to_link[max_chunk_index] |
| 202 | + return max_link.next_chunk_index |
| 203 | + |
| 204 | + def _trigger_next_batch_download(self) -> bool: |
| 205 | + """Fetch the next batch of links from the backend and return *True* on success.""" |
| 206 | + logger.debug( |
| 207 | + "LinkFetcher[%s]: requesting next batch of links", self._statement_id |
| 208 | + ) |
| 209 | + next_chunk_index = self._get_next_chunk_index() |
| 210 | + if next_chunk_index is None: |
| 211 | + return False |
| 212 | + |
| 213 | + try: |
| 214 | + links = self.backend.get_chunk_links(self._statement_id, next_chunk_index) |
| 215 | + with self._link_data_update: |
| 216 | + self._add_links(links) |
| 217 | + self._link_data_update.notify_all() |
| 218 | + except Exception as e: |
| 219 | + logger.error( |
| 220 | + f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}" |
| 221 | + ) |
| 222 | + with self._link_data_update: |
| 223 | + self._error = e |
| 224 | + self._link_data_update.notify_all() |
| 225 | + return False |
| 226 | + |
| 227 | + logger.debug( |
| 228 | + "LinkFetcher[%s]: received %d new link(s)", |
| 229 | + self._statement_id, |
| 230 | + len(links), |
| 231 | + ) |
| 232 | + return True |
| 233 | + |
| 234 | + def get_chunk_link(self, chunk_index: int) -> Optional[ExternalLink]: |
| 235 | + """Return (blocking) the :class:`ExternalLink` associated with *chunk_index*.""" |
| 236 | + logger.debug( |
| 237 | + "LinkFetcher[%s]: waiting for link of chunk %d", |
| 238 | + self._statement_id, |
| 239 | + chunk_index, |
| 240 | + ) |
| 241 | + if chunk_index >= self.total_chunk_count: |
| 242 | + return None |
| 243 | + |
| 244 | + with self._link_data_update: |
| 245 | + while chunk_index not in self.chunk_index_to_link: |
| 246 | + if self._error: |
| 247 | + raise self._error |
| 248 | + if self._shutdown_event.is_set(): |
| 249 | + raise ProgrammingError( |
| 250 | + "LinkFetcher is shutting down without providing link for chunk index {}".format( |
| 251 | + chunk_index |
| 252 | + ) |
| 253 | + ) |
| 254 | + self._link_data_update.wait() |
| 255 | + |
| 256 | + return self.chunk_index_to_link[chunk_index] |
| 257 | + |
| 258 | + @staticmethod |
| 259 | + def _convert_to_thrift_link(link: ExternalLink) -> TSparkArrowResultLink: |
| 260 | + """Convert SEA external links to Thrift format for compatibility with existing download manager.""" |
| 261 | + # Parse the ISO format expiration time |
| 262 | + expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) |
| 263 | + return TSparkArrowResultLink( |
| 264 | + fileLink=link.external_link, |
| 265 | + expiryTime=expiry_time, |
| 266 | + rowCount=link.row_count, |
| 267 | + bytesNum=link.byte_count, |
| 268 | + startRowOffset=link.row_offset, |
| 269 | + httpHeaders=link.http_headers or {}, |
| 270 | + ) |
| 271 | + |
| 272 | + def _worker_loop(self): |
| 273 | + """Entry point for the background thread.""" |
| 274 | + logger.debug("LinkFetcher[%s]: worker thread started", self._statement_id) |
| 275 | + while not self._shutdown_event.is_set(): |
| 276 | + links_downloaded = self._trigger_next_batch_download() |
| 277 | + if not links_downloaded: |
| 278 | + self._shutdown_event.set() |
| 279 | + logger.debug("LinkFetcher[%s]: worker thread exiting", self._statement_id) |
| 280 | + self._link_data_update.notify_all() |
| 281 | + |
| 282 | + def start(self): |
| 283 | + """Spawn the worker thread.""" |
| 284 | + logger.debug("LinkFetcher[%s]: starting worker thread", self._statement_id) |
| 285 | + self._worker_thread = threading.Thread( |
| 286 | + target=self._worker_loop, name=f"LinkFetcher-{self._statement_id}" |
| 287 | + ) |
| 288 | + self._worker_thread.start() |
| 289 | + |
| 290 | + def stop(self): |
| 291 | + """Signal the worker thread to stop and wait for its termination.""" |
| 292 | + logger.debug("LinkFetcher[%s]: stopping worker thread", self._statement_id) |
| 293 | + self._shutdown_event.set() |
| 294 | + self._worker_thread.join() |
| 295 | + logger.debug("LinkFetcher[%s]: worker thread stopped", self._statement_id) |
| 296 | + |
| 297 | + |
124 | 298 | class SeaCloudFetchQueue(CloudFetchQueue): |
125 | 299 | """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend.""" |
126 | 300 |
|
@@ -158,80 +332,49 @@ def __init__( |
158 | 332 | description=description, |
159 | 333 | ) |
160 | 334 |
|
161 | | - self._sea_client = sea_client |
162 | | - self._statement_id = statement_id |
163 | | - self._total_chunk_count = total_chunk_count |
164 | | - |
165 | 335 | logger.debug( |
166 | 336 | "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format( |
167 | 337 | statement_id, total_chunk_count |
168 | 338 | ) |
169 | 339 | ) |
170 | 340 |
|
171 | 341 | initial_links = result_data.external_links or [] |
172 | | - self._chunk_index_to_link = {link.chunk_index: link for link in initial_links} |
173 | 342 |
|
174 | 343 | # Track the current chunk we're processing |
175 | 344 | self._current_chunk_index = 0 |
176 | | - first_link = self._chunk_index_to_link.get(self._current_chunk_index, None) |
177 | | - if not first_link: |
178 | | - # possibly an empty response |
179 | | - return None |
180 | 345 |
|
181 | | - # Track the current chunk we're processing |
182 | | - self._current_chunk_index = 0 |
183 | | - # Initialize table and position |
184 | | - self.table = self._create_table_from_link(first_link) |
| 346 | + self.link_fetcher = None # for empty responses, we do not need a link fetcher |
| 347 | + if total_chunk_count > 0: |
| 348 | + self.link_fetcher = LinkFetcher( |
| 349 | + download_manager=self.download_manager, |
| 350 | + backend=sea_client, |
| 351 | + statement_id=statement_id, |
| 352 | + initial_links=initial_links, |
| 353 | + total_chunk_count=total_chunk_count, |
| 354 | + ) |
| 355 | + self.link_fetcher.start() |
185 | 356 |
|
186 | | - def _convert_to_thrift_link(self, link: ExternalLink) -> TSparkArrowResultLink: |
187 | | - """Convert SEA external links to Thrift format for compatibility with existing download manager.""" |
188 | | - # Parse the ISO format expiration time |
189 | | - expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) |
190 | | - return TSparkArrowResultLink( |
191 | | - fileLink=link.external_link, |
192 | | - expiryTime=expiry_time, |
193 | | - rowCount=link.row_count, |
194 | | - bytesNum=link.byte_count, |
195 | | - startRowOffset=link.row_offset, |
196 | | - httpHeaders=link.http_headers or {}, |
197 | | - ) |
| 357 | + # Initialize table and position |
| 358 | + self.table = self._create_next_table() |
198 | 359 |
|
199 | | - def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: |
200 | | - if chunk_index >= self._total_chunk_count: |
| 360 | + def _create_next_table(self) -> Union["pyarrow.Table", None]: |
| 361 | + """Create next table by retrieving the logical next downloaded file.""" |
| 362 | + if self.link_fetcher is None: |
201 | 363 | return None |
202 | 364 |
|
203 | | - if chunk_index not in self._chunk_index_to_link: |
204 | | - links = self._sea_client.get_chunk_links(self._statement_id, chunk_index) |
205 | | - self._chunk_index_to_link.update({l.chunk_index: l for l in links}) |
206 | | - |
207 | | - link = self._chunk_index_to_link.get(chunk_index, None) |
208 | | - if not link: |
209 | | - raise ServerOperationError( |
210 | | - f"Error fetching link for chunk {chunk_index}", |
211 | | - { |
212 | | - "operation-id": self._statement_id, |
213 | | - "diagnostic-info": None, |
214 | | - }, |
215 | | - ) |
216 | | - return link |
217 | | - |
218 | | - def _create_table_from_link( |
219 | | - self, link: ExternalLink |
220 | | - ) -> Union["pyarrow.Table", None]: |
221 | | - """Create a table from a link.""" |
222 | | - |
223 | | - thrift_link = self._convert_to_thrift_link(link) |
224 | | - self.download_manager.add_link(thrift_link) |
| 365 | + chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index) |
| 366 | + if chunk_link is None: |
| 367 | + return None |
225 | 368 |
|
226 | | - row_offset = link.row_offset |
| 369 | + row_offset = chunk_link.row_offset |
| 370 | + # NOTE: link has already been submitted to download manager at this point |
227 | 371 | arrow_table = self._create_table_at_offset(row_offset) |
228 | 372 |
|
| 373 | + self._current_chunk_index += 1 |
| 374 | + |
229 | 375 | return arrow_table |
230 | 376 |
|
231 | | - def _create_next_table(self) -> Union["pyarrow.Table", None]: |
232 | | - """Create next table by retrieving the logical next downloaded file.""" |
233 | | - self._current_chunk_index += 1 |
234 | | - next_chunk_link = self._get_chunk_link(self._current_chunk_index) |
235 | | - if not next_chunk_link: |
236 | | - return None |
237 | | - return self._create_table_from_link(next_chunk_link) |
| 377 | + def close(self): |
| 378 | + super().close() |
| 379 | + if self.link_fetcher: |
| 380 | + self.link_fetcher.stop() |
0 commit comments