Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }


def get_available_deployments() -> list[type[BaseDeployment]]:
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())
def get_available_deployments() -> dict[str, type[BaseDeployment]]:
installed_deployments = ALL_MODEL_DEPLOYMENTS.copy()

if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())

installed_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured", ex=e
)

enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]
return {
key: value
for key, value in installed_deployments.items()
if value.id() in enabled_deployment_ids
}

return installed_deployments

Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def update_agent(
if agent.is_private and agent.user_id != user_id:
return None

new_agent_cleaned = new_agent.dict(exclude_unset=True, exclude_none=True)
new_agent_cleaned = new_agent.model_dump(exclude_unset=True, exclude_none=True)

for attr, value in new_agent_cleaned.items():
setattr(agent, attr, value)
Expand Down
4 changes: 2 additions & 2 deletions src/backend/scripts/cli/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def core_env_var_prompt(secrets):


def deployment_prompt(secrets, configs):
for secret in configs.env_vars:
for secret in configs.env_vars():
value = secrets.get(secret)

if not value:
Expand Down Expand Up @@ -149,7 +149,7 @@ def select_deployments_prompt(deployments, _):

deployments = inquirer.checkbox(
"Select the model deployments you want to set up",
choices=[deployment.value for deployment in deployments.keys()],
choices=[deployment for deployment in deployments.keys()],
default=["Cohere Platform"],
validate=lambda _, x: len(x) > 0,
)
Expand Down
12 changes: 7 additions & 5 deletions src/backend/services/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_db_deployment(session: DBSessionDep, deployment: DeploymentDefinition

def get_default_deployment(**kwargs) -> BaseDeployment:
try:
fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.is_available)
fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.is_available())
except StopIteration:
raise NoAvailableDeploymentsError()

Expand All @@ -47,7 +47,7 @@ def get_default_deployment(**kwargs) -> BaseDeployment:
return next(
(
d
for d in AVAILABLE_MODEL_DEPLOYMENTS
for d in AVAILABLE_MODEL_DEPLOYMENTS.values()
if d.id() == default_deployment
),
fallback,
Expand All @@ -63,7 +63,9 @@ def get_deployment_by_name(session: DBSessionDep, deployment_name: str, **kwargs
definition = get_deployment_definition_by_name(session, deployment_name)

try:
return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.__name__ == definition.class_name)(db_id=definition.id, **definition.config, **kwargs)
return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == definition.class_name)(
db_id=definition.id, **definition.config, **kwargs
)
except StopIteration:
raise DeploymentNotFoundError(deployment_id=deployment_name)

Expand All @@ -73,7 +75,7 @@ def get_deployment_definition(session: DBSessionDep, deployment_id: str) -> Depl
return DeploymentDefinition.from_db_deployment(db_deployment)

try:
deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.id() == deployment_id)
deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.id() == deployment_id)
except StopIteration:
raise DeploymentNotFoundError(deployment_id=deployment_id)

Expand Down Expand Up @@ -101,7 +103,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti

installed_deployments = [
deployment.to_deployment_definition()
for deployment in AVAILABLE_MODEL_DEPLOYMENTS
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values()
if deployment.name() not in db_deployments
]

Expand Down
4 changes: 3 additions & 1 deletion src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
detail=f"Deployment {deployment} not found or is not available in the Database.",
)

deployment_config = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.__name__ == found.class_name).to_deployment_definition()
deployment_config = next(
d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == found.class_name
).to_deployment_definition()
deployment_model = next(
(
model_db
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def mock_available_model_deployments(request):
MockBedrockDeployment.name(): MockBedrockDeployment,
}

with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock:
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock:
yield mock

@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/integration/routers/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_list_deployments_no_available_models_404(
session_client: TestClient, session: Session
) -> None:
session.query(Deployment).delete()
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", []):
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}):
response = session_client.get("/v1/deployments")
assert response.status_code == 404
assert response.json() == {
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,5 @@ def mock_available_model_deployments(request):
MockSingleContainerDeployment.name(): MockSingleContainerDeployment,
}

with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock:
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock:
yield mock
7 changes: 5 additions & 2 deletions src/backend/tests/unit/services/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_all_tools_have_id() -> None:
assert tool.value.ID is not None

def test_get_default_deployment_none_available() -> None:
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", []):
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}):
with pytest.raises(NoAvailableDeploymentsError):
deployment_service.get_default_deployment()

Expand Down Expand Up @@ -106,7 +106,10 @@ def test_get_deployment_definitions_with_db_deployments(session, mock_available_
id="db-mock-cohere-platform-id",
)
with patch("backend.crud.deployment.get_deployments", return_value=[mock_cohere_deployment]):
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", [MockCohereDeployment, MockAzureDeployment]):
with patch(
"backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS",
{ MockCohereDeployment.name(): MockCohereDeployment, MockAzureDeployment.name(): MockAzureDeployment }
):
definitions = deployment_service.get_deployment_definitions(session)

assert len(definitions) == 2
Expand Down
Loading