44import math
55import time
66import threading
7+ import lz4 .frame
78from ssl import CERT_NONE , CERT_REQUIRED , create_default_context
89
910import pyarrow
@@ -451,7 +452,7 @@ def open_session(self, session_configuration, catalog, schema):
451452 initial_namespace = None
452453
453454 open_session_req = ttypes .TOpenSessionReq (
454- client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V5 ,
455+ client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V6 ,
455456 client_protocol = None ,
456457 initialNamespace = initial_namespace ,
457458 canUseMultipleCatalogs = True ,
@@ -507,7 +508,7 @@ def _poll_for_status(self, op_handle):
507508 )
508509 return self .make_request (self ._client .GetOperationStatus , req )
509510
510- def _create_arrow_table (self , t_row_set , schema_bytes , description ):
511+ def _create_arrow_table (self , t_row_set , lz4_compressed , schema_bytes , description ):
511512 if t_row_set .columns is not None :
512513 (
513514 arrow_table ,
@@ -520,7 +521,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
520521 arrow_table ,
521522 num_rows ,
522523 ) = ThriftBackend ._convert_arrow_based_set_to_arrow_table (
523- t_row_set .arrowBatches , schema_bytes
524+ t_row_set .arrowBatches , lz4_compressed , schema_bytes
524525 )
525526 else :
526527 raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
@@ -545,13 +546,20 @@ def _convert_decimals_in_arrow_table(table, description):
545546 return table
546547
547548 @staticmethod
548- def _convert_arrow_based_set_to_arrow_table (arrow_batches , schema_bytes ):
549+ def _convert_arrow_based_set_to_arrow_table (
550+ arrow_batches , lz4_compressed , schema_bytes
551+ ):
549552 ba = bytearray ()
550553 ba += schema_bytes
551554 n_rows = 0
552- for arrow_batch in arrow_batches :
553- n_rows += arrow_batch .rowCount
554- ba += arrow_batch .batch
555+ if lz4_compressed :
556+ for arrow_batch in arrow_batches :
557+ n_rows += arrow_batch .rowCount
558+ ba += lz4 .frame .decompress (arrow_batch .batch )
559+ else :
560+ for arrow_batch in arrow_batches :
561+ n_rows += arrow_batch .rowCount
562+ ba += arrow_batch .batch
555563 arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
556564 return arrow_table , n_rows
557565
@@ -708,7 +716,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
708716 ]
709717 )
710718 )
711-
712719 direct_results = resp .directResults
713720 has_been_closed_server_side = direct_results and direct_results .closeOperation
714721 has_more_rows = (
@@ -725,12 +732,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
725732 .serialize ()
726733 .to_pybytes ()
727734 )
728-
735+ lz4_compressed = t_result_set_metadata_resp . lz4Compressed
729736 if direct_results and direct_results .resultSet :
730737 assert direct_results .resultSet .results .startRowOffset == 0
731738 assert direct_results .resultSetMetadata
739+
732740 arrow_results , n_rows = self ._create_arrow_table (
733- direct_results .resultSet .results , schema_bytes , description
741+ direct_results .resultSet .results ,
742+ lz4_compressed ,
743+ schema_bytes ,
744+ description ,
734745 )
735746 arrow_queue_opt = ArrowQueue (arrow_results , n_rows , 0 )
736747 else :
@@ -740,6 +751,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
740751 status = operation_state ,
741752 has_been_closed_server_side = has_been_closed_server_side ,
742753 has_more_rows = has_more_rows ,
754+ lz4_compressed = lz4_compressed ,
743755 command_handle = resp .operationHandle ,
744756 description = description ,
745757 arrow_schema_bytes = schema_bytes ,
@@ -783,7 +795,9 @@ def _check_direct_results_for_error(t_spark_direct_results):
783795 t_spark_direct_results .closeOperation
784796 )
785797
786- def execute_command (self , operation , session_handle , max_rows , max_bytes , cursor ):
798+ def execute_command (
799+ self , operation , session_handle , max_rows , max_bytes , lz4_compression , cursor
800+ ):
787801 assert session_handle is not None
788802
789803 spark_arrow_types = ttypes .TSparkArrowTypes (
@@ -802,7 +816,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
802816 maxRows = max_rows , maxBytes = max_bytes
803817 ),
804818 canReadArrowResult = True ,
805- canDecompressLZ4Result = False ,
819+ canDecompressLZ4Result = lz4_compression ,
806820 canDownloadResult = False ,
807821 confOverlay = {
808822 # We want to receive proper Timestamp arrow types.
@@ -916,6 +930,7 @@ def fetch_results(
916930 max_rows ,
917931 max_bytes ,
918932 expected_row_start_offset ,
933+ lz4_compressed ,
919934 arrow_schema_bytes ,
920935 description ,
921936 ):
@@ -941,7 +956,7 @@ def fetch_results(
941956 )
942957 )
943958 arrow_results , n_rows = self ._create_arrow_table (
944- resp .results , arrow_schema_bytes , description
959+ resp .results , lz4_compressed , arrow_schema_bytes , description
945960 )
946961 arrow_queue = ArrowQueue (arrow_results , n_rows )
947962
0 commit comments