55import re
66from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
77
8- from databricks .sql .backend .sea .models .base import ResultManifest
8+ from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
99from databricks .sql .backend .sea .utils .constants import (
1010 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
1111 ResultFormat ,
2828 BackendType ,
2929 ExecuteResponse ,
3030)
31- from databricks .sql .exc import DatabaseError , ProgrammingError , ServerOperationError
31+ from databricks .sql .exc import DatabaseError , ServerOperationError
3232from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
3333from databricks .sql .types import SSLOptions
3434
4444 GetStatementResponse ,
4545 CreateSessionResponse ,
4646)
47+ from databricks .sql .backend .sea .models .responses import GetChunksResponse
4748
4849logger = logging .getLogger (__name__ )
4950
@@ -88,6 +89,7 @@ class SeaDatabricksClient(DatabricksClient):
8889 STATEMENT_PATH = BASE_PATH + "statements"
8990 STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9091 CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
92+ CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9193
9294 # SEA constants
9395 POLL_INTERVAL_SECONDS = 0.2
@@ -123,18 +125,22 @@ def __init__(
123125 )
124126
125127 self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
128+ self ._ssl_options = ssl_options
129+ self ._use_arrow_native_complex_types = kwargs .get (
130+ "_use_arrow_native_complex_types" , True
131+ )
126132
127133 # Extract warehouse ID from http_path
128134 self .warehouse_id = self ._extract_warehouse_id (http_path )
129135
130136 # Initialize HTTP client
131- self .http_client = SeaHttpClient (
137+ self ._http_client = SeaHttpClient (
132138 server_hostname = server_hostname ,
133139 port = port ,
134140 http_path = http_path ,
135141 http_headers = http_headers ,
136142 auth_provider = auth_provider ,
137- ssl_options = ssl_options ,
143+ ssl_options = self . _ssl_options ,
138144 ** kwargs ,
139145 )
140146
@@ -173,7 +179,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
173179 f"Note: SEA only works for warehouses."
174180 )
175181 logger .error (error_message )
176- raise ProgrammingError (error_message )
182+ raise ValueError (error_message )
177183
178184 @property
179185 def max_download_threads (self ) -> int :
@@ -220,7 +226,7 @@ def open_session(
220226 schema = schema ,
221227 )
222228
223- response = self .http_client ._make_request (
229+ response = self ._http_client ._make_request (
224230 method = "POST" , path = self .SESSION_PATH , data = request_data .to_dict ()
225231 )
226232
@@ -245,7 +251,7 @@ def close_session(self, session_id: SessionId) -> None:
245251 session_id: The session identifier returned by open_session()
246252
247253 Raises:
248- ProgrammingError : If the session ID is invalid
254+ ValueError : If the session ID is invalid
249255 OperationalError: If there's an error closing the session
250256 """
251257
@@ -260,7 +266,7 @@ def close_session(self, session_id: SessionId) -> None:
260266 session_id = sea_session_id ,
261267 )
262268
263- self .http_client ._make_request (
269+ self ._http_client ._make_request (
264270 method = "DELETE" ,
265271 path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
266272 data = request_data .to_dict (),
@@ -342,7 +348,7 @@ def _results_message_to_execute_response(
342348
343349 # Check for compression
344350 lz4_compressed = (
345- response .manifest .result_compression == ResultCompression .LZ4_FRAME
351+ response .manifest .result_compression == ResultCompression .LZ4_FRAME . value
346352 )
347353
348354 execute_response = ExecuteResponse (
@@ -424,7 +430,7 @@ def execute_command(
424430 enforce_embedded_schema_correctness: Whether to enforce schema correctness
425431
426432 Returns:
427- ResultSet : A SeaResultSet instance for the executed command
433+ SeaResultSet : A SeaResultSet instance for the executed command
428434 """
429435
430436 if session_id .backend_type != BackendType .SEA :
@@ -471,7 +477,7 @@ def execute_command(
471477 result_compression = result_compression ,
472478 )
473479
474- response_data = self .http_client ._make_request (
480+ response_data = self ._http_client ._make_request (
475481 method = "POST" , path = self .STATEMENT_PATH , data = request .to_dict ()
476482 )
477483 response = ExecuteStatementResponse .from_dict (response_data )
@@ -505,7 +511,7 @@ def cancel_command(self, command_id: CommandId) -> None:
505511 command_id: Command identifier to cancel
506512
507513 Raises:
508- ProgrammingError : If the command ID is invalid
514+ ValueError : If the command ID is invalid
509515 """
510516
511517 if command_id .backend_type != BackendType .SEA :
@@ -516,7 +522,7 @@ def cancel_command(self, command_id: CommandId) -> None:
516522 raise ValueError ("Not a valid SEA command ID" )
517523
518524 request = CancelStatementRequest (statement_id = sea_statement_id )
519- self .http_client ._make_request (
525+ self ._http_client ._make_request (
520526 method = "POST" ,
521527 path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
522528 data = request .to_dict (),
@@ -530,7 +536,7 @@ def close_command(self, command_id: CommandId) -> None:
530536 command_id: Command identifier to close
531537
532538 Raises:
533- ProgrammingError : If the command ID is invalid
539+ ValueError : If the command ID is invalid
534540 """
535541
536542 if command_id .backend_type != BackendType .SEA :
@@ -541,7 +547,7 @@ def close_command(self, command_id: CommandId) -> None:
541547 raise ValueError ("Not a valid SEA command ID" )
542548
543549 request = CloseStatementRequest (statement_id = sea_statement_id )
544- self .http_client ._make_request (
550+ self ._http_client ._make_request (
545551 method = "DELETE" ,
546552 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
547553 data = request .to_dict (),
@@ -558,7 +564,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
558564 CommandState: The current state of the command
559565
560566 Raises:
561- ProgrammingError : If the command ID is invalid
567+ ValueError : If the command ID is invalid
562568 """
563569
564570 if command_id .backend_type != BackendType .SEA :
@@ -569,7 +575,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
569575 raise ValueError ("Not a valid SEA command ID" )
570576
571577 request = GetStatementRequest (statement_id = sea_statement_id )
572- response_data = self .http_client ._make_request (
578+ response_data = self ._http_client ._make_request (
573579 method = "GET" ,
574580 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
575581 data = request .to_dict (),
@@ -609,7 +615,7 @@ def get_execution_result(
609615 request = GetStatementRequest (statement_id = sea_statement_id )
610616
611617 # Get the statement result
612- response_data = self .http_client ._make_request (
618+ response_data = self ._http_client ._make_request (
613619 method = "GET" ,
614620 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
615621 data = request .to_dict (),
@@ -631,6 +637,35 @@ def get_execution_result(
631637 arraysize = cursor .arraysize ,
632638 )
633639
640+ def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
641+ """
642+ Get links for chunks starting from the specified index.
643+ Args:
644+ statement_id: The statement ID
645+ chunk_index: The starting chunk index
646+ Returns:
647+ ExternalLink: External link for the chunk
648+ """
649+
650+ response_data = self ._http_client ._make_request (
651+ method = "GET" ,
652+ path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
653+ )
654+ response = GetChunksResponse .from_dict (response_data )
655+
656+ links = response .external_links or []
657+ link = next ((l for l in links if l .chunk_index == chunk_index ), None )
658+ if not link :
659+ raise ServerOperationError (
660+ f"No link found for chunk index { chunk_index } " ,
661+ {
662+ "operation-id" : statement_id ,
663+ "diagnostic-info" : None ,
664+ },
665+ )
666+
667+ return link
668+
634669 # == Metadata Operations ==
635670
636671 def get_catalogs (
0 commit comments