Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/ai/azure-ai-projects/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/ai/azure-ai-projects",
"Tag": "python/ai/azure-ai-projects_85a150439d"
"Tag": "python/ai/azure-ai-projects_f7878e759e"
}
55 changes: 55 additions & 0 deletions sdk/ai/azure-ai-projects/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@


class TestBase(AzureRecordedTestCase):
test_redteams_params = {
# cSpell:disable-next-line
"connection_name": "naposaniwestus3",
"connection_type": ConnectionType.AZURE_OPEN_AI,
"model_deployment_name": "gpt-4o-mini"
}

test_connections_params = {
"connection_name": "connection1",
Expand Down Expand Up @@ -116,6 +122,55 @@ def validate_connection(
assert connection.credentials.type == CredentialType.API_KEY
assert connection.credentials.api_key is not None

@classmethod
def validate_red_team_response(cls, response, expected_attack_strategies: int = -1, expected_risk_categories: int = -1):
"""Assert basic red team scan response properties."""
assert response is not None
assert hasattr(response, 'name')
assert hasattr(response, 'display_name')
assert hasattr(response, 'status')
assert hasattr(response, 'attack_strategies')
assert hasattr(response, 'risk_categories')
assert hasattr(response, 'target')
assert hasattr(response, 'properties')

# Validate attack strategies and risk categories
if expected_attack_strategies != -1:
assert len(response.attack_strategies) == expected_attack_strategies
if expected_risk_categories != -1:
assert len(response.risk_categories) == expected_risk_categories
assert response.status is not None
cls._assert_azure_ml_properties(response)

@classmethod
def _assert_azure_ml_properties(cls, response):
"""Assert Azure ML specific properties are present and valid."""
properties = response.properties
assert properties is not None, "Red team scan properties should not be None"

required_properties = [
'runType',
'redteaming',
'_azureml.evaluation_run',
'_azureml.evaluate_artifacts',
'AiStudioEvaluationUri'
]

for prop in required_properties:
assert prop in properties, f"Missing required property: {prop}"

# Validate specific property values
assert properties['runType'] == 'eval_run'
assert properties['_azureml.evaluation_run'] == 'evaluation.service'
assert 'instance_results.json' in properties['_azureml.evaluate_artifacts']
assert properties['redteaming'] == 'asr'

# Validate AI Studio URI format
ai_studio_uri = properties['AiStudioEvaluationUri']
assert ai_studio_uri.startswith('https://ai.azure.com/resource/build/redteaming/')
assert 'wsid=' in ai_studio_uri
assert 'tid=' in ai_studio_uri

@classmethod
def validate_deployment(
cls,
Expand Down
64 changes: 64 additions & 0 deletions sdk/ai/azure-ai-projects/tests/test_redteams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from azure.ai.projects import AIProjectClient
from azure.ai.projects.models import (
RedTeam,
AzureOpenAIModelConfiguration,
AttackStrategy,
RiskCategory,
)
from test_base import TestBase, servicePreparer
from devtools_testutils import recorded_by_proxy


class TestRedTeams(TestBase):

# To run this test, use the following command in the \sdk\ai\azure-ai-projects folder:
# cls & pytest tests\test_redteams.py::TestRedTeams::test_red_teams -s
@servicePreparer()
@recorded_by_proxy
def test_red_teams(self, **kwargs):

endpoint = kwargs.pop("azure_ai_projects_tests_project_endpoint")
print("\n=====> Endpoint:", endpoint)

connection_name = self.test_redteams_params["connection_name"]
model_deployment_name = self.test_redteams_params["model_deployment_name"]

with AIProjectClient(
endpoint=endpoint,
credential=self.get_credential(AIProjectClient, is_async=False),
) as project_client:

# [START red_team_sample]
print("Creating a Red Team scan for direct model testing")

# Create target configuration for testing an Azure OpenAI model
target_config = AzureOpenAIModelConfiguration(model_deployment_name=f"{connection_name}/{model_deployment_name}")

# Create the Red Team configuration
red_team = RedTeam(
attack_strategies=[AttackStrategy.BASE64],
risk_categories=[RiskCategory.VIOLENCE],
display_name="redteamtest1", # Use a simpler name
target=target_config,
)

# Create and run the Red Team scan
red_team_response = project_client.red_teams.create(red_team=red_team)
print(f"Red Team scan created with scan name: {red_team_response.name}")
TestBase.validate_red_team_response(red_team_response, expected_attack_strategies=1, expected_risk_categories=1)

print("Getting Red Team scan details")
# Use the name returned by the create operation for the get call
get_red_team_response = project_client.red_teams.get(name=red_team_response.name)
print(f"Red Team scan status: {get_red_team_response.status}")
TestBase.validate_red_team_response(get_red_team_response, expected_attack_strategies=1, expected_risk_categories=1)

print("Listing all Red Team scans")
for scan in project_client.red_teams.list():
print(f"Found scan: {scan.name}, Status: {scan.status}")
TestBase.validate_red_team_response(scan)
64 changes: 64 additions & 0 deletions sdk/ai/azure-ai-projects/tests/test_redteams_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import (
RedTeam,
AzureOpenAIModelConfiguration,
AttackStrategy,
RiskCategory,
)
from test_base import TestBase, servicePreparer
from devtools_testutils.aio import recorded_by_proxy_async


class TestRedTeams(TestBase):

# To run this test, use the following command in the \sdk\ai\azure-ai-projects folder:
# cls & pytest tests\test_redteams.py::TestRedTeams::test_red_teams_async -s
@servicePreparer()
@recorded_by_proxy_async
async def test_red_teams_async(self, **kwargs):

endpoint = kwargs.pop("azure_ai_projects_tests_project_endpoint")
print("\n=====> Endpoint:", endpoint)

connection_name = self.test_redteams_params["connection_name"]
model_deployment_name = self.test_redteams_params["model_deployment_name"]

async with AIProjectClient(
endpoint=endpoint,
credential=self.get_credential(AIProjectClient, is_async=True),
) as project_client:

# [START red_team_sample]
print("Creating a Red Team scan for direct model testing")

# Create target configuration for testing an Azure OpenAI model
target_config = AzureOpenAIModelConfiguration(model_deployment_name=f"{connection_name}/{model_deployment_name}")

# Create the Red Team configuration
red_team = RedTeam(
attack_strategies=[AttackStrategy.BASE64],
risk_categories=[RiskCategory.VIOLENCE],
display_name="redteamtest1", # Use a simpler name
target=target_config,
)

# Create and run the Red Team scan
red_team_response = await project_client.red_teams.create(red_team=red_team)
print(f"Red Team scan created with scan name: {red_team_response.name}")
TestBase.validate_red_team_response(red_team_response, expected_attack_strategies=1, expected_risk_categories=1)

print("Getting Red Team scan details")
# Use the name returned by the create operation for the get call
get_red_team_response = await project_client.red_teams.get(name=red_team_response.name)
print(f"Red Team scan status: {get_red_team_response.status}")
TestBase.validate_red_team_response(get_red_team_response, expected_attack_strategies=1, expected_risk_categories=1)

print("Listing all Red Team scans")
async for scan in project_client.red_teams.list():
print(f"Found scan: {scan.name}, Status: {scan.status}")
TestBase.validate_red_team_response(scan)