Skip to content

Commit e67e226

Browse files
malexwmisspia-coheretianjing-liBeatrixCoherelusmoura
authored
frontend: Login, logout, and account creation (#179)
* add login page components * Add Register page and hooks for auth * Add the register page and connect all the frontend elements * Redirect to /login if the token expires and clean up some console errors * Add error messages for failed logins * frontend: Add Single Sign-on to Toolkit (#227) * Add Google SSO login plus OpenID components * Dynamically set SSO login buttons and show or hide username and password based on auth_strategies --------- Co-authored-by: Tianjing Li <[email protected]> * fix startup event * Fix build errors and update the API client * Fix tests by adding missing env vars Add test OIDC_WELL_KNOWN_ENDPOINT var to fixtures * Add a walkthrough guide of the toolkit (#251) * GUide * Chang * Change * Change * Update docs/walkthrough/walkthrough.md Co-authored-by: Luísa Moura <[email protected]> * Update walkthrough.md --------- Co-authored-by: Luísa Moura <[email protected]> * coral-web: fix agent info panel opening by default (#253) cast isEditAgentPanelOpen to boolean * [backend] enforce agent update with user-id (#246) * updates * remove client changes * remove logs * use better header user id check * fix validators * typo * Metrics: add middleware (#185) * Metrics: add middleware * add chat calls * merge * lint * make it async * add user id * add more fields * add retry and duration * add meta * comments * fix tests * improve error handling * rename fields * match spec * comments * clean code * only create loop when theres endpoint * add assistant id to chat * feat(toolkit): show assistant welcome message (#255) * feat(toolkit): show assistant welcome message * feat(toolkit): show assistant welcome message --------- Co-authored-by: misspia-cohere <[email protected]> Co-authored-by: Tianjing Li <[email protected]> Co-authored-by: Beatrix De Wilde <[email protected]> Co-authored-by: Luísa Moura <[email protected]> Co-authored-by: misspia-cohere <[email protected]> Co-authored-by: Scott <[email protected]> Co-authored-by: Khalil Najjar <[email protected]>
1 parent 338a9fa commit e67e226

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+8455
-169
lines changed

.env-template

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ USE_AGENTS_VIEW=False
4242
USE_COMMUNITY_FEATURES='True'
4343

4444
# For setting up authentication, see: docs/auth_guide.md
45-
JWT_SECRET_KEY=<See auth.guide.md on how to generate a secure one>
45+
AUTH_SECRET_KEY=<See auth.guide.md on how to generate a secure one>
46+
# Required for specifying Redirect URI
47+
FRONTEND_HOSTNAME=http://localhost:4000
4648

4749
# Google OAuth
4850
GOOGLE_CLIENT_ID=<GOOGLE_CLIENT_ID>
@@ -51,4 +53,4 @@ GOOGLE_CLIENT_SECRET=<GOOGLE_CLIENT_SECRET>
5153
# OpenID Connect
5254
OIDC_CLIENT_ID=<OIDC_CLIENT_ID>
5355
OIDC_CLIENT_SECRET=<OIDC_CLIENT_SECRET>
54-
OIDC_CONFIG_ENDPOINT=<OIDC_CONFIG_ENDPOINT>
56+
OIDC_WELL_KNOWN_ENDPOINT=<OIDC_WELL_KNOWN_ENDPOINT>

src/backend/config/auth.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,12 @@ def is_authentication_enabled() -> bool:
2121
return True
2222

2323
return False
24+
25+
26+
async def get_auth_strategy_endpoints() -> None:
27+
"""
28+
Fetches the endpoints for each enabled strategy.
29+
"""
30+
for strategy in ENABLED_AUTH_STRATEGY_MAPPING.values():
31+
if hasattr(strategy, "get_endpoints"):
32+
await strategy.get_endpoints()

src/backend/config/routers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class RouterName(StrEnum):
9494
],
9595
"auth": [
9696
Depends(get_session),
97-
Depends(validate_authorization),
97+
# TODO: Add if the router's have to have authorization
98+
# Depends(validate_authorization),
9899
],
99100
},
100101
}

src/backend/main.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import asyncio
2+
import os
13
from contextlib import asynccontextmanager
24

35
from alembic.command import upgrade
46
from alembic.config import Config
57
from dotenv import load_dotenv
68
from fastapi import FastAPI, HTTPException
79
from fastapi.middleware.cors import CORSMiddleware
10+
from starlette.middleware.sessions import SessionMiddleware
811

9-
from backend.config.auth import is_authentication_enabled
12+
from backend.config.auth import get_auth_strategy_endpoints, is_authentication_enabled
1013
from backend.config.routers import ROUTER_DEPENDENCIES
1114
from backend.routers.agent import router as agent_router
1215
from backend.routers.auth import router as auth_router
@@ -23,17 +26,10 @@
2326

2427
# CORS Origins
2528
ORIGINS = ["*"]
26-
# Session expiration time in seconds, set to None to last only browser session
27-
SESSION_EXPIRY = 60 * 60 * 24 * 7 # A week
28-
29-
30-
@asynccontextmanager
31-
async def lifespan(app: FastAPI):
32-
yield
3329

3430

3531
def create_app():
36-
app = FastAPI(lifespan=lifespan)
32+
app = FastAPI()
3733

3834
routers = [
3935
auth_router,
@@ -50,6 +46,10 @@ def create_app():
5046
# These values must be set in config/routers.py
5147
dependencies_type = "default"
5248
if is_authentication_enabled():
49+
# Required to save temporary OAuth state in session
50+
app.add_middleware(
51+
SessionMiddleware, secret_key=os.environ.get("AUTH_SECRET_KEY")
52+
)
5353
dependencies_type = "auth"
5454
for router in routers:
5555
if getattr(router, "name", "") in ROUTER_DEPENDENCIES.keys():
@@ -76,6 +76,15 @@ def create_app():
7676
app = create_app()
7777

7878

79+
@app.on_event("startup")
80+
async def startup_event():
81+
"""
82+
Retrieves all the Auth provider endpoints if authentication is enabled.
83+
"""
84+
if is_authentication_enabled():
85+
await get_auth_strategy_endpoints()
86+
87+
7988
@app.get("/health")
8089
async def health():
8190
"""

src/backend/routers/auth.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,39 @@ def get_strategies() -> list[ListAuthStrategy]:
3232
List[dict]: List of dictionaries containing the enabled auth strategy names.
3333
"""
3434
strategies = []
35-
for key in ENABLED_AUTH_STRATEGY_MAPPING.keys():
36-
strategies.append({"strategy": key})
35+
for strategy_name, strategy_instance in ENABLED_AUTH_STRATEGY_MAPPING.items():
36+
strategies.append(
37+
{
38+
"strategy": strategy_name,
39+
"client_id": (
40+
strategy_instance.get_client_id()
41+
if hasattr(strategy_instance, "get_client_id")
42+
else None
43+
),
44+
"authorization_endpoint": (
45+
strategy_instance.get_authorization_endpoint()
46+
if hasattr(strategy_instance, "get_authorization_endpoint")
47+
else None
48+
),
49+
}
50+
)
3751

3852
return strategies
3953

4054

4155
@router.post("/login", response_model=Union[JWTResponse, None])
4256
async def login(request: Request, login: Login, session: DBSessionDep):
4357
"""
44-
Logs user in and either:
45-
- (Basic email/password authentication) Verifies their credentials, retrieves the user and returns a JWT token.
46-
- (OAuth) Redirects to the /auth endpoint.
58+
Logs user in, performing basic email/password auth.
59+
Verifies their credentials, retrieves the user and returns a JWT token.
4760
4861
Args:
4962
request (Request): current Request object.
5063
login (Login): Login payload.
5164
session (DBSessionDep): Database session.
5265
5366
Returns:
54-
dict: JWT token on basic auth success
55-
or
56-
Redirect: to /auth endpoint
67+
dict: JWT token on Basic auth success
5768
5869
Raises:
5970
HTTPException: If the strategy or payload are invalid, or if the login fails.
@@ -76,27 +87,20 @@ async def login(request: Request, login: Login, session: DBSessionDep):
7687
detail=f"Missing the following keys in the payload: {missing_keys}.",
7788
)
7889

79-
# Login with redirect to /auth
80-
if strategy.SHOULD_AUTH_REDIRECT:
81-
# Fetch endpoint with method name
82-
redirect_uri = request.url_for(strategy.REDIRECT_METHOD_NAME)
83-
return await strategy.login(request, redirect_uri)
84-
# Login with email/password and set session directly
85-
else:
86-
user = strategy.login(session, payload)
87-
if not user:
88-
raise HTTPException(
89-
status_code=401,
90-
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
91-
)
90+
user = strategy.login(session, payload)
91+
if not user:
92+
raise HTTPException(
93+
status_code=401,
94+
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
95+
)
9296

93-
token = JWTService().create_and_encode_jwt(user)
97+
token = JWTService().create_and_encode_jwt(user)
9498

95-
return {"token": token}
99+
return {"token": token}
96100

97101

98102
@router.get("/google/auth", response_model=JWTResponse)
99-
async def google_authenticate(request: Request, session: DBSessionDep):
103+
async def google_authorize(request: Request, session: DBSessionDep):
100104
"""
101105
Callback authentication endpoint used for Google OAuth after redirecting to
102106
the service's login screen.
@@ -112,11 +116,11 @@ async def google_authenticate(request: Request, session: DBSessionDep):
112116
"""
113117
strategy_name = GoogleOAuth.NAME
114118

115-
return await authenticate(request, session, strategy_name)
119+
return await authorize(request, session, strategy_name)
116120

117121

118122
@router.get("/oidc/auth", response_model=JWTResponse)
119-
async def oidc_authenticate(request: Request, session: DBSessionDep):
123+
async def oidc_authorize(request: Request, session: DBSessionDep):
120124
"""
121125
Callback authentication endpoint used for OIDC after redirecting to
122126
the service's login screen.
@@ -132,7 +136,8 @@ async def oidc_authenticate(request: Request, session: DBSessionDep):
132136
"""
133137
strategy_name = OpenIDConnect.NAME
134138

135-
return await authenticate(request, session, strategy_name)
139+
# TODO: Merge authorize endpoints into single one
140+
return await authorize(request, session, strategy_name)
136141

137142

138143
@router.get("/logout", response_model=Logout)
@@ -157,7 +162,7 @@ async def logout(
157162
return {}
158163

159164

160-
async def authenticate(
165+
async def authorize(
161166
request: Request, session: DBSessionDep, strategy_name: str
162167
) -> JWTResponse:
163168
if not is_enabled_authentication_strategy(strategy_name):
@@ -168,22 +173,20 @@ async def authenticate(
168173
strategy = ENABLED_AUTH_STRATEGY_MAPPING[strategy_name]
169174

170175
try:
171-
token = await strategy.authenticate(request)
176+
userinfo = await strategy.authorize(request)
172177
except OAuthError as e:
173178
raise HTTPException(
174-
status_code=401,
175-
detail=f"Could not authenticate, failed with error: {str(e)}",
179+
status_code=400,
180+
detail=f"Could not fetch access token from provider, failed with error: {str(e)}",
176181
)
177182

178-
token_user = token.get("userinfo")
179-
180-
if not token_user:
183+
if not userinfo:
181184
raise HTTPException(
182185
status_code=401, detail=f"Could not get user from auth token: {token}."
183186
)
184187

185188
# Get or create user, then set session user
186-
user = get_or_create_user(session, token_user)
189+
user = get_or_create_user(session, userinfo)
187190

188191
token = JWTService().create_and_encode_jwt(user)
189192

src/backend/schemas/auth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class Logout(BaseModel):
1717

1818
class ListAuthStrategy(BaseModel):
1919
strategy: str
20+
client_id: str | None
21+
authorization_endpoint: str | None
2022

2123

2224
class JWTResponse(BaseModel):

src/backend/services/auth/jwt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ class JWTService:
1616
ALGORITHM = "HS256"
1717

1818
def __init__(self):
19-
secret_key = os.environ.get("JWT_SECRET_KEY")
19+
secret_key = os.environ.get("AUTH_SECRET_KEY")
2020

2121
if not secret_key:
2222
raise ValueError(
23-
"JWT_SECRET_KEY environment variable is missing, and is required to enable authentication."
23+
"AUTH_SECRET_KEY environment variable is missing, and is required to enable authentication."
2424
)
2525

2626
self.secret_key = secret_key

src/backend/services/auth/strategies/base.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@ class BaseAuthenticationStrategy:
88
99
Attributes:
1010
NAME (str): The name of the strategy.
11-
SHOULD_AUTH_REDIRECT (str): Whether the strategy requires a redirect to the /auth endpoint after login.
1211
"""
1312

1413
NAME = "Base"
15-
SHOULD_AUTH_REDIRECT = False
1614

1715
@staticmethod
1816
def get_required_payload(self) -> List[str]:
@@ -24,35 +22,56 @@ def get_required_payload(self) -> List[str]:
2422
@abstractmethod
2523
def login(self, **kwargs: Any):
2624
"""
27-
Login logic: dealing with checking credentials, returning user object
28-
to store into session if finished. For OAuth strategies, the next step
29-
will be to authenticate.
25+
Check email/password credentials and return JWT token.
3026
"""
3127
...
3228

3329

34-
class BaseOAuthStrategy(BaseAuthenticationStrategy):
30+
class BaseOAuthStrategy:
3531
"""
3632
Base strategy for OAuth, abstract class that should be inherited from.
3733
3834
Attributes:
3935
NAME (str): The name of the strategy.
40-
SHOULD_AUTH_REDIRECT (str): Whether the strategy requires a redirect to the /auth endpoint after login.
41-
REDIRECT_METHOD_NAME (str | None): The router method name that should be used for redirect callback.
4236
"""
4337

44-
SHOULD_AUTH_REDIRECT = True
45-
REDIRECT_METHOD_NAME = None
38+
NAME = None
4639

47-
def __init__subclass(cls, **kwargs):
48-
super().__init__subclass__(**kwargs)
49-
if cls.REDIRECT_METHOD_NAME is None:
50-
raise ValueError(
51-
f"{cls.__name__} must have a REDIRECT_METHOD_NAME defined, and a corresponding router definition."
52-
)
40+
def __init__(self, *args, **kwargs):
41+
super().__init__(*args, **kwargs)
42+
self._post_init_check()
43+
44+
def _post_init_check(self):
45+
if any(
46+
[
47+
self.NAME is None,
48+
]
49+
):
50+
raise ValueError(f"{self.__name__} must have NAME parameter(s) defined.")
51+
52+
@abstractmethod
53+
def get_client_id(self, **kwargs: Any):
54+
"""
55+
Retrieves the OAuth app's client ID
56+
"""
57+
...
58+
59+
@abstractmethod
60+
def get_authorization_endpoint(self, **kwargs: Any):
61+
"""
62+
Retrieves the OAuth app's authorization endpoint.
63+
"""
64+
...
65+
66+
@abstractmethod
67+
async def get_endpoints(self, **kwargs: Any):
68+
"""
69+
Retrieves the /token and /userinfo endpoints.
70+
"""
71+
...
5372

5473
@abstractmethod
55-
def authenticate(self, **kwargs: Any):
74+
async def authorize(self, **kwargs: Any):
5675
"""
5776
Authentication logic: dealing with user data and returning it
5877
to set the current user session for OAuth strategies.

0 commit comments

Comments
 (0)