Skip to content

Commit d3d72b9

Browse files
chore(internal): add Sequence related utils
1 parent 3ab273f commit d3d72b9

File tree

4 files changed

+50
-2
lines changed

4 files changed

+50
-2
lines changed

src/openai/_types.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@
1313
Mapping,
1414
TypeVar,
1515
Callable,
16+
Iterator,
1617
Optional,
1718
Sequence,
1819
)
19-
from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
20+
from typing_extensions import (
21+
Set,
22+
Literal,
23+
Protocol,
24+
TypeAlias,
25+
TypedDict,
26+
SupportsIndex,
27+
overload,
28+
override,
29+
runtime_checkable,
30+
)
2031

2132
import httpx
2233
import pydantic
@@ -219,3 +230,26 @@ class _GenericAlias(Protocol):
219230
class HttpxSendArgs(TypedDict, total=False):
220231
auth: httpx.Auth
221232
follow_redirects: bool
233+
234+
235+
_T_co = TypeVar("_T_co", covariant=True)
236+
237+
238+
if TYPE_CHECKING:
239+
# This works because str.__contains__ does not accept object (either in typeshed or at runtime)
240+
# https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285
241+
class SequenceNotStr(Protocol[_T_co]):
242+
@overload
243+
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
244+
@overload
245+
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
246+
def __contains__(self, value: object, /) -> bool: ...
247+
def __len__(self) -> int: ...
248+
def __iter__(self) -> Iterator[_T_co]: ...
249+
def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
250+
def count(self, value: Any, /) -> int: ...
251+
def __reversed__(self) -> Iterator[_T_co]: ...
252+
else:
253+
# just point this to a normal `Sequence` at runtime to avoid having to special case
254+
# deserializing our custom sequence type
255+
SequenceNotStr = Sequence

src/openai/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
extract_type_arg as extract_type_arg,
4242
is_iterable_type as is_iterable_type,
4343
is_required_type as is_required_type,
44+
is_sequence_type as is_sequence_type,
4445
is_annotated_type as is_annotated_type,
4546
is_type_alias_type as is_type_alias_type,
4647
strip_annotated_type as strip_annotated_type,

src/openai/_utils/_typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ def is_list_type(typ: type) -> bool:
2626
return (get_origin(typ) or typ) == list
2727

2828

29+
def is_sequence_type(typ: type) -> bool:
30+
origin = get_origin(typ) or typ
31+
return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence
32+
33+
2934
def is_iterable_type(typ: type) -> bool:
3035
"""If the given type is `typing.Iterable[T]`"""
3136
origin = get_origin(typ) or typ

tests/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
import traceback
77
import contextlib
8-
from typing import Any, TypeVar, Iterator, ForwardRef, cast
8+
from typing import Any, TypeVar, Iterator, ForwardRef, Sequence, cast
99
from datetime import date, datetime
1010
from typing_extensions import Literal, get_args, get_origin, assert_type
1111

@@ -18,6 +18,7 @@
1818
is_list_type,
1919
is_union_type,
2020
extract_type_arg,
21+
is_sequence_type,
2122
is_annotated_type,
2223
is_type_alias_type,
2324
)
@@ -78,6 +79,13 @@ def assert_matches_type(
7879
if is_list_type(type_):
7980
return _assert_list_type(type_, value)
8081

82+
if is_sequence_type(type_):
83+
assert isinstance(value, Sequence)
84+
inner_type = get_args(type_)[0]
85+
for entry in value: # type: ignore
86+
assert_type(inner_type, entry) # type: ignore
87+
return
88+
8189
if origin == str:
8290
assert isinstance(value, str)
8391
elif origin == int:

0 commit comments

Comments
 (0)