Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/polars-io/src/parquet/write/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub struct MetadataKeyValue {
pub struct ParquetFieldOverwrites {
pub name: Option<PlSmallStr>,
pub children: ChildFieldOverwrites,

pub required: Option<bool>,
pub field_id: Option<i32>,
pub metadata: Option<Vec<MetadataKeyValue>>,
}
Expand Down
14 changes: 9 additions & 5 deletions crates/polars-io/src/parquet/write/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ fn to_column_write_options_rec(
let mut column_options = ColumnWriteOptions {
field_id: None,
metadata: Vec::new(),
required: None,

// Dummy value.
children: ChildWriteOptions::Leaf(FieldWriteOptions {
Expand All @@ -191,6 +192,7 @@ fn to_column_write_options_rec(
if let Some(overwrites) = overwrites {
column_options.field_id = overwrites.field_id;
column_options.metadata = convert_metadata(&overwrites.metadata);
column_options.required = overwrites.required;
}

use arrow::datatypes::PhysicalType::*;
Expand All @@ -202,8 +204,9 @@ fn to_column_write_options_rec(
});
},
List | FixedSizeList | LargeList => {
let child_overwrites = overwrites.map(|o| match &o.children {
ChildFieldOverwrites::ListLike(child_overwrites) => child_overwrites.as_ref(),
let child_overwrites = overwrites.and_then(|o| match &o.children {
ChildFieldOverwrites::None => None,
ChildFieldOverwrites::ListLike(child_overwrites) => Some(child_overwrites.as_ref()),
_ => unreachable!(),
});

Expand All @@ -223,12 +226,13 @@ fn to_column_write_options_rec(
},
Struct => {
if let ArrowDataType::Struct(fields) = field.dtype().to_logical_type() {
let children_overwrites = overwrites.map(|o| match &o.children {
ChildFieldOverwrites::Struct(child_overwrites) => PlHashMap::from_iter(
let children_overwrites = overwrites.and_then(|o| match &o.children {
ChildFieldOverwrites::None => None,
ChildFieldOverwrites::Struct(child_overwrites) => Some(PlHashMap::from_iter(
child_overwrites
.iter()
.map(|f| (f.name.as_ref().unwrap(), f)),
),
)),
_ => unreachable!(),
});

Expand Down
13 changes: 13 additions & 0 deletions crates/polars-parquet/src/arrow/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub use crate::parquet::metadata::{
Descriptor, FileMetadata, KeyValue, SchemaDescriptor, ThriftFileMetadata,
};
pub use crate::parquet::page::{CompressedDataPage, CompressedPage, Page};
use crate::parquet::schema::Repetition;
use crate::parquet::schema::types::PrimitiveType as ParquetPrimitiveType;
pub use crate::parquet::schema::types::{
FieldInfo, ParquetType, PhysicalType as ParquetPhysicalType,
Expand Down Expand Up @@ -95,6 +96,7 @@ pub struct WriteOptions {
pub struct ColumnWriteOptions {
pub field_id: Option<i32>,
pub metadata: Vec<KeyValue>,
pub required: Option<bool>,
pub children: ChildWriteOptions,
}

Expand Down Expand Up @@ -129,6 +131,7 @@ impl ColumnWriteOptions {
Self {
field_id: None,
metadata: Vec::new(),
required: None,
children,
}
}
Expand Down Expand Up @@ -448,6 +451,10 @@ pub fn array_to_page_simple(
) -> PolarsResult<Page> {
let dtype = array.dtype();

if type_.field_info.repetition == Repetition::Required && array.null_count() > 0 {
polars_bail!(InvalidOperation: "writing a missing value to required parquet column '{}'", type_.field_info.name);
}

match dtype.to_logical_type() {
ArrowDataType::Boolean => boolean::array_to_page(
array.as_any().downcast_ref().unwrap(),
Expand Down Expand Up @@ -816,6 +823,12 @@ fn array_to_page_nested(
options: WriteOptions,
_encoding: Encoding,
) -> PolarsResult<Page> {
if type_.field_info.repetition == Repetition::Required
&& array.validity().is_some_and(|v| v.unset_bits() > 0)
{
polars_bail!(InvalidOperation: "writing a missing value to required parquet column '{}'", type_.field_info.name);
}

use ArrowDataType::*;
match array.dtype().to_logical_type() {
Null => {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-parquet/src/arrow/write/pages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ fn to_nested_recursive(
) -> PolarsResult<()> {
let is_optional = is_nullable(type_.get_field_info());

if !is_optional && array.null_count() > 0 {
polars_bail!(InvalidOperation: "writing a missing value to required field '{}'", type_.name());
}

use PhysicalType::*;
match array.dtype().to_physical_type() {
Struct => {
Expand Down
13 changes: 10 additions & 3 deletions crates/polars-parquet/src/arrow/write/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ fn insert_field_metadata(field: &mut Cow<Field>, options: &ColumnWriteOptions) {
field.metadata = Some(Arc::new(metadata));
}

if let Some(v) = options.required {
if v == field.is_nullable {
let field = field.to_mut();
field.is_nullable = !v;
}
}

use ArrowDataType as D;
match field.dtype() {
D::Struct(f) => {
Expand Down Expand Up @@ -180,10 +187,10 @@ pub fn schema_to_metadata_key(schema: &ArrowSchema, options: &[ColumnWriteOption
/// Creates a [`ParquetType`] from a [`Field`].
pub fn to_parquet_type(field: &Field, options: &ColumnWriteOptions) -> PolarsResult<ParquetType> {
let name = field.name.clone();
let repetition = if field.is_nullable {
Repetition::Optional
} else {
let repetition = if options.required.unwrap_or(!field.is_nullable) {
Repetition::Required
} else {
Repetition::Optional
};

let field_id = options.field_id;
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/dsl-schema.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0557f815a5f4b373cee07384fd2999bb9908ecc31c5c777366c5ffd73e6c10cf
f6b5cbe9618e32af921a0ae4bda402141d840d28c89e8da7f9c0ea127c1fecd6
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use super::*;
// - changing a name, type, or meaning of a field or an enum variant
// - changing a default value of a field or a default enum variant
// - restricting the range of allowed values a field can have
pub static DSL_VERSION: (u16, u16) = (9, 0);
pub static DSL_VERSION: (u16, u16) = (9, 1);
static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1692,11 +1692,16 @@ impl<'py> FromPyObject<'py> for Wrap<polars_io::parquet::write::ParquetFieldOver
.collect()
});

let required = PyDictMethods::get_item(&parsed, "required")?
.map(|v| v.extract::<bool>())
.transpose()?;

Ok(Wrap(ParquetFieldOverwrites {
name,
children,
field_id,
metadata,
required,
}))
}
}
6 changes: 6 additions & 0 deletions py-polars/polars/io/parquet/field_overwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def _parquet_field_overwrites_to_dict(pqo: ParquetFieldOverwrites) -> dict[str,
if pqo.metadata is not None:
d["metadata"] = list(pqo.metadata.items())

if pqo.required is not None:
d["required"] = pqo.required

return d


Expand Down Expand Up @@ -104,6 +107,7 @@ class ParquetFieldOverwrites:
metadata: (
dict[str, None | str] | None
) #: Arrow metadata added to the field before writing
required: bool | None = None #: Is the field not allowed to have missing values

def __init__(
self,
Expand All @@ -117,6 +121,7 @@ def __init__(
) = None,
field_id: int | None = None,
metadata: Mapping[str, None | str] | None = None,
required: bool | None = None,
) -> None:
self.name = name

Expand All @@ -132,3 +137,4 @@ def __init__(
self.metadata = dict(metadata)
else:
self.metadata = metadata
self.required = required
143 changes: 143 additions & 0 deletions py-polars/tests/unit/io/test_parquet_field_overwrites.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import io

import pyarrow.parquet as pq
import pytest

import polars as pl
from polars.io.parquet import ParquetFieldOverwrites


def test_required_flat() -> None:
f = io.BytesIO()
pl.Series("a", [1, 2, 3]).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(name="a", required=False),
)

f.seek(0)
assert pq.read_schema(f).field(0).nullable

f.seek(0)
pl.Series("a", [1, 2, 3]).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(name="a", required=True),
)

f.truncate()
f.seek(0)
assert not pq.read_schema(f).field(0).nullable

f = io.BytesIO()
with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
pl.Series("a", [1, 2, 3, None]).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
name="a", required=True
),
)


@pytest.mark.parametrize("dtype", [pl.List(pl.Int64()), pl.Array(pl.Int64(), 1)])
def test_required_list(dtype: pl.DataType) -> None:
f = io.BytesIO()
pl.Series("a", [[1], [2], [3], [None]], dtype).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(name="a", required=True),
)
f.seek(0)
schema = pq.read_schema(f)
assert not schema.field(0).nullable
assert schema.field(0).type.value_field.nullable

with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
pl.Series("a", [[1], [2], [3], None], dtype).to_frame().lazy().sink_parquet(
io.BytesIO(),
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
name="a", required=True
),
)

with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
pl.Series("a", [[1], [2], [3], [None]], dtype).to_frame().lazy().sink_parquet(
io.BytesIO(),
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
name="a",
required=True,
children=pl.io.parquet.ParquetFieldOverwrites(required=True),
),
)

f = io.BytesIO()
pl.Series("a", [[1], [2], [3], [4]], dtype).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
name="a",
required=True,
children=pl.io.parquet.ParquetFieldOverwrites(required=True),
),
)
f.seek(0)
schema = pq.read_schema(f)
assert not schema.field(0).nullable
assert not schema.field(0).type.value_field.nullable


def test_required_struct() -> None:
f = io.BytesIO()
pl.Series(
"a", [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]
).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
name="a",
required=True,
),
)
f.seek(0)
schema = pq.read_schema(f)
assert not schema.field(0).nullable
assert schema.field(0).type.fields[0].nullable

f = io.BytesIO()
pl.Series(
"a", [{"x": 1}, {"x": None}, {"x": 2}, {"x": 3}]
).to_frame().lazy().sink_parquet(
f,
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
name="a",
required=True,
),
)

f.seek(0)
schema = pq.read_schema(f)
assert not schema.field(0).nullable
assert schema.field(0).type.fields[0].nullable

with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
pl.Series(
"a", [{"x": 1}, {"x": None}, {"x": 2}, {"x": 3}]
).to_frame().lazy().sink_parquet(
io.BytesIO(),
field_overwrites=ParquetFieldOverwrites(
name="a",
required=True,
children={"x": ParquetFieldOverwrites(required=True)},
),
)

f = io.BytesIO()
pl.Series(
"a", [{"x": 1}, {"x": 2}, {"x": 2}, {"x": 3}]
).to_frame().lazy().sink_parquet(
f,
field_overwrites=ParquetFieldOverwrites(
name="a",
required=True,
children={"x": ParquetFieldOverwrites(required=True)},
),
)
f.seek(0)
schema = pq.read_schema(f)
assert not schema.field(0).nullable
assert not schema.field(0).type.fields[0].nullable
Loading