@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
655655 }
656656
657657
658- def _combine_positional_deletes (positional_deletes : List [pa .ChunkedArray ], rows : int ) -> pa .Array :
658+ def _combine_positional_deletes (positional_deletes : List [pa .ChunkedArray ], start_index : int , end_index : int ) -> pa .Array :
659659 if len (positional_deletes ) == 1 :
660660 all_chunks = positional_deletes [0 ]
661661 else :
662662 all_chunks = pa .chunked_array (itertools .chain (* [arr .chunks for arr in positional_deletes ]))
663- return np .setdiff1d (np .arange (rows ), all_chunks , assume_unique = False )
663+ return np .subtract ( np . setdiff1d (np .arange (start_index , end_index ), all_chunks , assume_unique = False ), start_index )
664664
665665
666666def pyarrow_to_schema (schema : pa .Schema , name_mapping : Optional [NameMapping ] = None ) -> Schema :
@@ -995,17 +995,16 @@ def _field_id(self, field: pa.Field) -> int:
995995 return - 1
996996
997997
998- def _task_to_table (
998+ def _task_to_record_batches (
999999 fs : FileSystem ,
10001000 task : FileScanTask ,
10011001 bound_row_filter : BooleanExpression ,
10021002 projected_schema : Schema ,
10031003 projected_field_ids : Set [int ],
10041004 positional_deletes : Optional [List [ChunkedArray ]],
10051005 case_sensitive : bool ,
1006- limit : Optional [int ] = None ,
10071006 name_mapping : Optional [NameMapping ] = None ,
1008- ) -> Optional [pa .Table ]:
1007+ ) -> Iterator [pa .RecordBatch ]:
10091008 _ , _ , path = PyArrowFileIO .parse_location (task .file .file_path )
10101009 arrow_format = ds .ParquetFileFormat (pre_buffer = True , buffer_size = (ONE_MEGABYTE * 8 ))
10111010 with fs .open_input_file (path ) as fin :
@@ -1035,36 +1034,39 @@ def _task_to_table(
10351034 columns = [col .name for col in file_project_schema .columns ],
10361035 )
10371036
1038- if positional_deletes :
1039- # Create the mask of indices that we're interested in
1040- indices = _combine_positional_deletes (positional_deletes , fragment .count_rows ())
1041-
1042- if limit :
1043- if pyarrow_filter is not None :
1044- # In case of the filter, we don't exactly know how many rows
1045- # we need to fetch upfront, can be optimized in the future:
1046- # https://github.com/apache/arrow/issues/35301
1047- arrow_table = fragment_scanner .take (indices )
1048- arrow_table = arrow_table .filter (pyarrow_filter )
1049- arrow_table = arrow_table .slice (0 , limit )
1050- else :
1051- arrow_table = fragment_scanner .take (indices [0 :limit ])
1052- else :
1053- arrow_table = fragment_scanner .take (indices )
1037+ current_index = 0
1038+ batches = fragment_scanner .to_batches ()
1039+ for batch in batches :
1040+ if positional_deletes :
1041+ # Create the mask of indices that we're interested in
1042+ indices = _combine_positional_deletes (positional_deletes , current_index , current_index + len (batch ))
1043+ batch = batch .take (indices )
10541044 # Apply the user filter
10551045 if pyarrow_filter is not None :
1046+ # we need to switch back and forth between RecordBatch and Table
1047+ # as Expression filter isn't yet supported in RecordBatch
1048+ # https://github.com/apache/arrow/issues/39220
1049+ arrow_table = pa .Table .from_batches ([batch ])
10561050 arrow_table = arrow_table .filter (pyarrow_filter )
1057- else :
1058- # If there are no deletes, we can just take the head
1059- # and the user-filter is already applied
1060- if limit :
1061- arrow_table = fragment_scanner .head (limit )
1062- else :
1063- arrow_table = fragment_scanner .to_table ()
1051+ batch = arrow_table .to_batches ()[0 ]
1052+ yield to_requested_schema (projected_schema , file_project_schema , batch )
1053+ current_index += len (batch )
10641054
1065- if len (arrow_table ) < 1 :
1066- return None
1067- return to_requested_schema (projected_schema , file_project_schema , arrow_table )
1055+
1056+ def _task_to_table (
1057+ fs : FileSystem ,
1058+ task : FileScanTask ,
1059+ bound_row_filter : BooleanExpression ,
1060+ projected_schema : Schema ,
1061+ projected_field_ids : Set [int ],
1062+ positional_deletes : Optional [List [ChunkedArray ]],
1063+ case_sensitive : bool ,
1064+ name_mapping : Optional [NameMapping ] = None ,
1065+ ) -> pa .Table :
1066+ batches = _task_to_record_batches (
1067+ fs , task , bound_row_filter , projected_schema , projected_field_ids , positional_deletes , case_sensitive , name_mapping
1068+ )
1069+ return pa .Table .from_batches (batches , schema = schema_to_pyarrow (projected_schema , include_field_ids = False ))
10681070
10691071
10701072def _read_all_delete_files (fs : FileSystem , tasks : Iterable [FileScanTask ]) -> Dict [str , List [ChunkedArray ]]:
@@ -1143,7 +1145,6 @@ def project_table(
11431145 projected_field_ids ,
11441146 deletes_per_file .get (task .file .file_path ),
11451147 case_sensitive ,
1146- limit ,
11471148 table_metadata .name_mapping (),
11481149 )
11491150 for task in tasks
@@ -1177,16 +1178,86 @@ def project_table(
11771178 return result
11781179
11791180
1180- def to_requested_schema (requested_schema : Schema , file_schema : Schema , table : pa .Table ) -> pa .Table :
1181- struct_array = visit_with_partner (requested_schema , table , ArrowProjectionVisitor (file_schema ), ArrowAccessor (file_schema ))
1181+ def project_batches (
1182+ tasks : Iterable [FileScanTask ],
1183+ table_metadata : TableMetadata ,
1184+ io : FileIO ,
1185+ row_filter : BooleanExpression ,
1186+ projected_schema : Schema ,
1187+ case_sensitive : bool = True ,
1188+ limit : Optional [int ] = None ,
1189+ ) -> Iterator [pa .RecordBatch ]:
1190+ """Resolve the right columns based on the identifier.
1191+
1192+ Args:
1193+ tasks (Iterable[FileScanTask]): A URI or a path to a local file.
1194+ table_metadata (TableMetadata): The table metadata of the table that's being queried
1195+ io (FileIO): A FileIO to open streams to the object store
1196+ row_filter (BooleanExpression): The expression for filtering rows.
1197+ projected_schema (Schema): The output schema.
1198+ case_sensitive (bool): Case sensitivity when looking up column names.
1199+ limit (Optional[int]): Limit the number of records.
1200+
1201+ Raises:
1202+ ResolveError: When an incompatible query is done.
1203+ """
1204+ scheme , netloc , _ = PyArrowFileIO .parse_location (table_metadata .location )
1205+ if isinstance (io , PyArrowFileIO ):
1206+ fs = io .fs_by_scheme (scheme , netloc )
1207+ else :
1208+ try :
1209+ from pyiceberg .io .fsspec import FsspecFileIO
1210+
1211+ if isinstance (io , FsspecFileIO ):
1212+ from pyarrow .fs import PyFileSystem
1213+
1214+ fs = PyFileSystem (FSSpecHandler (io .get_fs (scheme )))
1215+ else :
1216+ raise ValueError (f"Expected PyArrowFileIO or FsspecFileIO, got: { io } " )
1217+ except ModuleNotFoundError as e :
1218+ # When FsSpec is not installed
1219+ raise ValueError (f"Expected PyArrowFileIO or FsspecFileIO, got: { io } " ) from e
1220+
1221+ bound_row_filter = bind (table_metadata .schema (), row_filter , case_sensitive = case_sensitive )
1222+
1223+ projected_field_ids = {
1224+ id for id in projected_schema .field_ids if not isinstance (projected_schema .find_type (id ), (MapType , ListType ))
1225+ }.union (extract_field_ids (bound_row_filter ))
1226+
1227+ deletes_per_file = _read_all_delete_files (fs , tasks )
1228+
1229+ total_row_count = 0
1230+
1231+ for task in tasks :
1232+ batches = _task_to_record_batches (
1233+ fs ,
1234+ task ,
1235+ bound_row_filter ,
1236+ projected_schema ,
1237+ projected_field_ids ,
1238+ deletes_per_file .get (task .file .file_path ),
1239+ case_sensitive ,
1240+ table_metadata .name_mapping (),
1241+ )
1242+ for batch in batches :
1243+ if limit is not None :
1244+ if total_row_count + len (batch ) >= limit :
1245+ yield batch .slice (0 , limit - total_row_count )
1246+ break
1247+ yield batch
1248+ total_row_count += len (batch )
1249+
1250+
1251+ def to_requested_schema (requested_schema : Schema , file_schema : Schema , batch : pa .RecordBatch ) -> pa .RecordBatch :
1252+ struct_array = visit_with_partner (requested_schema , batch , ArrowProjectionVisitor (file_schema ), ArrowAccessor (file_schema ))
11821253
11831254 arrays = []
11841255 fields = []
11851256 for pos , field in enumerate (requested_schema .fields ):
11861257 array = struct_array .field (pos )
11871258 arrays .append (array )
11881259 fields .append (pa .field (field .name , array .type , field .optional ))
1189- return pa .Table .from_arrays (arrays , schema = pa .schema (fields ))
1260+ return pa .RecordBatch .from_arrays (arrays , schema = pa .schema (fields ))
11901261
11911262
11921263class ArrowProjectionVisitor (SchemaWithPartnerVisitor [pa .Array , Optional [pa .Array ]]):
@@ -1293,8 +1364,10 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
12931364
12941365 if isinstance (partner_struct , pa .StructArray ):
12951366 return partner_struct .field (name )
1296- elif isinstance (partner_struct , pa .Table ):
1297- return partner_struct .column (name ).combine_chunks ()
1367+ elif isinstance (partner_struct , pa .RecordBatch ):
1368+ return partner_struct .column (name )
1369+ else :
1370+ raise ValueError (f"Cannot find { name } in expected partner_struct type { type (partner_struct )} " )
12981371
12991372 return None
13001373
@@ -1831,15 +1904,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
18311904
18321905 def write_parquet (task : WriteTask ) -> DataFile :
18331906 table_schema = task .schema
1834- arrow_table = pa . Table . from_batches ( task . record_batches )
1907+
18351908 # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
18361909 # otherwise use the original schema
18371910 if (sanitized_schema := sanitize_column_names (table_schema )) != table_schema :
18381911 file_schema = sanitized_schema
18391912 else :
18401913 file_schema = table_schema
18411914
1842- arrow_table = to_requested_schema (requested_schema = file_schema , file_schema = table_schema , table = arrow_table )
1915+ batches = [
1916+ to_requested_schema (requested_schema = file_schema , file_schema = table_schema , batch = batch )
1917+ for batch in task .record_batches
1918+ ]
1919+ arrow_table = pa .Table .from_batches (batches )
18431920 file_path = f'{ table_metadata .location } /data/{ task .generate_data_file_path ("parquet" )} '
18441921 fo = io .new_output (file_path )
18451922 with fo .create (overwrite = True ) as fos :
0 commit comments