diff --git a/src/cohere/client.py b/src/cohere/client.py index ba6780581..5cb5cf272 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -28,10 +28,21 @@ def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None: method = getattr(obj, method_name) - def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + def _wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_fn(*args, **kwargs) return method(*args, **kwargs) + async def _async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + # The `return await` looks redundant, but it's necessary to ensure that the return type is correct. + check_fn(*args, **kwargs) + return await method(*args, **kwargs) + + wrapped = _wrapped + if asyncio.iscoroutinefunction(method): + wrapped = _async_wrapped + + wrapped.__name__ = method.__name__ + wrapped.__doc__ = method.__doc__ setattr(obj, method_name, wrapped)