Skip to content
113 changes: 59 additions & 54 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

if TYPE_CHECKING:
from databricks.sql.client import Cursor
from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import (
Expand Down Expand Up @@ -332,7 +333,7 @@ def _extract_description_from_manifest(
return columns

def _results_message_to_execute_response(
self, response: GetStatementResponse
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
) -> ExecuteResponse:
"""
Convert a SEA response to an ExecuteResponse and extract result data.
Expand Down Expand Up @@ -366,6 +367,27 @@ def _results_message_to_execute_response(

return execute_response

def _response_to_result_set(
self,
response: Union[ExecuteStatementResponse, GetStatementResponse],
cursor: Cursor,
) -> SeaResultSet:
"""
Convert a SEA response to a SeaResultSet.
"""

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)

def _check_command_not_in_failed_or_closed_state(
self, state: CommandState, command_id: CommandId
) -> None:
Expand All @@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(

def _wait_until_command_done(
self, response: ExecuteStatementResponse
) -> CommandState:
) -> Union[ExecuteStatementResponse, GetStatementResponse]:
"""
Wait until a command is done.
"""

state = response.status.state
command_id = CommandId.from_sea_statement_id(response.statement_id)
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response

state = final_response.status.state
command_id = CommandId.from_sea_statement_id(final_response.statement_id)

while state in [CommandState.PENDING, CommandState.RUNNING]:
time.sleep(self.POLL_INTERVAL_SECONDS)
state = self.get_query_state(command_id)
final_response = self._poll_query(command_id)
state = final_response.status.state

self._check_command_not_in_failed_or_closed_state(state, command_id)

return state
return final_response

def execute_command(
self,
Expand Down Expand Up @@ -506,8 +531,11 @@ def execute_command(
if async_op:
return None

self._wait_until_command_done(response)
return self.get_execution_result(command_id, cursor)
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
if response.status.state != CommandState.SUCCEEDED:
final_response = self._wait_until_command_done(response)

return self._response_to_result_set(final_response, cursor)

def cancel_command(self, command_id: CommandId) -> None:
"""
Expand Down Expand Up @@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
data=request.to_dict(),
)

def get_query_state(self, command_id: CommandId) -> CommandState:
def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ValueError: If the command ID is invalid
Poll for the current command info.
"""

if command_id.backend_type != BackendType.SEA:
Expand All @@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)

# Parse the response
response = GetStatementResponse.from_dict(response_data)

return response

def get_query_state(self, command_id: CommandId) -> CommandState:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
"""

response = self._poll_query(command_id)
return response.status.state

def get_execution_result(
Expand All @@ -610,38 +645,8 @@ def get_execution_result(
ValueError: If the command ID is invalid
"""

if command_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA command ID")

sea_statement_id = command_id.to_sea_statement_id()
if sea_statement_id is None:
raise ValueError("Not a valid SEA command ID")

# Create the request model
request = GetStatementRequest(statement_id=sea_statement_id)

# Get the statement result
response_data = self._http_client._make_request(
method="GET",
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = GetStatementResponse.from_dict(response_data)

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)
response = self._poll_query(command_id)
return self._response_to_result_set(response, cursor)

def get_chunk_links(
self, statement_id: str, chunk_index: int
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter

Expand All @@ -15,6 +14,7 @@

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.types import Row
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
from databricks.sql.backend.types import ExecuteResponse
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_command_execution_sync(
mock_http_client._make_request.return_value = execute_response

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
result = sea_client.execute_command(
operation="SELECT 1",
Expand All @@ -242,9 +242,6 @@ def test_command_execution_sync(
enforce_embedded_schema_correctness=False,
)
assert result == "mock_result_set"
cmd_id_arg = mock_get_result.call_args[0][0]
assert isinstance(cmd_id_arg, CommandId)
assert cmd_id_arg.guid == "test-statement-123"

# Test with invalid session ID
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -332,7 +329,7 @@ def test_command_execution_advanced(
mock_http_client._make_request.side_effect = [initial_response, poll_response]

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
with patch("time.sleep"):
result = sea_client.execute_command(
Expand Down Expand Up @@ -360,7 +357,7 @@ def test_command_execution_advanced(
dbsql_param = IntegerParameter(name="param1", value=1)
param = dbsql_param.as_tspark_param(named=True)

with patch.object(sea_client, "get_execution_result"):
with patch.object(sea_client, "_response_to_result_set"):
sea_client.execute_command(
operation="SELECT * FROM table WHERE col = :param1",
session_id=sea_session_id,
Expand Down
Loading