Skip to content

Commit 8f5ff50

Browse files
Add rerank support for sagemaker (#526)
* Add rerank to sagemaker cli * Restore * Restore * Fix * Add skip
1 parent 01b3c22 commit 8f5ff50

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

src/cohere/aws_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tokenizers import Tokenizer # type: ignore
1212

1313
from . import GenerateStreamedResponse, Generation, \
14-
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse
14+
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse
1515
from .client import Client, ClientEnvironment
1616
from .core import construct_type
1717

@@ -97,6 +97,7 @@ def __iter__(self) -> typing.Iterator[bytes]:
9797
"chat": NonStreamedChatResponse,
9898
"embed": EmbedResponse,
9999
"generate": Generation,
100+
"rerank": RerankResponse
100101
}
101102

102103
stream_response_mapping: typing.Dict[str, typing.Any] = {

tests/test_aws_client.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
embed_job = os.path.join(package_dir, 'embed_job.jsonl')
1010

1111

12-
models = {
12+
model_mapping = {
1313
"bedrock": {
1414
"chat_model": "cohere.command-r-plus-v1:0",
1515
"embed_model": "cohere.embed-multilingual-v3",
@@ -19,37 +19,60 @@
1919
"chat_model": "cohere.command-r-plus-v1:0",
2020
"embed_model": "cohere.embed-multilingual-v3",
2121
"generate_model": "cohere-command-light",
22+
"rerank_model": "rerank",
2223
},
2324
}
2425

2526

2627
@parameterized_class([
2728
{
29+
"platform": "bedrock",
2830
"client": cohere.BedrockClient(
2931
timeout=10000,
3032
aws_region="us-east-1",
3133
aws_access_key="...",
3234
aws_secret_key="...",
3335
aws_session_token="...",
3436
),
35-
"models": models["bedrock"],
37+
"models": model_mapping["bedrock"],
3638
},
3739
{
40+
"platform": "sagemaker",
3841
"client": cohere.SagemakerClient(
3942
timeout=10000,
4043
aws_region="us-east-1",
4144
aws_access_key="...",
4245
aws_secret_key="...",
4346
aws_session_token="...",
4447
),
45-
"models": models["sagemaker"],
48+
"models": model_mapping["sagemaker"],
4649
}
4750
])
4851
@unittest.skip("skip tests until they work in CI")
4952
class TestClient(unittest.TestCase):
53+
platform: str
5054
client: cohere.AwsClient
5155
models: typing.Dict[str, str]
5256

57+
def test_rerank(self) -> None:
58+
if self.platform != "sagemaker":
59+
self.skipTest("Only sagemaker supports rerank")
60+
61+
docs = [
62+
'Carson City is the capital city of the American state of Nevada.',
63+
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
64+
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
65+
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
66+
67+
response = self.client.rerank(
68+
model=self.models["rerank_model"],
69+
query='What is the capital of the United States?',
70+
documents=docs,
71+
top_n=3,
72+
)
73+
74+
self.assertEqual(len(response.results), 3)
75+
5376
def test_embed(self) -> None:
5477
response = self.client.embed(
5578
model=self.models["embed_model"],

0 commit comments

Comments
 (0)