-
Notifications
You must be signed in to change notification settings - Fork 246
Open
Labels
enhancementNew feature or requestNew feature or requestsecurityImproves securityImproves securitytriageIssues / Features awaiting triageIssues / Features awaiting triage
Milestone
Description
🛡️ FEATURE: CSRF Token Protection System
Summary: Implement comprehensive CSRF (Cross-Site Request Forgery) protection for all state-changing operations in MCP Gateway. This includes double-submit cookie pattern, per-session tokens, SameSite cookies, and automatic token validation for forms and API requests.
Depends on: Existing JWT authentication system
Implementation
1. Update config.py
with CSRF Settings
# In mcpgateway/config.py
class Settings(BaseSettings):
# ... existing settings ...
# ===================================
# CSRF Protection Settings
# ===================================
# CSRF Token Configuration
csrf_enabled: bool = True
csrf_token_name: str = "X-CSRF-Token"
csrf_cookie_name: str = "csrf_token"
csrf_header_name: str = "X-CSRF-Token"
csrf_form_field_name: str = "csrf_token"
csrf_token_length: int = 32
csrf_token_expiry: int = 3600 # 1 hour in seconds
# Cookie Settings
csrf_cookie_secure: bool = True # HTTPS only in production
csrf_cookie_httponly: bool = False # Must be readable by JS
csrf_cookie_samesite: str = "Strict"
csrf_cookie_path: str = "/"
csrf_cookie_domain: Optional[str] = None
# Exempted Paths (no CSRF check)
csrf_exempt_paths: List[str] = [
"/health",
"/auth/login", # Login endpoint generates CSRF token
"/auth/refresh",
"/docs",
"/openapi.json",
"/metrics"
]
# Safe Methods (no CSRF check)
csrf_safe_methods: List[str] = ["GET", "HEAD", "OPTIONS", "TRACE"]
# Token Rotation
csrf_rotate_on_login: bool = True
csrf_rotate_on_error: bool = True
# Additional Security
csrf_check_referer: bool = True
csrf_trusted_origins: List[str] = [] # Add trusted origins in production
2. Create CSRF Service
# Create mcpgateway/services/csrf_service.py
import secrets
import time
import hashlib
from typing import Optional, Dict, Tuple
from fastapi import Request, Response, HTTPException
from jose import jwt
from mcpgateway.config import settings
from mcpgateway.exceptions import CSRFError
class CSRFService:
"""Service for managing CSRF tokens and validation."""
def __init__(self):
self.token_cache: Dict[str, float] = {} # token -> expiry timestamp
def generate_token(self, user_id: str, session_id: str) -> str:
"""Generate a new CSRF token for a user session."""
# Create token with user and session binding
random_data = secrets.token_urlsafe(settings.csrf_token_length)
token_data = f"{user_id}:{session_id}:{random_data}:{int(time.time())}"
# Create HMAC for integrity
token_hash = hashlib.sha256(
f"{token_data}:{settings.jwt_secret_key}".encode()
).hexdigest()[:16]
token = f"{random_data}.{token_hash}"
# Cache token with expiry
expiry = time.time() + settings.csrf_token_expiry
self.token_cache[token] = expiry
# Clean expired tokens periodically
if len(self.token_cache) > 1000:
self._cleanup_expired_tokens()
return token
def validate_token(
self,
token: str,
user_id: str,
session_id: str
) -> bool:
"""Validate CSRF token."""
if not token:
return False
# Check token format
parts = token.split(".")
if len(parts) != 2:
return False
# Check if token exists and not expired
if token not in self.token_cache:
return False
if time.time() > self.token_cache[token]:
del self.token_cache[token]
return False
return True
def extract_token_from_request(self, request: Request) -> Optional[str]:
"""Extract CSRF token from request (header, form, or query)."""
# Check header first (for AJAX requests)
token = request.headers.get(settings.csrf_header_name)
if token:
return token
# Check form data
if request.method == "POST":
content_type = request.headers.get("content-type", "")
if "application/x-www-form-urlencoded" in content_type:
form_data = await request.form()
token = form_data.get(settings.csrf_form_field_name)
if token:
return token
elif "multipart/form-data" in content_type:
form_data = await request.form()
token = form_data.get(settings.csrf_form_field_name)
if token:
return token
# Check JSON body for API requests
if "application/json" in request.headers.get("content-type", ""):
try:
body = await request.json()
token = body.get(settings.csrf_form_field_name)
if token:
return token
except:
pass
# Check cookie (double-submit pattern)
token = request.cookies.get(settings.csrf_cookie_name)
return token
def set_csrf_cookie(
self,
response: Response,
token: str,
secure: Optional[bool] = None
) -> None:
"""Set CSRF token cookie."""
response.set_cookie(
key=settings.csrf_cookie_name,
value=token,
max_age=settings.csrf_token_expiry,
secure=secure if secure is not None else settings.csrf_cookie_secure,
httponly=settings.csrf_cookie_httponly,
samesite=settings.csrf_cookie_samesite,
path=settings.csrf_cookie_path,
domain=settings.csrf_cookie_domain
)
def clear_csrf_cookie(self, response: Response) -> None:
"""Clear CSRF token cookie."""
response.delete_cookie(
key=settings.csrf_cookie_name,
path=settings.csrf_cookie_path,
domain=settings.csrf_cookie_domain
)
def check_referer(self, request: Request) -> bool:
"""Validate referer header for additional security."""
if not settings.csrf_check_referer:
return True
referer = request.headers.get("referer")
if not referer:
return False
# Parse referer
from urllib.parse import urlparse
referer_parsed = urlparse(referer)
# Check against host
host = request.headers.get("host")
if referer_parsed.netloc == host:
return True
# Check against trusted origins
for origin in settings.csrf_trusted_origins:
origin_parsed = urlparse(origin)
if referer_parsed.netloc == origin_parsed.netloc:
return True
return False
def _cleanup_expired_tokens(self) -> None:
"""Remove expired tokens from cache."""
current_time = time.time()
expired_tokens = [
token for token, expiry in self.token_cache.items()
if current_time > expiry
]
for token in expired_tokens:
del self.token_cache[token]
# Global instance
csrf_service = CSRFService()
3. Create CSRF Middleware
# Create mcpgateway/middleware/csrf_middleware.py
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from mcpgateway.config import settings
from mcpgateway.services.csrf_service import csrf_service
from mcpgateway.dependencies import get_current_user_optional
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce CSRF protection."""
async def dispatch(self, request: Request, call_next):
# Skip CSRF check if disabled
if not settings.csrf_enabled:
return await call_next(request)
# Skip for exempted paths
if any(request.url.path.startswith(path) for path in settings.csrf_exempt_paths):
return await call_next(request)
# Skip for safe methods
if request.method in settings.csrf_safe_methods:
response = await call_next(request)
# Add CSRF token to response for authenticated users
user = await get_current_user_optional(request)
if user:
await self._add_csrf_token_to_response(request, response, user)
return response
# For state-changing requests, validate CSRF token
try:
# Get current user
user = await get_current_user_optional(request)
if not user:
# No CSRF check for unauthenticated requests
return await call_next(request)
# Extract token from request
token = await csrf_service.extract_token_from_request(request)
if not token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing"}
)
# Get session ID from JWT or session
session_id = getattr(user, "session_id", str(user.id))
# Validate token
if not csrf_service.validate_token(token, str(user.id), session_id):
return JSONResponse(
status_code=403,
content={"detail": "Invalid CSRF token"}
)
# Check referer
if not csrf_service.check_referer(request):
return JSONResponse(
status_code=403,
content={"detail": "Invalid referer"}
)
# Process request
response = await call_next(request)
# Rotate token if needed
if settings.csrf_rotate_on_error and response.status_code >= 400:
await self._add_csrf_token_to_response(request, response, user, rotate=True)
return response
except Exception as e:
return JSONResponse(
status_code=403,
content={"detail": f"CSRF validation error: {str(e)}"}
)
async def _add_csrf_token_to_response(
self,
request: Request,
response: Response,
user,
rotate: bool = False
):
"""Add CSRF token to response headers and cookies."""
# Check if we need a new token
existing_token = request.cookies.get(settings.csrf_cookie_name)
session_id = getattr(user, "session_id", str(user.id))
if not existing_token or rotate:
# Generate new token
new_token = csrf_service.generate_token(str(user.id), session_id)
# Set cookie
csrf_service.set_csrf_cookie(response, new_token)
# Add to response header for easy JS access
response.headers[settings.csrf_header_name] = new_token
4. Update Authentication Endpoints
# Update mcpgateway/routers/auth.py
from mcpgateway.services.csrf_service import csrf_service
@router.post("/login")
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
response: Response,
db: Session = Depends(get_db)
):
"""Login endpoint that generates CSRF token."""
# ... existing authentication logic ...
# After successful authentication
if user and access_token:
# Generate CSRF token
session_id = str(uuid.uuid4()) # Generate session ID
csrf_token = csrf_service.generate_token(str(user.id), session_id)
# Set CSRF cookie
csrf_service.set_csrf_cookie(response, csrf_token)
# Include CSRF token in response
return {
"access_token": access_token,
"token_type": "bearer",
"csrf_token": csrf_token, # For SPA applications
"user": {
"id": user.id,
"username": user.username,
"is_admin": user.is_admin
}
}
@router.post("/logout")
async def logout(
response: Response,
current_user: User = Depends(get_current_user)
):
"""Logout endpoint that clears CSRF token."""
# Clear CSRF cookie
csrf_service.clear_csrf_cookie(response)
# ... existing logout logic ...
return {"message": "Logged out successfully"}
@router.get("/csrf-token")
async def get_csrf_token(
response: Response,
current_user: User = Depends(get_current_user)
):
"""Get a fresh CSRF token."""
session_id = getattr(current_user, "session_id", str(current_user.id))
csrf_token = csrf_service.generate_token(str(current_user.id), session_id)
# Set cookie and header
csrf_service.set_csrf_cookie(response, csrf_token)
response.headers[settings.csrf_header_name] = csrf_token
return {"csrf_token": csrf_token}
5. Update Frontend JavaScript
// Update admin.js with CSRF support
// Global CSRF token management
let csrfToken = null;
// Get CSRF token from cookie
function getCSRFToken() {
const name = 'csrf_token=';
const decodedCookie = decodeURIComponent(document.cookie);
const ca = decodedCookie.split(';');
for(let i = 0; i < ca.length; i++) {
let c = ca[i];
while (c.charAt(0) == ' ') {
c = c.substring(1);
}
if (c.indexOf(name) == 0) {
return c.substring(name.length, c.length);
}
}
return null;
}
// Fetch new CSRF token
async function refreshCSRFToken() {
try {
const response = await fetch(`${window.ROOT_PATH}/auth/csrf-token`, {
credentials: 'include',
headers: {
'Authorization': `Bearer ${getAuthToken()}`
}
});
if (response.ok) {
const data = await response.json();
csrfToken = data.csrf_token;
return csrfToken;
}
} catch (error) {
console.error('Failed to refresh CSRF token:', error);
}
return null;
}
// Enhanced fetch with CSRF token
async function fetchWithCSRF(url, options = {}) {
// Get CSRF token
if (!csrfToken) {
csrfToken = getCSRFToken() || await refreshCSRFToken();
}
// Add CSRF token to headers for state-changing requests
const method = (options.method || 'GET').toUpperCase();
if (['POST', 'PUT', 'DELETE', 'PATCH'].includes(method)) {
options.headers = {
...options.headers,
'X-CSRF-Token': csrfToken
};
// For form data, add CSRF token as field
if (options.body instanceof FormData) {
options.body.append('csrf_token', csrfToken);
}
// For JSON requests, add to body
else if (options.headers?.['Content-Type'] === 'application/json' && options.body) {
const body = JSON.parse(options.body);
body.csrf_token = csrfToken;
options.body = JSON.stringify(body);
}
}
// Make request
const response = await fetchWithTimeout(url, options);
// If CSRF token invalid, refresh and retry once
if (response.status === 403) {
const text = await response.text();
if (text.includes('CSRF')) {
csrfToken = await refreshCSRFToken();
if (csrfToken && !options._retried) {
options._retried = true;
return fetchWithCSRF(url, options);
}
}
}
return response;
}
// Update all form submissions to use CSRF
async function handleGatewayFormSubmit(e) {
e.preventDefault();
const form = e.target;
const formData = new FormData(form);
try {
// ... validation logic ...
const response = await fetchWithCSRF(
`${window.ROOT_PATH}/admin/gateways`,
{
method: "POST",
body: formData,
}
);
// ... handle response ...
} catch (error) {
// ... error handling ...
}
}
// Initialize CSRF token on page load
document.addEventListener('DOMContentLoaded', async () => {
// Get initial CSRF token
csrfToken = getCSRFToken();
if (!csrfToken && isAuthenticated()) {
csrfToken = await refreshCSRFToken();
}
// ... rest of initialization ...
});
// Add CSRF token to all AJAX requests automatically
const originalFetch = window.fetch;
window.fetch = function(...args) {
if (args[1] && typeof args[1] === 'object') {
return fetchWithCSRF(args[0], args[1]);
}
return originalFetch.apply(this, args);
};
6. Update HTML Templates
<!-- Update admin.html forms with CSRF token fields -->
<!-- Add CSRF meta tag for easy access -->
<meta name="csrf-token" content="">
<!-- Update forms to include CSRF field -->
<form id="add-gateway-form" method="POST">
<input type="hidden" name="csrf_token" value="">
<!-- ... other form fields ... -->
</form>
<script>
// Populate CSRF token in forms
document.addEventListener('DOMContentLoaded', () => {
const csrfToken = getCSRFToken();
// Update meta tag
const metaTag = document.querySelector('meta[name="csrf-token"]');
if (metaTag && csrfToken) {
metaTag.content = csrfToken;
}
// Update all forms
document.querySelectorAll('input[name="csrf_token"]').forEach(input => {
input.value = csrfToken;
});
});
</script>
7. Add CSRF Decorator for Additional Protection
# Create mcpgateway/decorators/csrf_decorator.py
from functools import wraps
from fastapi import Request, HTTPException
from mcpgateway.services.csrf_service import csrf_service
from mcpgateway.dependencies import get_current_user
def require_csrf(exempt: bool = False):
"""Decorator to enforce CSRF protection on specific endpoints."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if exempt or not settings.csrf_enabled:
return await func(*args, **kwargs)
# Find request object in args/kwargs
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
request = kwargs.get('request')
if not request:
raise HTTPException(
status_code=500,
detail="CSRF protection requires Request object"
)
# Validate CSRF token
token = await csrf_service.extract_token_from_request(request)
user = await get_current_user(request)
session_id = getattr(user, "session_id", str(user.id))
if not csrf_service.validate_token(token, str(user.id), session_id):
raise HTTPException(
status_code=403,
detail="Invalid or missing CSRF token"
)
return await func(*args, **kwargs)
return wrapper
return decorator
# Usage example
@router.post("/admin/gateways")
@require_csrf()
async def create_gateway(
request: Request,
gateway_data: GatewayCreate,
current_user: User = Depends(get_current_admin)
):
# ... endpoint logic ...
8. Add CSRF Configuration UI
# Add to admin endpoints
@router.get("/admin/security/csrf")
async def get_csrf_config(
current_user: User = Depends(get_current_admin)
):
"""Get CSRF configuration."""
return {
"enabled": settings.csrf_enabled,
"token_name": settings.csrf_token_name,
"header_name": settings.csrf_header_name,
"cookie_settings": {
"secure": settings.csrf_cookie_secure,
"httponly": settings.csrf_cookie_httponly,
"samesite": settings.csrf_cookie_samesite
},
"token_expiry": settings.csrf_token_expiry,
"safe_methods": settings.csrf_safe_methods,
"exempt_paths": settings.csrf_exempt_paths
}
@router.post("/admin/security/csrf/rotate")
@require_csrf()
async def rotate_csrf_tokens(
response: Response,
current_user: User = Depends(get_current_admin)
):
"""Force rotation of all CSRF tokens."""
# Clear token cache
csrf_service.token_cache.clear()
# Generate new token for admin
session_id = getattr(current_user, "session_id", str(current_user.id))
new_token = csrf_service.generate_token(str(current_user.id), session_id)
csrf_service.set_csrf_cookie(response, new_token)
return {
"message": "CSRF tokens rotated successfully",
"new_token": new_token
}
9. Add Tests
# tests/test_csrf_protection.py
import pytest
from fastapi.testclient import TestClient
def test_csrf_token_generation(client: TestClient, auth_headers):
"""Test CSRF token is generated on login."""
response = client.post(
"/auth/login",
data={"username": "admin", "password": "password"}
)
assert response.status_code == 200
assert "csrf_token" in response.json()
assert "csrf_token" in response.cookies
def test_csrf_protection_enforced(client: TestClient, auth_headers):
"""Test CSRF protection blocks requests without token."""
# Try to create gateway without CSRF token
response = client.post(
"/admin/gateways",
headers=auth_headers,
json={"name": "test", "url": "http://example.com"}
)
assert response.status_code == 403
assert "CSRF token missing" in response.text
def test_csrf_token_validation(client: TestClient, auth_headers):
"""Test valid CSRF token allows request."""
# Get CSRF token
csrf_response = client.get("/auth/csrf-token", headers=auth_headers)
csrf_token = csrf_response.json()["csrf_token"]
# Make request with CSRF token
headers = {**auth_headers, "X-CSRF-Token": csrf_token}
response = client.post(
"/admin/gateways",
headers=headers,
json={"name": "test", "url": "http://example.com"}
)
assert response.status_code in [200, 201]
def test_csrf_safe_methods_exempt(client: TestClient, auth_headers):
"""Test safe methods don't require CSRF token."""
response = client.get("/admin/gateways", headers=auth_headers)
assert response.status_code == 200
10. Update .env.example
#####################################
# CSRF Protection Settings
#####################################
# Enable CSRF protection
CSRF_ENABLED=true
# Token configuration
CSRF_TOKEN_NAME=X-CSRF-Token
CSRF_COOKIE_NAME=csrf_token
CSRF_TOKEN_LENGTH=32
CSRF_TOKEN_EXPIRY=3600 # 1 hour
# Cookie settings
CSRF_COOKIE_SECURE=true # Set to false for development
CSRF_COOKIE_HTTPONLY=false # Must be false for JS access
CSRF_COOKIE_SAMESITE=Strict
# Security settings
CSRF_CHECK_REFERER=true
CSRF_ROTATE_ON_LOGIN=true
CSRF_ROTATE_ON_ERROR=true
# Trusted origins (comma-separated)
CSRF_TRUSTED_ORIGINS=https://app.example.com,https://admin.example.com
# Exempted paths (comma-separated)
CSRF_EXEMPT_PATHS=/health,/auth/login,/auth/refresh,/docs,/openapi.json
Security Benefits
- Protection Against CSRF Attacks: Prevents unauthorized state-changing requests
- Double-Submit Cookie Pattern: Validates token in both cookie and header/body
- Session Binding: Tokens are bound to specific user sessions
- Automatic Rotation: Tokens rotate on login and errors
- Referer Validation: Additional layer of security
- SameSite Cookies: Modern browser protection
- Flexible Integration: Works with forms, AJAX, and APIs
Migration Guide
-
Enable CSRF Protection:
CSRF_ENABLED=true
-
Update Frontend Code:
- Replace
fetch()
withfetchWithCSRF()
- Add CSRF tokens to all forms
- Handle 403 errors for token refresh
- Replace
-
Configure Trusted Origins:
CSRF_TRUSTED_ORIGINS=https://your-domain.com
-
Test Protection:
# Should fail without CSRF token curl -X POST https://your-app/admin/gateways \ -H "Authorization: Bearer $TOKEN" \ -d '{"name":"test"}' # Should succeed with CSRF token curl -X POST https://your-app/admin/gateways \ -H "Authorization: Bearer $TOKEN" \ -H "X-CSRF-Token: $CSRF_TOKEN" \ -d '{"name":"test"}'
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestsecurityImproves securityImproves securitytriageIssues / Features awaiting triageIssues / Features awaiting triage