Skip to content

Commit 4f95d30

Browse files
DataTable sort by function (or other callable) (#3090)
* DataTable sort by function (or other callable) The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). Covers #2261 using [suggested function signature](#2512 (comment)) from @darrenburns on PR #2512. * argument change and functionaloty update Changed back to orinal `columns` argument and added a new `key` argument which takes a function (or other callable). This allows the PR to NOT BE a breaking change. * better example for docs - Updated the example file for the docs to better show the functionality of the change (especially when using `columns` and `key` together). - Added one new tests to cover a similar situation to the example changes * removed unecessary code from example - the sort by clicked column function was bloat in my opinion * requested changes * simplify method and terminology * combine key_wrapper and default sort * Removing some tests from DataTable.sort as duplicates. Ensure there is test coverage of the case where a key, but no columns, is passed to DataTable.sort. * Remove unused import * Fix merge issues in CHANGELOG, update DataTable sort-by-key changelog PR link --------- Co-authored-by: Darren Burns <[email protected]> Co-authored-by: Darren Burns <[email protected]>
1 parent 665dca9 commit 4f95d30

File tree

5 files changed

+225
-23
lines changed

5 files changed

+225
-23
lines changed

CHANGELOG.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
2727

2828
- Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410
2929
- Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525
30+
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090
31+
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
32+
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
33+
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498
34+
3035

3136
### Changed
3237

@@ -49,15 +54,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
4954
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
5055
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586
5156

52-
### Added
53-
54-
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
55-
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
56-
57-
### Added
58-
59-
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498
60-
6157
## [0.40.0] - 2023-10-11
6258

6359
### Added
@@ -251,7 +247,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
251247

252248
- DescendantBlur and DescendantFocus can now be used with @on decorator
253249

254-
255250
## [0.32.0] - 2023-08-03
256251

257252
### Added
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from rich.text import Text
2+
3+
from textual.app import App, ComposeResult
4+
from textual.widgets import DataTable, Footer
5+
6+
ROWS = [
7+
("lane", "swimmer", "country", "time 1", "time 2"),
8+
(4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84),
9+
(2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84),
10+
(5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73),
11+
(6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58),
12+
(3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26),
13+
(8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15),
14+
(7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12),
15+
(1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85),
16+
(10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55),
17+
]
18+
19+
20+
class TableApp(App):
21+
BINDINGS = [
22+
("a", "sort_by_average_time", "Sort By Average Time"),
23+
("n", "sort_by_last_name", "Sort By Last Name"),
24+
("c", "sort_by_country", "Sort By Country"),
25+
("d", "sort_by_columns", "Sort By Columns (Only)"),
26+
]
27+
28+
current_sorts: set = set()
29+
30+
def compose(self) -> ComposeResult:
31+
yield DataTable()
32+
yield Footer()
33+
34+
def on_mount(self) -> None:
35+
table = self.query_one(DataTable)
36+
for col in ROWS[0]:
37+
table.add_column(col, key=col)
38+
table.add_rows(ROWS[1:])
39+
40+
def sort_reverse(self, sort_type: str):
41+
"""Determine if `sort_type` is ascending or descending."""
42+
reverse = sort_type in self.current_sorts
43+
if reverse:
44+
self.current_sorts.remove(sort_type)
45+
else:
46+
self.current_sorts.add(sort_type)
47+
return reverse
48+
49+
def action_sort_by_average_time(self) -> None:
50+
"""Sort DataTable by average of times (via a function) and
51+
passing of column data through positional arguments."""
52+
53+
def sort_by_average_time_then_last_name(row_data):
54+
name, *scores = row_data
55+
return (sum(scores) / len(scores), name.split()[-1])
56+
57+
table = self.query_one(DataTable)
58+
table.sort(
59+
"swimmer",
60+
"time 1",
61+
"time 2",
62+
key=sort_by_average_time_then_last_name,
63+
reverse=self.sort_reverse("time"),
64+
)
65+
66+
def action_sort_by_last_name(self) -> None:
67+
"""Sort DataTable by last name of swimmer (via a lambda)."""
68+
table = self.query_one(DataTable)
69+
table.sort(
70+
"swimmer",
71+
key=lambda swimmer: swimmer.split()[-1],
72+
reverse=self.sort_reverse("swimmer"),
73+
)
74+
75+
def action_sort_by_country(self) -> None:
76+
"""Sort DataTable by country which is a `Rich.Text` object."""
77+
table = self.query_one(DataTable)
78+
table.sort(
79+
"country",
80+
key=lambda country: country.plain,
81+
reverse=self.sort_reverse("country"),
82+
)
83+
84+
def action_sort_by_columns(self) -> None:
85+
"""Sort DataTable without a key."""
86+
table = self.query_one(DataTable)
87+
table.sort("swimmer", "lane", reverse=self.sort_reverse("columns"))
88+
89+
90+
app = TableApp()
91+
if __name__ == "__main__":
92+
app.run()

docs/widgets/data_table.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,22 @@ visible as you scroll through the data table.
143143

144144
### Sorting
145145

146-
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method.
147-
In order to sort your data by a column, you must have supplied a `key` to the `add_column` method
148-
when you added it.
149-
You can then pass this key to the `sort` method to sort by that column.
150-
Additionally, you can sort by multiple columns by passing multiple keys to `sort`.
146+
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.
147+
148+
Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.
149+
150+
Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.
151+
152+
=== "Output"
153+
154+
```{.textual path="docs/examples/widgets/data_table_sort.py"}
155+
```
156+
157+
=== "data_table_sort.py"
158+
159+
```python
160+
--8<-- "docs/examples/widgets/data_table_sort.py"
161+
```
151162

152163
### Labelled rows
153164

src/textual/widgets/_data_table.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from itertools import chain, zip_longest
66
from operator import itemgetter
7-
from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
7+
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
88

99
import rich.repr
1010
from rich.console import RenderableType
@@ -2348,30 +2348,40 @@ def _get_fixed_offset(self) -> Spacing:
23482348
def sort(
23492349
self,
23502350
*columns: ColumnKey | str,
2351+
key: Callable[[Any], Any] | None = None,
23512352
reverse: bool = False,
23522353
) -> Self:
2353-
"""Sort the rows in the `DataTable` by one or more column keys.
2354+
"""Sort the rows in the `DataTable` by one or more column keys or a
2355+
key function (or other callable). If both columns and a key function
2356+
are specified, only data from those columns will sent to the key function.
23542357
23552358
Args:
23562359
columns: One or more columns to sort by the values in.
2360+
key: A function (or other callable) that returns a key to
2361+
use for sorting purposes.
23572362
reverse: If True, the sort order will be reversed.
23582363
23592364
Returns:
23602365
The `DataTable` instance.
23612366
"""
23622367

2363-
def sort_by_column_keys(
2364-
row: tuple[RowKey, dict[ColumnKey | str, CellType]]
2365-
) -> Any:
2368+
def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
23662369
_, row_data = row
2367-
result = itemgetter(*columns)(row_data)
2370+
if columns:
2371+
result = itemgetter(*columns)(row_data)
2372+
else:
2373+
result = tuple(row_data.values())
2374+
if key is not None:
2375+
return key(result)
23682376
return result
23692377

23702378
ordered_rows = sorted(
2371-
self._data.items(), key=sort_by_column_keys, reverse=reverse
2379+
self._data.items(),
2380+
key=key_wrapper,
2381+
reverse=reverse,
23722382
)
23732383
self._row_locations = TwoWayDict(
2374-
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
2384+
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
23752385
)
23762386
self._update_count += 1
23772387
self.refresh()

tests/test_data_table.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,100 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse():
11971197
assert not table._show_hover_cursor
11981198

11991199

1200+
async def test_sort_by_all_columns_no_key():
1201+
"""Test sorting a `DataTable` by all columns."""
1202+
1203+
app = DataTableApp()
1204+
async with app.run_test():
1205+
table = app.query_one(DataTable)
1206+
a, b, c = table.add_columns("A", "B", "C")
1207+
table.add_row(1, 3, 8)
1208+
table.add_row(2, 9, 5)
1209+
table.add_row(1, 1, 9)
1210+
assert table.get_row_at(0) == [1, 3, 8]
1211+
assert table.get_row_at(1) == [2, 9, 5]
1212+
assert table.get_row_at(2) == [1, 1, 9]
1213+
1214+
table.sort()
1215+
assert table.get_row_at(0) == [1, 1, 9]
1216+
assert table.get_row_at(1) == [1, 3, 8]
1217+
assert table.get_row_at(2) == [2, 9, 5]
1218+
1219+
table.sort(reverse=True)
1220+
assert table.get_row_at(0) == [2, 9, 5]
1221+
assert table.get_row_at(1) == [1, 3, 8]
1222+
assert table.get_row_at(2) == [1, 1, 9]
1223+
1224+
1225+
async def test_sort_by_multiple_columns_no_key():
1226+
"""Test sorting a `DataTable` by multiple columns."""
1227+
1228+
app = DataTableApp()
1229+
async with app.run_test():
1230+
table = app.query_one(DataTable)
1231+
a, b, c = table.add_columns("A", "B", "C")
1232+
table.add_row(1, 3, 8)
1233+
table.add_row(2, 9, 5)
1234+
table.add_row(1, 1, 9)
1235+
1236+
table.sort(a, b, c)
1237+
assert table.get_row_at(0) == [1, 1, 9]
1238+
assert table.get_row_at(1) == [1, 3, 8]
1239+
assert table.get_row_at(2) == [2, 9, 5]
1240+
1241+
table.sort(a, c, b)
1242+
assert table.get_row_at(0) == [1, 3, 8]
1243+
assert table.get_row_at(1) == [1, 1, 9]
1244+
assert table.get_row_at(2) == [2, 9, 5]
1245+
1246+
table.sort(c, a, b, reverse=True)
1247+
assert table.get_row_at(0) == [1, 1, 9]
1248+
assert table.get_row_at(1) == [1, 3, 8]
1249+
assert table.get_row_at(2) == [2, 9, 5]
1250+
1251+
table.sort(a, c)
1252+
assert table.get_row_at(0) == [1, 3, 8]
1253+
assert table.get_row_at(1) == [1, 1, 9]
1254+
assert table.get_row_at(2) == [2, 9, 5]
1255+
1256+
1257+
async def test_sort_by_function_sum():
1258+
"""Test sorting a `DataTable` using a custom sort function."""
1259+
1260+
def custom_sort(row_data):
1261+
return sum(row_data)
1262+
1263+
row_data = (
1264+
[1, 3, 8], # SUM=12
1265+
[2, 9, 5], # SUM=16
1266+
[1, 1, 9], # SUM=11
1267+
)
1268+
1269+
app = DataTableApp()
1270+
async with app.run_test():
1271+
table = app.query_one(DataTable)
1272+
a, b, c = table.add_columns("A", "B", "C")
1273+
for i, row in enumerate(row_data):
1274+
table.add_row(*row)
1275+
1276+
# Sorting by all columns
1277+
table.sort(a, b, c, key=custom_sort)
1278+
sorted_row_data = sorted(row_data, key=sum)
1279+
for i, row in enumerate(sorted_row_data):
1280+
assert table.get_row_at(i) == row
1281+
1282+
# Passing a sort function but no columns also sorts by all columns
1283+
table.sort(key=custom_sort)
1284+
sorted_row_data = sorted(row_data, key=sum)
1285+
for i, row in enumerate(sorted_row_data):
1286+
assert table.get_row_at(i) == row
1287+
1288+
table.sort(a, b, c, key=custom_sort, reverse=True)
1289+
sorted_row_data = sorted(row_data, key=sum, reverse=True)
1290+
for i, row in enumerate(sorted_row_data):
1291+
assert table.get_row_at(i) == row
1292+
1293+
12001294
@pytest.mark.parametrize(
12011295
["cell", "height"],
12021296
[

0 commit comments

Comments
 (0)