2121import databricks .sql .client as client
2222from databricks .sql import InterfaceError , DatabaseError , Error , NotSupportedError
2323from databricks .sql .types import Row
24+ from databricks .sql .result_set import ResultSet , ThriftResultSet
2425
2526from tests .unit .test_fetches import FetchTests
2627from tests .unit .test_thrift_backend import ThriftBackendTestSuite
@@ -34,12 +35,11 @@ def new(cls):
3435 ThriftBackendMock .return_value = ThriftBackendMock
3536
3637 cls .apply_property_to_mock (ThriftBackendMock , staging_allowed_local_path = None )
37- MockTExecuteStatementResp = MagicMock ( spec = TExecuteStatementResp ())
38-
38+
39+ mock_result_set = Mock ( spec = ThriftResultSet )
3940 cls .apply_property_to_mock (
40- MockTExecuteStatementResp ,
41+ mock_result_set ,
4142 description = None ,
42- arrow_queue = None ,
4343 is_staging_operation = False ,
4444 command_handle = b"\x22 " ,
4545 has_been_closed_server_side = True ,
@@ -48,7 +48,7 @@ def new(cls):
4848 arrow_schema_bytes = b"schema" ,
4949 )
5050
51- ThriftBackendMock .execute_command .return_value = MockTExecuteStatementResp
51+ ThriftBackendMock .execute_command .return_value = mock_result_set
5252
5353 return ThriftBackendMock
5454
@@ -81,21 +81,22 @@ class ClientTestSuite(unittest.TestCase):
8181 }
8282
8383 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME , ThriftDatabricksClientMockFactory .new ())
84- @patch ("%s.client.ResultSet" % PACKAGE_NAME )
85- def test_closing_connection_closes_commands (self , mock_result_set_class ):
86- # Test once with has_been_closed_server side, once without
84+ def test_closing_connection_closes_commands (self ):
8785 for closed in (True , False ):
8886 with self .subTest (closed = closed ):
89- mock_result_set_class .return_value = Mock ()
9087 connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
9188 cursor = connection .cursor ()
92- cursor .execute ("SELECT 1;" )
89+
90+ # Create a mock result set and set it as the active result set
91+ mock_result_set = Mock ()
92+ mock_result_set .has_been_closed_server_side = closed
93+ cursor .active_result_set = mock_result_set
94+
95+ # Close the connection
9396 connection .close ()
94-
95- self .assertTrue (
96- mock_result_set_class .return_value .has_been_closed_server_side
97- )
98- mock_result_set_class .return_value .close .assert_called_once_with ()
97+
98+ # After closing the connection, the close method should have been called on the result set
99+ mock_result_set .close .assert_called_once_with ()
99100
100101 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
101102 def test_cant_open_cursor_on_closed_connection (self , mock_client_class ):
@@ -122,10 +123,11 @@ def test_arraysize_buffer_size_passthrough(
122123 def test_closing_result_set_with_closed_connection_soft_closes_commands (self ):
123124 mock_connection = Mock ()
124125 mock_backend = Mock ()
125- result_set = client .ResultSet (
126+
127+ result_set = ThriftResultSet (
126128 connection = mock_connection ,
127- backend = mock_backend ,
128129 execute_response = Mock (),
130+ thrift_client = mock_backend ,
129131 )
130132 # Setup session mock on the mock_connection
131133 mock_session = Mock ()
@@ -147,7 +149,7 @@ def test_closing_result_set_hard_closes_commands(self):
147149 mock_session .open = True
148150 type(mock_connection ).session = PropertyMock (return_value = mock_session )
149151
150- result_set = client . ResultSet (
152+ result_set = ThriftResultSet (
151153 mock_connection , mock_results_response , mock_thrift_backend
152154 )
153155
@@ -157,16 +159,22 @@ def test_closing_result_set_hard_closes_commands(self):
157159 mock_results_response .command_handle
158160 )
159161
160- @patch ("%s.client.ResultSet " % PACKAGE_NAME )
162+ @patch ("%s.result_set.ThriftResultSet " % PACKAGE_NAME )
161163 def test_executing_multiple_commands_uses_the_most_recent_command (
162164 self , mock_result_set_class
163165 ):
164-
165166 mock_result_sets = [Mock (), Mock ()]
167+ # Set is_staging_operation to False to avoid _handle_staging_operation being called
168+ for mock_rs in mock_result_sets :
169+ mock_rs .is_staging_operation = False
170+
166171 mock_result_set_class .side_effect = mock_result_sets
172+
173+ mock_backend = ThriftDatabricksClientMockFactory .new ()
174+ mock_backend .execute_command .side_effect = mock_result_sets
167175
168176 cursor = client .Cursor (
169- connection = Mock (), backend = ThriftDatabricksClientMockFactory . new ()
177+ connection = Mock (), backend = mock_backend
170178 )
171179 cursor .execute ("SELECT 1;" )
172180 cursor .execute ("SELECT 1;" )
@@ -192,7 +200,7 @@ def test_closed_cursor_doesnt_allow_operations(self):
192200 self .assertIn ("closed" , e .msg )
193201
194202 def test_negative_fetch_throws_exception (self ):
195- result_set = client . ResultSet (Mock (), Mock (), Mock ())
203+ result_set = ThriftResultSet (Mock (), Mock (), Mock ())
196204
197205 with self .assertRaises (ValueError ) as e :
198206 result_set .fetchmany (- 1 )
@@ -334,14 +342,19 @@ def test_execute_parameter_passthrough(self):
334342 expected_query ,
335343 )
336344
337- @patch ("%s.client.ResultSet " % PACKAGE_NAME )
345+ @patch ("%s.result_set.ThriftResultSet " % PACKAGE_NAME )
338346 def test_executemany_parameter_passhthrough_and_uses_last_result_set (
339347 self , mock_result_set_class
340348 ):
341349 # Create a new mock result set each time the class is instantiated
342350 mock_result_set_instances = [Mock (), Mock (), Mock ()]
351+ # Set is_staging_operation to False to avoid _handle_staging_operation being called
352+ for mock_rs in mock_result_set_instances :
353+ mock_rs .is_staging_operation = False
354+
343355 mock_result_set_class .side_effect = mock_result_set_instances
344356 mock_backend = ThriftDatabricksClientMockFactory .new ()
357+ mock_backend .execute_command .side_effect = mock_result_set_instances
345358
346359 cursor = client .Cursor (Mock (), mock_backend )
347360
@@ -494,8 +507,9 @@ def test_staging_operation_response_is_handled(
494507 ThriftDatabricksClientMockFactory .apply_property_to_mock (
495508 mock_execute_response , is_staging_operation = True
496509 )
497- mock_client_class .execute_command .return_value = mock_execute_response
498- mock_client_class .return_value = mock_client_class
510+ mock_client = mock_client_class .return_value
511+ mock_client .execute_command .return_value = Mock (is_staging_operation = True )
512+ mock_client_class .return_value = mock_client
499513
500514 connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
501515 cursor = connection .cursor ()
0 commit comments