Skip to content

Commit 2545497

Browse files
refactor(python): Improve/fix internal LRUCache implementation and move into "_utils" module (#23813)
1 parent 5d429ea commit 2545497

File tree

5 files changed

+402
-138
lines changed

5 files changed

+402
-138
lines changed

py-polars/polars/_utils/cache.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from __future__ import annotations
2+
3+
from collections import OrderedDict
4+
from collections.abc import MutableMapping
5+
from typing import TYPE_CHECKING, Any, TypeVar, overload
6+
7+
from polars._utils.various import no_default
8+
9+
if TYPE_CHECKING:
10+
import sys
11+
from collections.abc import ItemsView, Iterable, Iterator, KeysView, ValuesView
12+
13+
from polars._utils.various import NoDefault
14+
15+
if sys.version_info >= (3, 11):
16+
from typing import Self
17+
else:
18+
from typing_extensions import Self
19+
20+
D = TypeVar("D")
21+
K = TypeVar("K")
22+
V = TypeVar("V")
23+
24+
25+
class LRUCache(MutableMapping[K, V]):
26+
def __init__(self, maxsize: int) -> None:
27+
"""
28+
Initialize an LRU (Least Recently Used) cache with a specified maximum size.
29+
30+
Parameters
31+
----------
32+
maxsize : int
33+
The maximum number of items the cache can hold.
34+
35+
Examples
36+
--------
37+
>>> from polars._utils.cache import LRUCache
38+
>>> cache = LRUCache[str, int](maxsize=3)
39+
>>> cache["a"] = 1
40+
>>> cache["b"] = 2
41+
>>> cache["c"] = 3
42+
>>> cache["d"] = 4 # evicts the least recently used item ("a"), as maxsize=3
43+
>>> print(cache["b"]) # accessing "b" marks it as recently used
44+
2
45+
>>> print(list(cache.keys())) # show the current keys in LRU order
46+
['c', 'd', 'b']
47+
>>> cache.get("xyz", "not found")
48+
'not found'
49+
"""
50+
self._items: OrderedDict[K, V] = OrderedDict()
51+
self.maxsize = maxsize
52+
53+
def __bool__(self) -> bool:
54+
"""Returns True if the cache is not empty, False otherwise."""
55+
return bool(self._items)
56+
57+
def __contains__(self, key: Any) -> bool:
58+
"""Check if the key is in the cache."""
59+
return key in self._items
60+
61+
def __delitem__(self, key: K) -> None:
62+
"""Remove the item with the specified key from the cache."""
63+
if key not in self._items:
64+
msg = f"{key!r} not found in cache"
65+
raise KeyError(msg)
66+
del self._items[key]
67+
68+
def __getitem__(self, key: K) -> V:
69+
"""Raises KeyError if the key is not found."""
70+
if key not in self._items:
71+
msg = f"{key!r} not found in cache"
72+
raise KeyError(msg)
73+
74+
# moving accessed items to the end marks them as recently used
75+
self._items.move_to_end(key)
76+
return self._items[key]
77+
78+
def __iter__(self) -> Iterator[K]:
79+
"""Iterate over the keys in the cache."""
80+
yield from self._items
81+
82+
def __len__(self) -> int:
83+
"""Number of items in the cache."""
84+
return len(self._items)
85+
86+
def __setitem__(self, key: K, value: V) -> None:
87+
"""Insert a value into the cache."""
88+
if self._max_size == 0:
89+
return
90+
while len(self) >= self._max_size:
91+
self.popitem()
92+
if key in self:
93+
# moving accessed items to the end marks them as recently used
94+
self._items.move_to_end(key)
95+
self._items[key] = value
96+
97+
def __repr__(self) -> str:
98+
"""Return a string representation of the cache."""
99+
all_items = list(self._items.items())
100+
if len(self) > 4:
101+
items = (
102+
", ".join(f"{k!r}: {v!r}" for k, v in all_items[:2])
103+
+ " ..., "
104+
+ ", ".join(f"{k!r}: {v!r}" for k, v in all_items[-2:])
105+
)
106+
else:
107+
items = ", ".join(f"{k!r}: {v!r}" for k, v in all_items)
108+
return f"{self.__class__.__name__}({{{items}}}, maxsize={self._max_size}, currsize={len(self)})"
109+
110+
def clear(self) -> None:
111+
"""Clear the cache, removing all items."""
112+
self._items.clear()
113+
114+
@overload
115+
def get(self, key: K, default: None = None) -> V | None: ...
116+
117+
@overload
118+
def get(self, key: K, default: D = ...) -> V | D: ...
119+
120+
def get(self, key: K, default: D | V | None = None) -> V | D | None:
121+
"""Return value associated with `key` if present, otherwise return `default`."""
122+
if key in self:
123+
# moving accessed items to the end marks them as recently used
124+
self._items.move_to_end(key)
125+
return self._items[key]
126+
return default
127+
128+
@classmethod
129+
def fromkeys(cls, maxsize: int, *, keys: Iterable[K], value: V) -> Self:
130+
"""Initialize cache with keys from an iterable, all set to the same value."""
131+
cache = cls(maxsize)
132+
for key in keys:
133+
cache[key] = value
134+
return cache
135+
136+
def items(self) -> ItemsView[K, V]:
137+
"""Return an iterable view of the cache's items (keys and values)."""
138+
return self._items.items()
139+
140+
def keys(self) -> KeysView[K]:
141+
"""Return an iterable view of the cache's keys."""
142+
return self._items.keys()
143+
144+
@property
145+
def maxsize(self) -> int:
146+
return self._max_size
147+
148+
@maxsize.setter
149+
def maxsize(self, n: int) -> None:
150+
"""Set new maximum cache size; cache is trimmed if value is smaller."""
151+
if n < 0:
152+
msg = f"`maxsize` cannot be negative; found {n}"
153+
raise ValueError(msg)
154+
while len(self) > n:
155+
self.popitem()
156+
self._max_size = n
157+
158+
def pop(self, key: K, default: D | NoDefault = no_default) -> V | D:
159+
"""
160+
Remove specified key from the cache and return the associated value.
161+
162+
If the key is not found, `default` is returned (if given).
163+
Otherwise, a KeyError is raised.
164+
"""
165+
if (item := self._items.pop(key, default)) is no_default:
166+
msg = f"{key!r} not found in cache"
167+
raise KeyError(msg)
168+
return item
169+
170+
def popitem(self) -> tuple[K, V]:
171+
"""Remove the least recently used value; raises KeyError if cache is empty."""
172+
return self._items.popitem(last=False)
173+
174+
def values(self) -> ValuesView[V]:
175+
"""Return an iterable view of the cache's values."""
176+
return self._items.values()

py-polars/polars/io/cloud/_utils.py

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
from __future__ import annotations
22

3-
from collections import OrderedDict
43
from pathlib import Path
5-
from typing import TYPE_CHECKING, Any, Generic, TypeVar
4+
from typing import Any, Generic, TypeVar
65

76
from polars._utils.various import is_path_or_str_sequence
87
from polars.io.partition import PartitionMaxSize
98

10-
if TYPE_CHECKING:
11-
from collections.abc import KeysView
12-
139
T = TypeVar("T")
14-
K = TypeVar("K")
15-
V = TypeVar("V")
1610

1711

1812
class NoPickleOption(Generic[T]):
@@ -39,68 +33,6 @@ def __setstate__(self, _state: tuple[()]) -> None:
3933
NoPickleOption.__init__(self)
4034

4135

42-
class LRUCache(Generic[K, V]):
43-
def __init__(self, max_items: int) -> None:
44-
self._max_items = 0
45-
self._dict: OrderedDict[K, V] = OrderedDict()
46-
47-
self.set_max_items(max_items)
48-
49-
def __len__(self) -> int:
50-
return len(self._dict)
51-
52-
def get(self, key: K) -> V:
53-
"""Raises KeyError if the key is not found."""
54-
self._dict.move_to_end(key)
55-
return self._dict[key]
56-
57-
def keys(self) -> KeysView[K]:
58-
return self._dict.keys()
59-
60-
def contains(self, key: K) -> bool:
61-
return key in self._dict
62-
63-
def insert(self, key: K, value: V) -> None:
64-
"""Insert a value into the cache."""
65-
if self.max_items() == 0:
66-
return
67-
68-
while len(self) >= self.max_items():
69-
self.remove_lru()
70-
71-
self._dict[key] = value
72-
73-
def remove(self, key: K) -> V:
74-
"""Raises KeyError if the key is not found."""
75-
return self._dict.pop(key)
76-
77-
def max_items(self) -> int:
78-
return self._max_items
79-
80-
def set_max_items(self, max_items: int) -> None:
81-
"""
82-
Set a new maximum number of items.
83-
84-
The cache is trimmed if its length exceeds the new maximum.
85-
"""
86-
if max_items < 0:
87-
msg = f"max_items cannot be negative: {max_items}"
88-
raise ValueError(msg)
89-
90-
while len(self) > max_items:
91-
self.remove_lru()
92-
93-
self._max_items = max_items
94-
95-
def remove_lru(self) -> tuple[K, V]:
96-
"""
97-
Remove the least recently used value.
98-
99-
Raises KeyError if the cache is empty.
100-
"""
101-
return self._dict.popitem(last=False)
102-
103-
10436
def _first_scan_path(
10537
source: Any,
10638
) -> str | Path | None:

py-polars/polars/io/cloud/credential_provider/_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
77

88
import polars._utils.logging
9+
from polars._utils.cache import LRUCache
910
from polars._utils.logging import eprint, verbose
1011
from polars._utils.unstable import issue_unstable_warning
11-
from polars.io.cloud._utils import LRUCache, NoPickleOption
12+
from polars.io.cloud._utils import NoPickleOption
1213
from polars.io.cloud.credential_provider._providers import (
1314
CachingCredentialProvider,
1415
CredentialProvider,
@@ -167,7 +168,7 @@ def __repr__(self) -> str:
167168
return f"{provider_repr} @ {builder_name}"
168169

169170

170-
# Wraps an already ininitialized credential provider into the builder interface.
171+
# Wraps an already initialized credential provider into the builder interface.
171172
# Used for e.g. user-provided credential providers.
172173
class InitializedCredentialProvider(CredentialProviderBuilderImpl):
173174
"""Wraps an already initialized credential provider."""
@@ -221,10 +222,10 @@ def _auto_init_with_cache(
221222
cache_key = get_cache_key_func()
222223

223224
try:
224-
provider = AUTO_INIT_LRU_CACHE.get(cache_key)
225+
provider = AUTO_INIT_LRU_CACHE[cache_key]
225226
except KeyError:
226227
provider = build_provider_func()
227-
AUTO_INIT_LRU_CACHE.insert(cache_key, provider)
228+
AUTO_INIT_LRU_CACHE[cache_key] = provider
228229

229230
return provider
230231

py-polars/tests/unit/io/cloud/test_credential_provider.py

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import polars as pl
1212
import polars.io.cloud.credential_provider
13-
from polars.io.cloud._utils import LRUCache, NoPickleOption
13+
from polars.io.cloud._utils import NoPickleOption
1414
from polars.io.cloud.credential_provider._builder import (
1515
AutoInit,
1616
_init_credential_provider_builder,
@@ -607,70 +607,6 @@ def test_credential_provider_aws_expiry(
607607
assert expiry is None
608608

609609

610-
def test_lru_cache() -> None:
611-
def _test(cache: LRUCache[int, str]) -> None:
612-
with pytest.raises(ValueError):
613-
cache.set_max_items(-1)
614-
615-
assert len(cache) == 0
616-
assert cache.max_items() == 2
617-
618-
cache.insert(1, "1")
619-
cache.insert(2, "2")
620-
621-
assert cache.get(2) == "2"
622-
assert cache.get(1) == "1"
623-
624-
assert cache.contains(1)
625-
assert cache.contains(2)
626-
627-
assert list(cache.keys()) == [2, 1]
628-
629-
cache.insert(3, "3")
630-
631-
# Note: We have 1, 3 due to cache.get() ordering above.
632-
# The calls to contains() should not shift the LRU order.
633-
assert list(cache.keys()) == [1, 3]
634-
635-
cache.insert(4, "4")
636-
637-
assert cache.contains(3)
638-
assert cache.contains(4)
639-
640-
assert list(cache.keys()) == [3, 4]
641-
642-
cache.remove(4)
643-
cache.insert(5, "5")
644-
645-
assert list(cache.keys()) == [3, 5]
646-
647-
assert cache.max_items() == 2
648-
assert len(cache) == 2
649-
650-
cache.set_max_items(1)
651-
assert cache.max_items() == 1
652-
assert len(cache) == 1
653-
assert list(cache.keys()) == [5]
654-
655-
cache: LRUCache[int, str] = LRUCache(2)
656-
657-
_test(cache)
658-
659-
cache.set_max_items(0)
660-
assert len(cache) == 0
661-
assert cache.max_items() == 0
662-
663-
cache.insert(1, "1")
664-
assert len(cache.keys()) == 0
665-
assert not cache.contains(1)
666-
667-
with pytest.raises(KeyError):
668-
cache.remove(1)
669-
670-
cache.set_max_items(2)
671-
_test(cache)
672-
673-
674610
@pytest.mark.slow
675611
@pytest.mark.parametrize(
676612
(

0 commit comments

Comments
 (0)