33import logging
44import math
55import time
6- import uuid
76import threading
8- from typing import List , Optional , Union , Any , TYPE_CHECKING
7+ from typing import List , Union , Any , TYPE_CHECKING
98
109if TYPE_CHECKING :
1110 from databricks .sql .client import Cursor
12- from databricks .sql .result_set import ResultSet , ThriftResultSet
1311
14- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
1512from databricks .sql .backend .types import (
1613 CommandState ,
1714 SessionId ,
1815 CommandId ,
19- BackendType ,
16+ ExecuteResponse ,
2017)
2118from databricks .sql .backend .utils import guid_to_hex_id
2219
20+
2321try :
2422 import pyarrow
2523except ImportError :
4240)
4341
4442from databricks .sql .utils import (
45- ExecuteResponse ,
43+ ResultSetQueueFactory ,
4644 _bound ,
4745 RequestErrorInfo ,
4846 NoRetryReason ,
5351)
5452from databricks .sql .types import SSLOptions
5553from databricks .sql .backend .databricks_client import DatabricksClient
54+ from databricks .sql .result_set import ResultSet , ThriftResultSet
5655
5756logger = logging .getLogger (__name__ )
5857
@@ -758,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
758757 )
759758 direct_results = resp .directResults
760759 has_been_closed_server_side = direct_results and direct_results .closeOperation
761- has_more_rows = (
760+
761+ is_direct_results = (
762762 (not direct_results )
763763 or (not direct_results .resultSet )
764764 or direct_results .resultSet .hasMoreRows
765765 )
766+
766767 description = self ._hive_schema_to_description (
767768 t_result_set_metadata_resp .schema
768769 )
@@ -778,42 +779,28 @@ def _results_message_to_execute_response(self, resp, operation_state):
778779 schema_bytes = None
779780
780781 lz4_compressed = t_result_set_metadata_resp .lz4Compressed
781- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
782- if direct_results and direct_results .resultSet :
783- assert direct_results .resultSet .results .startRowOffset == 0
784- assert direct_results .resultSetMetadata
785-
786- arrow_queue_opt = ResultSetQueueFactory .build_queue (
787- row_set_type = t_result_set_metadata_resp .resultFormat ,
788- t_row_set = direct_results .resultSet .results ,
789- arrow_schema_bytes = schema_bytes ,
790- max_download_threads = self .max_download_threads ,
791- lz4_compressed = lz4_compressed ,
792- description = description ,
793- ssl_options = self ._ssl_options ,
794- )
795- else :
796- arrow_queue_opt = None
797-
798782 command_id = CommandId .from_thrift_handle (resp .operationHandle )
799783
800- return ExecuteResponse (
801- arrow_queue = arrow_queue_opt ,
802- status = CommandState .from_thrift_state (operation_state ),
803- has_been_closed_server_side = has_been_closed_server_side ,
804- has_more_rows = has_more_rows ,
805- lz4_compressed = lz4_compressed ,
806- is_staging_operation = is_staging_operation ,
784+ status = CommandState .from_thrift_state (operation_state )
785+ if status is None :
786+ raise ValueError (f"Unknown command state: { operation_state } " )
787+
788+ execute_response = ExecuteResponse (
807789 command_id = command_id ,
790+ status = status ,
808791 description = description ,
792+ has_been_closed_server_side = has_been_closed_server_side ,
793+ lz4_compressed = lz4_compressed ,
794+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
809795 arrow_schema_bytes = schema_bytes ,
796+ result_format = t_result_set_metadata_resp .resultFormat ,
810797 )
811798
799+ return execute_response , is_direct_results
800+
812801 def get_execution_result (
813802 self , command_id : CommandId , cursor : "Cursor"
814803 ) -> "ResultSet" :
815- from databricks .sql .result_set import ThriftResultSet
816-
817804 thrift_handle = command_id .to_thrift_handle ()
818805 if not thrift_handle :
819806 raise ValueError ("Not a valid Thrift command ID" )
@@ -835,9 +822,6 @@ def get_execution_result(
835822
836823 t_result_set_metadata_resp = resp .resultSetMetadata
837824
838- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
839- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
840- has_more_rows = resp .hasMoreRows
841825 description = self ._hive_schema_to_description (
842826 t_result_set_metadata_resp .schema
843827 )
@@ -852,26 +836,21 @@ def get_execution_result(
852836 else :
853837 schema_bytes = None
854838
855- queue = ResultSetQueueFactory .build_queue (
856- row_set_type = resp .resultSetMetadata .resultFormat ,
857- t_row_set = resp .results ,
858- arrow_schema_bytes = schema_bytes ,
859- max_download_threads = self .max_download_threads ,
860- lz4_compressed = lz4_compressed ,
861- description = description ,
862- ssl_options = self ._ssl_options ,
863- )
839+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841+ is_direct_results = resp .hasMoreRows
842+
843+ status = self .get_query_state (command_id )
864844
865845 execute_response = ExecuteResponse (
866- arrow_queue = queue ,
867- status = CommandState .from_thrift_state (resp .status ),
846+ command_id = command_id ,
847+ status = status ,
848+ description = description ,
868849 has_been_closed_server_side = False ,
869- has_more_rows = has_more_rows ,
870850 lz4_compressed = lz4_compressed ,
871851 is_staging_operation = is_staging_operation ,
872- command_id = command_id ,
873- description = description ,
874852 arrow_schema_bytes = schema_bytes ,
853+ result_format = t_result_set_metadata_resp .resultFormat ,
875854 )
876855
877856 return ThriftResultSet (
@@ -881,6 +860,10 @@ def get_execution_result(
881860 buffer_size_bytes = cursor .buffer_size_bytes ,
882861 arraysize = cursor .arraysize ,
883862 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
863+ t_row_set = resp .results ,
864+ max_download_threads = self .max_download_threads ,
865+ ssl_options = self ._ssl_options ,
866+ is_direct_results = is_direct_results ,
884867 )
885868
886869 def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -947,8 +930,6 @@ def execute_command(
947930 async_op = False ,
948931 enforce_embedded_schema_correctness = False ,
949932 ) -> Union ["ResultSet" , None ]:
950- from databricks .sql .result_set import ThriftResultSet
951-
952933 thrift_handle = session_id .to_thrift_handle ()
953934 if not thrift_handle :
954935 raise ValueError ("Not a valid Thrift session ID" )
@@ -995,7 +976,13 @@ def execute_command(
995976 self ._handle_execute_response_async (resp , cursor )
996977 return None
997978 else :
998- execute_response = self ._handle_execute_response (resp , cursor )
979+ execute_response , is_direct_results = self ._handle_execute_response (
980+ resp , cursor
981+ )
982+
983+ t_row_set = None
984+ if resp .directResults and resp .directResults .resultSet :
985+ t_row_set = resp .directResults .resultSet .results
999986
1000987 return ThriftResultSet (
1001988 connection = cursor .connection ,
@@ -1004,6 +991,10 @@ def execute_command(
1004991 buffer_size_bytes = max_bytes ,
1005992 arraysize = max_rows ,
1006993 use_cloud_fetch = use_cloud_fetch ,
994+ t_row_set = t_row_set ,
995+ max_download_threads = self .max_download_threads ,
996+ ssl_options = self ._ssl_options ,
997+ is_direct_results = is_direct_results ,
1007998 )
1008999
10091000 def get_catalogs (
@@ -1013,8 +1004,6 @@ def get_catalogs(
10131004 max_bytes : int ,
10141005 cursor : "Cursor" ,
10151006 ) -> "ResultSet" :
1016- from databricks .sql .result_set import ThriftResultSet
1017-
10181007 thrift_handle = session_id .to_thrift_handle ()
10191008 if not thrift_handle :
10201009 raise ValueError ("Not a valid Thrift session ID" )
@@ -1027,7 +1016,13 @@ def get_catalogs(
10271016 )
10281017 resp = self .make_request (self ._client .GetCatalogs , req )
10291018
1030- execute_response = self ._handle_execute_response (resp , cursor )
1019+ execute_response , is_direct_results = self ._handle_execute_response (
1020+ resp , cursor
1021+ )
1022+
1023+ t_row_set = None
1024+ if resp .directResults and resp .directResults .resultSet :
1025+ t_row_set = resp .directResults .resultSet .results
10311026
10321027 return ThriftResultSet (
10331028 connection = cursor .connection ,
@@ -1036,6 +1031,10 @@ def get_catalogs(
10361031 buffer_size_bytes = max_bytes ,
10371032 arraysize = max_rows ,
10381033 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1034+ t_row_set = t_row_set ,
1035+ max_download_threads = self .max_download_threads ,
1036+ ssl_options = self ._ssl_options ,
1037+ is_direct_results = is_direct_results ,
10391038 )
10401039
10411040 def get_schemas (
@@ -1047,8 +1046,6 @@ def get_schemas(
10471046 catalog_name = None ,
10481047 schema_name = None ,
10491048 ) -> "ResultSet" :
1050- from databricks .sql .result_set import ThriftResultSet
1051-
10521049 thrift_handle = session_id .to_thrift_handle ()
10531050 if not thrift_handle :
10541051 raise ValueError ("Not a valid Thrift session ID" )
@@ -1063,7 +1060,13 @@ def get_schemas(
10631060 )
10641061 resp = self .make_request (self ._client .GetSchemas , req )
10651062
1066- execute_response = self ._handle_execute_response (resp , cursor )
1063+ execute_response , is_direct_results = self ._handle_execute_response (
1064+ resp , cursor
1065+ )
1066+
1067+ t_row_set = None
1068+ if resp .directResults and resp .directResults .resultSet :
1069+ t_row_set = resp .directResults .resultSet .results
10671070
10681071 return ThriftResultSet (
10691072 connection = cursor .connection ,
@@ -1072,6 +1075,10 @@ def get_schemas(
10721075 buffer_size_bytes = max_bytes ,
10731076 arraysize = max_rows ,
10741077 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1078+ t_row_set = t_row_set ,
1079+ max_download_threads = self .max_download_threads ,
1080+ ssl_options = self ._ssl_options ,
1081+ is_direct_results = is_direct_results ,
10751082 )
10761083
10771084 def get_tables (
@@ -1085,8 +1092,6 @@ def get_tables(
10851092 table_name = None ,
10861093 table_types = None ,
10871094 ) -> "ResultSet" :
1088- from databricks .sql .result_set import ThriftResultSet
1089-
10901095 thrift_handle = session_id .to_thrift_handle ()
10911096 if not thrift_handle :
10921097 raise ValueError ("Not a valid Thrift session ID" )
@@ -1103,7 +1108,13 @@ def get_tables(
11031108 )
11041109 resp = self .make_request (self ._client .GetTables , req )
11051110
1106- execute_response = self ._handle_execute_response (resp , cursor )
1111+ execute_response , is_direct_results = self ._handle_execute_response (
1112+ resp , cursor
1113+ )
1114+
1115+ t_row_set = None
1116+ if resp .directResults and resp .directResults .resultSet :
1117+ t_row_set = resp .directResults .resultSet .results
11071118
11081119 return ThriftResultSet (
11091120 connection = cursor .connection ,
@@ -1112,6 +1123,10 @@ def get_tables(
11121123 buffer_size_bytes = max_bytes ,
11131124 arraysize = max_rows ,
11141125 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1126+ t_row_set = t_row_set ,
1127+ max_download_threads = self .max_download_threads ,
1128+ ssl_options = self ._ssl_options ,
1129+ is_direct_results = is_direct_results ,
11151130 )
11161131
11171132 def get_columns (
@@ -1125,8 +1140,6 @@ def get_columns(
11251140 table_name = None ,
11261141 column_name = None ,
11271142 ) -> "ResultSet" :
1128- from databricks .sql .result_set import ThriftResultSet
1129-
11301143 thrift_handle = session_id .to_thrift_handle ()
11311144 if not thrift_handle :
11321145 raise ValueError ("Not a valid Thrift session ID" )
@@ -1143,7 +1156,13 @@ def get_columns(
11431156 )
11441157 resp = self .make_request (self ._client .GetColumns , req )
11451158
1146- execute_response = self ._handle_execute_response (resp , cursor )
1159+ execute_response , is_direct_results = self ._handle_execute_response (
1160+ resp , cursor
1161+ )
1162+
1163+ t_row_set = None
1164+ if resp .directResults and resp .directResults .resultSet :
1165+ t_row_set = resp .directResults .resultSet .results
11471166
11481167 return ThriftResultSet (
11491168 connection = cursor .connection ,
@@ -1152,6 +1171,10 @@ def get_columns(
11521171 buffer_size_bytes = max_bytes ,
11531172 arraysize = max_rows ,
11541173 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1174+ t_row_set = t_row_set ,
1175+ max_download_threads = self .max_download_threads ,
1176+ ssl_options = self ._ssl_options ,
1177+ is_direct_results = is_direct_results ,
11551178 )
11561179
11571180 def _handle_execute_response (self , resp , cursor ):
0 commit comments