| 
9 | 9 | embed_job = os.path.join(package_dir, 'embed_job.jsonl')  | 
10 | 10 | 
 
  | 
11 | 11 | 
 
  | 
12 |  | -models = {  | 
 | 12 | +model_mapping = {  | 
13 | 13 |     "bedrock": {  | 
14 | 14 |         "chat_model": "cohere.command-r-plus-v1:0",  | 
15 | 15 |         "embed_model": "cohere.embed-multilingual-v3",  | 
 | 
19 | 19 |         "chat_model": "cohere.command-r-plus-v1:0",  | 
20 | 20 |         "embed_model": "cohere.embed-multilingual-v3",  | 
21 | 21 |         "generate_model": "cohere-command-light",  | 
 | 22 | +        "rerank_model": "rerank",  | 
22 | 23 |     },  | 
23 | 24 | }  | 
24 | 25 | 
 
  | 
25 | 26 | 
 
  | 
26 | 27 | @parameterized_class([  | 
27 | 28 |     {  | 
 | 29 | +        "platform": "bedrock",  | 
28 | 30 |         "client": cohere.BedrockClient(  | 
29 | 31 |             timeout=10000,  | 
30 | 32 |             aws_region="us-east-1",  | 
31 | 33 |             aws_access_key="...",  | 
32 | 34 |             aws_secret_key="...",  | 
33 | 35 |             aws_session_token="...",  | 
34 | 36 |         ),  | 
35 |  | -        "models": models["bedrock"],  | 
 | 37 | +        "models": model_mapping["bedrock"],  | 
36 | 38 |     },  | 
37 | 39 |     {  | 
 | 40 | +        "platform": "sagemaker",  | 
38 | 41 |         "client": cohere.SagemakerClient(  | 
39 | 42 |             timeout=10000,  | 
40 | 43 |             aws_region="us-east-1",  | 
41 | 44 |             aws_access_key="...",  | 
42 | 45 |             aws_secret_key="...",  | 
43 | 46 |             aws_session_token="...",  | 
44 | 47 |         ),  | 
45 |  | -        "models": models["sagemaker"],  | 
 | 48 | +        "models": model_mapping["sagemaker"],  | 
46 | 49 |     }  | 
47 | 50 | ])  | 
48 | 51 | @unittest.skip("skip tests until they work in CI")  | 
49 | 52 | class TestClient(unittest.TestCase):  | 
 | 53 | +    platform: str  | 
50 | 54 |     client: cohere.AwsClient  | 
51 | 55 |     models: typing.Dict[str, str]  | 
52 | 56 | 
 
  | 
 | 57 | +    def test_rerank(self) -> None:  | 
 | 58 | +        if self.platform != "sagemaker":  | 
 | 59 | +            self.skipTest("Only sagemaker supports rerank")  | 
 | 60 | + | 
 | 61 | +        docs = [  | 
 | 62 | +            'Carson City is the capital city of the American state of Nevada.',  | 
 | 63 | +            'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',  | 
 | 64 | +            'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',  | 
 | 65 | +            'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']  | 
 | 66 | + | 
 | 67 | +        response = self.client.rerank(  | 
 | 68 | +            model=self.models["rerank_model"],  | 
 | 69 | +            query='What is the capital of the United States?',  | 
 | 70 | +            documents=docs,  | 
 | 71 | +            top_n=3,  | 
 | 72 | +        )  | 
 | 73 | + | 
 | 74 | +        self.assertEqual(len(response.results), 3)  | 
 | 75 | + | 
53 | 76 |     def test_embed(self) -> None:  | 
54 | 77 |         response = self.client.embed(  | 
55 | 78 |             model=self.models["embed_model"],  | 
 | 
0 commit comments