Skip to content

Commit ac04bfa

Browse files
authored
Add generic type support to Future and Client methods (#9123)
1 parent 506d307 commit ac04bfa

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

distributed/client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
Any,
3232
Callable,
3333
ClassVar,
34+
Generic,
3435
Literal,
3536
NamedTuple,
3637
TypedDict,
38+
TypeVar,
3739
cast,
3840
)
3941

@@ -138,6 +140,8 @@
138140

139141
TOPIC_PREFIX_FORWARDED_LOG_RECORD = "forwarded-log-record"
140142

143+
_T = TypeVar("_T")
144+
141145

142146
class FutureCancelledError(CancelledError):
143147
key: str
@@ -241,7 +245,7 @@ def _del_global_client(c: Client) -> None:
241245
pass
242246

243247

244-
class Future(TaskRef):
248+
class Future(TaskRef, Generic[_T]):
245249
"""A remotely running computation
246250
247251
A Future is a local proxy to a result running on a remote worker. A user
@@ -371,7 +375,7 @@ def done(self):
371375
"""
372376
return self._state.done()
373377

374-
def result(self, timeout=None):
378+
def result(self, timeout=None) -> _T:
375379
"""Wait until computation completes, gather result to local process.
376380
377381
Parameters
@@ -2033,7 +2037,7 @@ def get_executor(self, **kwargs):
20332037

20342038
def submit(
20352039
self,
2036-
func,
2040+
func: Callable[..., _T],
20372041
*args,
20382042
key=None,
20392043
workers=None,
@@ -2046,7 +2050,7 @@ def submit(
20462050
actors=False,
20472051
pure=True,
20482052
**kwargs,
2049-
):
2053+
) -> Future[_T]:
20502054
"""Submit a function application to the scheduler
20512055
20522056
Parameters
@@ -2154,7 +2158,7 @@ def submit(
21542158
key,
21552159
func,
21562160
*(parse_input(a) for a in args),
2157-
**{k: parse_input(v) for k, v in kwargs.items()},
2161+
**{k: parse_input(v) for k, v in kwargs.items()}, # type: ignore
21582162
)
21592163
},
21602164
# We'd like to avoid hashing/tokenizing all of the above.
@@ -2181,7 +2185,7 @@ def submit(
21812185

21822186
def map(
21832187
self,
2184-
func: Callable,
2188+
func: Callable[..., _T],
21852189
*iterables: Collection,
21862190
key: str | list | None = None,
21872191
workers: str | Iterable[str] | None = None,
@@ -2195,7 +2199,7 @@ def map(
21952199
pure: bool = True,
21962200
batch_size=None,
21972201
**kwargs,
2198-
):
2202+
) -> list[Future[_T]]:
21992203
"""Map a function on a sequence of arguments
22002204
22012205
Arguments can be normal objects or Futures

0 commit comments

Comments
 (0)