1818
1919if TYPE_CHECKING :
2020 from databricks .sql .client import Cursor
21- from databricks .sql .backend .sea .result_set import SeaResultSet
21+
22+ from databricks .sql .backend .sea .result_set import SeaResultSet
2223
2324from databricks .sql .backend .databricks_client import DatabricksClient
2425from databricks .sql .backend .types import (
@@ -332,7 +333,7 @@ def _extract_description_from_manifest(
332333 return columns
333334
334335 def _results_message_to_execute_response (
335- self , response : GetStatementResponse
336+ self , response : Union [ ExecuteStatementResponse , GetStatementResponse ]
336337 ) -> ExecuteResponse :
337338 """
338339 Convert a SEA response to an ExecuteResponse and extract result data.
@@ -366,6 +367,27 @@ def _results_message_to_execute_response(
366367
367368 return execute_response
368369
370+ def _response_to_result_set (
371+ self ,
372+ response : Union [ExecuteStatementResponse , GetStatementResponse ],
373+ cursor : Cursor ,
374+ ) -> SeaResultSet :
375+ """
376+ Convert a SEA response to a SeaResultSet.
377+ """
378+
379+ execute_response = self ._results_message_to_execute_response (response )
380+
381+ return SeaResultSet (
382+ connection = cursor .connection ,
383+ execute_response = execute_response ,
384+ sea_client = self ,
385+ result_data = response .result ,
386+ manifest = response .manifest ,
387+ buffer_size_bytes = cursor .buffer_size_bytes ,
388+ arraysize = cursor .arraysize ,
389+ )
390+
369391 def _check_command_not_in_failed_or_closed_state (
370392 self , state : CommandState , command_id : CommandId
371393 ) -> None :
@@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
386408
387409 def _wait_until_command_done (
388410 self , response : ExecuteStatementResponse
389- ) -> CommandState :
411+ ) -> Union [ ExecuteStatementResponse , GetStatementResponse ] :
390412 """
391413 Wait until a command is done.
392414 """
393415
394- state = response .status .state
395- command_id = CommandId .from_sea_statement_id (response .statement_id )
416+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417+
418+ state = final_response .status .state
419+ command_id = CommandId .from_sea_statement_id (final_response .statement_id )
396420
397421 while state in [CommandState .PENDING , CommandState .RUNNING ]:
398422 time .sleep (self .POLL_INTERVAL_SECONDS )
399- state = self .get_query_state (command_id )
423+ final_response = self ._poll_query (command_id )
424+ state = final_response .status .state
400425
401426 self ._check_command_not_in_failed_or_closed_state (state , command_id )
402427
403- return state
428+ return final_response
404429
405430 def execute_command (
406431 self ,
@@ -506,8 +531,11 @@ def execute_command(
506531 if async_op :
507532 return None
508533
509- self ._wait_until_command_done (response )
510- return self .get_execution_result (command_id , cursor )
534+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
535+ if response .status .state != CommandState .SUCCEEDED :
536+ final_response = self ._wait_until_command_done (response )
537+
538+ return self ._response_to_result_set (final_response , cursor )
511539
512540 def cancel_command (self , command_id : CommandId ) -> None :
513541 """
@@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
559587 data = request .to_dict (),
560588 )
561589
562- def get_query_state (self , command_id : CommandId ) -> CommandState :
590+ def _poll_query (self , command_id : CommandId ) -> GetStatementResponse :
563591 """
564- Get the state of a running query.
565-
566- Args:
567- command_id: Command identifier
568-
569- Returns:
570- CommandState: The current state of the command
571-
572- Raises:
573- ValueError: If the command ID is invalid
592+ Poll for the current command info.
574593 """
575594
576595 if command_id .backend_type != BackendType .SEA :
@@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
586605 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
587606 data = request .to_dict (),
588607 )
589-
590- # Parse the response
591608 response = GetStatementResponse .from_dict (response_data )
609+
610+ return response
611+
612+ def get_query_state (self , command_id : CommandId ) -> CommandState :
613+ """
614+ Get the state of a running query.
615+
616+ Args:
617+ command_id: Command identifier
618+
619+ Returns:
620+ CommandState: The current state of the command
621+
622+ Raises:
623+ ProgrammingError: If the command ID is invalid
624+ """
625+
626+ response = self ._poll_query (command_id )
592627 return response .status .state
593628
594629 def get_execution_result (
@@ -610,38 +645,8 @@ def get_execution_result(
610645 ValueError: If the command ID is invalid
611646 """
612647
613- if command_id .backend_type != BackendType .SEA :
614- raise ValueError ("Not a valid SEA command ID" )
615-
616- sea_statement_id = command_id .to_sea_statement_id ()
617- if sea_statement_id is None :
618- raise ValueError ("Not a valid SEA command ID" )
619-
620- # Create the request model
621- request = GetStatementRequest (statement_id = sea_statement_id )
622-
623- # Get the statement result
624- response_data = self ._http_client ._make_request (
625- method = "GET" ,
626- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
627- data = request .to_dict (),
628- )
629- response = GetStatementResponse .from_dict (response_data )
630-
631- # Create and return a SeaResultSet
632- from databricks .sql .backend .sea .result_set import SeaResultSet
633-
634- execute_response = self ._results_message_to_execute_response (response )
635-
636- return SeaResultSet (
637- connection = cursor .connection ,
638- execute_response = execute_response ,
639- sea_client = self ,
640- result_data = response .result ,
641- manifest = response .manifest ,
642- buffer_size_bytes = cursor .buffer_size_bytes ,
643- arraysize = cursor .arraysize ,
644- )
648+ response = self ._poll_query (command_id )
649+ return self ._response_to_result_set (response , cursor )
645650
646651 def get_chunk_links (
647652 self , statement_id : str , chunk_index : int
0 commit comments