Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class OAuth2Auth(BaseModelWithConfig):
refresh_token: Optional[str] = None
expires_at: Optional[int] = None
expires_in: Optional[int] = None
audience: Optional[str] = None
prompt: Optional[str] = None


class ServiceAccountCredential(BaseModelWithConfig):
Expand Down
6 changes: 5 additions & 1 deletion src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,12 @@ def generate_auth_uri(
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {"access_type": "offline"}
params["prompt"] = auth_credential.oauth2.prompt or "consent"
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, access_type="offline", prompt="consent"
url=authorization_endpoint, **params
)
exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
Expand Down
24 changes: 23 additions & 1 deletion tests/unittests/auth/test_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def __init__(
self.state = state

def create_authorization_url(self, url, **kwargs):
return f"{url}?client_id={self.client_id}&scope={self.scope}", "mock_state"
params = f"client_id={self.client_id}&scope={self.scope}"
if kwargs.get("audience"):
params += f"&audience={kwargs.get('audience')}"
return f"{url}?{params}", "mock_state"

def fetch_token(
self,
Expand Down Expand Up @@ -225,8 +228,27 @@ def test_generate_auth_uri_oauth2(self, auth_config):
"https://example.com/oauth2/authorize"
)
assert "client_id=mock_client_id" in result.oauth2.auth_uri
assert "audience" not in result.oauth2.auth_uri
assert result.oauth2.state == "mock_state"

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_with_audience_and_prompt(
self, openid_auth_scheme, oauth2_credentials
):
"""Test generating an auth URI with audience and prompt."""
oauth2_credentials.oauth2.audience = "test_audience"
exchanged = oauth2_credentials.model_copy(deep=True)

config = AuthConfig(
auth_scheme=openid_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=exchanged,
)
handler = AuthHandler(config)
result = handler.generate_auth_uri()

assert "audience=test_audience" in result.oauth2.auth_uri

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_openid(
self, openid_auth_scheme, oauth2_credentials
Expand Down