Skip to content

Commit 199e1cd

Browse files
perf: Optimise low-level null scans and arg_max for bools (when chunked) (#22897)
1 parent 622d792 commit 199e1cd

File tree

5 files changed

+49
-35
lines changed

5 files changed

+49
-35
lines changed

crates/polars-core/src/utils/mod.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ mod schema;
1010

1111
pub use any_value::*;
1212
use arrow::bitmap::Bitmap;
13-
use arrow::bitmap::bitmask::BitMask;
1413
pub use arrow::legacy::utils::*;
1514
pub use arrow::trusted_len::TrustMyLength;
1615
use flatten::*;
@@ -1200,26 +1199,27 @@ pub(crate) fn index_to_chunked_index_rev<
12001199
)
12011200
}
12021201

1203-
pub(crate) fn first_non_null<'a, I>(iter: I) -> Option<usize>
1202+
pub fn first_non_null<'a, I>(iter: I) -> Option<usize>
12041203
where
12051204
I: Iterator<Item = Option<&'a Bitmap>>,
12061205
{
12071206
let mut offset = 0;
12081207
for validity in iter {
1209-
if let Some(validity) = validity {
1210-
let mask = BitMask::from_bitmap(validity);
1211-
if let Some(n) = mask.nth_set_bit_idx(0, 0) {
1208+
if let Some(mask) = validity {
1209+
let len_mask = mask.len();
1210+
let n = mask.leading_zeros();
1211+
if n < len_mask {
12121212
return Some(offset + n);
12131213
}
1214-
offset += validity.len()
1214+
offset += len_mask
12151215
} else {
12161216
return Some(offset);
12171217
}
12181218
}
12191219
None
12201220
}
12211221

1222-
pub(crate) fn last_non_null<'a, I>(iter: I, len: usize) -> Option<usize>
1222+
pub fn last_non_null<'a, I>(iter: I, len: usize) -> Option<usize>
12231223
where
12241224
I: DoubleEndedIterator<Item = Option<&'a Bitmap>>,
12251225
{
@@ -1228,15 +1228,15 @@ where
12281228
}
12291229
let mut offset = 0;
12301230
for validity in iter.rev() {
1231-
if let Some(validity) = validity {
1232-
let mask = BitMask::from_bitmap(validity);
1233-
if let Some(n) = mask.nth_set_bit_idx_rev(0, mask.len()) {
1234-
let mask_start = len - offset - mask.len();
1235-
return Some(mask_start + n);
1231+
if let Some(mask) = validity {
1232+
let len_mask = mask.len();
1233+
let n = mask.trailing_zeros();
1234+
if n < len_mask {
1235+
return Some(len - offset - n - 1);
12361236
}
1237-
offset += validity.len()
1237+
offset += len_mask;
12381238
} else {
1239-
return Some(len - 1 - offset);
1239+
return Some(len - offset - 1);
12401240
}
12411241
}
12421242
None

crates/polars-ops/src/series/ops/arg_min_max.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use argminmax::ArgMinMax;
22
use arrow::array::Array;
3-
use arrow::legacy::bit_util::*;
3+
use arrow::legacy::bit_util::first_unset_bit;
44
use polars_core::chunked_array::ops::float_sorted_arg_max::{
55
float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
66
};
77
use polars_core::series::IsSorted;
8-
use polars_core::with_match_physical_numeric_polars_type;
8+
use polars_core::utils::first_non_null;
99

1010
use super::*;
1111

@@ -167,14 +167,14 @@ where
167167
}
168168

169169
pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
170-
if ca.null_count() == ca.len() {
170+
let null_count = ca.null_count();
171+
if null_count == ca.len() {
171172
None
172-
}
173-
// don't check for any, that on itself is already an argmax search
174-
else if ca.null_count() == 0 && ca.chunks().len() == 1 {
175-
let arr = ca.downcast_iter().next().unwrap();
176-
let mask = arr.values();
177-
Some(first_set_bit(mask))
173+
} else if null_count == 0 {
174+
// if no null values we only have True/False, which can be downcast to
175+
// operate on as bitmap 1/0, allowing for a fast-path; if `first_non_null`
176+
// returns None this implies all values are zero (eg: False), so we return 0
177+
first_non_null(ca.downcast_iter().map(|arr| Some(arr.values()))).or(Some(0))
178178
} else {
179179
let mut first_false_idx: Option<usize> = None;
180180
ca.iter()
@@ -202,7 +202,6 @@ where
202202
IsSorted::Descending => float_arg_max_sorted_descending(ca),
203203
_ => unreachable!(),
204204
};
205-
206205
Some(out)
207206
}
208207

crates/polars-ops/src/series/ops/index_of.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>>
8080

8181
// Series is not null, and the value is null:
8282
if needle.is_null() {
83+
let null_count = series.null_count();
84+
if null_count == 0 {
85+
return Ok(None);
86+
} else if null_count == series.len() {
87+
return Ok(Some(0));
88+
}
8389
let mut index = 0;
8490
for chunk in series.chunks() {
8591
let length = chunk.len();

py-polars/tests/unit/operations/test_index_of.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111

1212
import polars as pl
1313
from polars.exceptions import InvalidOperationError
14+
from polars.testing import assert_frame_equal
15+
from polars.testing.parametric import series
1416

1517
if TYPE_CHECKING:
1618
from polars._typing import IntoExpr
17-
from polars.testing import assert_frame_equal
1819

1920

2021
def isnan(value: object) -> bool:
@@ -351,3 +352,14 @@ def test_categorical(convert_to_literal: bool) -> None:
351352
]:
352353
for value in expected_values:
353354
assert_index_of(s, value, convert_to_literal=convert_to_literal)
355+
356+
357+
@given(s=series(name="s", allow_chunks=True, max_size=10))
358+
def test_index_of_null_parametric(s: pl.Series) -> None:
359+
idx_null = s.index_of(None)
360+
if s.len() == 0:
361+
assert idx_null is None
362+
elif s.null_count() == 0:
363+
assert idx_null is None
364+
elif s.null_count() == len(s):
365+
assert idx_null == 0

py-polars/tests/unit/series/test_series.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,18 +1437,16 @@ def test_arg_sort() -> None:
14371437
(pl.Series(["c", "b", "a"], dtype=pl.Categorical), 0, 2),
14381438
(pl.Series([None, "c", "b", None, "a"], dtype=pl.Categorical), 1, 4),
14391439
(pl.Series(["c", "b", "a"], dtype=pl.Categorical(ordering="lexical")), 2, 0),
1440-
(
1441-
pl.Series(
1442-
[None, "c", "b", None, "a"], dtype=pl.Categorical(ordering="lexical")
1443-
),
1444-
4,
1445-
1,
1446-
),
1440+
(pl.Series("s", [None, "c", "b", None, "a"], pl.Categorical("lexical")), 4, 1),
14471441
],
14481442
)
14491443
def test_arg_min_arg_max(series: pl.Series, argmin: int, argmax: int) -> None:
1450-
assert series.arg_min() == argmin
1451-
assert series.arg_max() == argmax
1444+
assert series.arg_min() == argmin, (
1445+
f"values: {series.to_list()}, expected {argmin} got {series.arg_min()}"
1446+
)
1447+
assert series.arg_max() == argmax, (
1448+
f"values: {series.to_list()}, expected {argmax} got {series.arg_max()}"
1449+
)
14521450

14531451

14541452
@pytest.mark.parametrize(
@@ -2216,7 +2214,6 @@ def test_construction_large_nested_u64_17231() -> None:
22162214

22172215
values = [{"f0": [9223372036854775808]}]
22182216
dtype = pl.Struct({"f0": pl.List(pl.UInt64)})
2219-
22202217
assert pl.Series(values, dtype=dtype).to_list() == values
22212218

22222219

0 commit comments

Comments
 (0)