Skip to content

[SECURITY FEATURE]: CSRF Token Protection System #543

@crivetimihai

Description

@crivetimihai

🛡️ FEATURE: CSRF Token Protection System

Summary: Implement comprehensive CSRF (Cross-Site Request Forgery) protection for all state-changing operations in MCP Gateway. This includes double-submit cookie pattern, per-session tokens, SameSite cookies, and automatic token validation for forms and API requests.

Depends on: Existing JWT authentication system

Implementation

1. Update config.py with CSRF Settings

# In mcpgateway/config.py

class Settings(BaseSettings):
    # ... existing settings ...
    
    # ===================================
    # CSRF Protection Settings
    # ===================================
    
    # CSRF Token Configuration
    csrf_enabled: bool = True
    csrf_token_name: str = "X-CSRF-Token"
    csrf_cookie_name: str = "csrf_token"
    csrf_header_name: str = "X-CSRF-Token"
    csrf_form_field_name: str = "csrf_token"
    csrf_token_length: int = 32
    csrf_token_expiry: int = 3600  # 1 hour in seconds
    
    # Cookie Settings
    csrf_cookie_secure: bool = True  # HTTPS only in production
    csrf_cookie_httponly: bool = False  # Must be readable by JS
    csrf_cookie_samesite: str = "Strict"
    csrf_cookie_path: str = "/"
    csrf_cookie_domain: Optional[str] = None
    
    # Exempted Paths (no CSRF check)
    csrf_exempt_paths: List[str] = [
        "/health",
        "/auth/login",  # Login endpoint generates CSRF token
        "/auth/refresh",
        "/docs",
        "/openapi.json",
        "/metrics"
    ]
    
    # Safe Methods (no CSRF check)
    csrf_safe_methods: List[str] = ["GET", "HEAD", "OPTIONS", "TRACE"]
    
    # Token Rotation
    csrf_rotate_on_login: bool = True
    csrf_rotate_on_error: bool = True
    
    # Additional Security
    csrf_check_referer: bool = True
    csrf_trusted_origins: List[str] = []  # Add trusted origins in production

2. Create CSRF Service

# Create mcpgateway/services/csrf_service.py

import secrets
import time
import hashlib
from typing import Optional, Dict, Tuple
from fastapi import Request, Response, HTTPException
from jose import jwt

from mcpgateway.config import settings
from mcpgateway.exceptions import CSRFError

class CSRFService:
    """Service for managing CSRF tokens and validation."""
    
    def __init__(self):
        self.token_cache: Dict[str, float] = {}  # token -> expiry timestamp
        
    def generate_token(self, user_id: str, session_id: str) -> str:
        """Generate a new CSRF token for a user session."""
        # Create token with user and session binding
        random_data = secrets.token_urlsafe(settings.csrf_token_length)
        token_data = f"{user_id}:{session_id}:{random_data}:{int(time.time())}"
        
        # Create HMAC for integrity
        token_hash = hashlib.sha256(
            f"{token_data}:{settings.jwt_secret_key}".encode()
        ).hexdigest()[:16]
        
        token = f"{random_data}.{token_hash}"
        
        # Cache token with expiry
        expiry = time.time() + settings.csrf_token_expiry
        self.token_cache[token] = expiry
        
        # Clean expired tokens periodically
        if len(self.token_cache) > 1000:
            self._cleanup_expired_tokens()
        
        return token
    
    def validate_token(
        self, 
        token: str, 
        user_id: str, 
        session_id: str
    ) -> bool:
        """Validate CSRF token."""
        if not token:
            return False
        
        # Check token format
        parts = token.split(".")
        if len(parts) != 2:
            return False
        
        # Check if token exists and not expired
        if token not in self.token_cache:
            return False
        
        if time.time() > self.token_cache[token]:
            del self.token_cache[token]
            return False
        
        return True
    
    def extract_token_from_request(self, request: Request) -> Optional[str]:
        """Extract CSRF token from request (header, form, or query)."""
        # Check header first (for AJAX requests)
        token = request.headers.get(settings.csrf_header_name)
        if token:
            return token
        
        # Check form data
        if request.method == "POST":
            content_type = request.headers.get("content-type", "")
            if "application/x-www-form-urlencoded" in content_type:
                form_data = await request.form()
                token = form_data.get(settings.csrf_form_field_name)
                if token:
                    return token
            elif "multipart/form-data" in content_type:
                form_data = await request.form()
                token = form_data.get(settings.csrf_form_field_name)
                if token:
                    return token
        
        # Check JSON body for API requests
        if "application/json" in request.headers.get("content-type", ""):
            try:
                body = await request.json()
                token = body.get(settings.csrf_form_field_name)
                if token:
                    return token
            except:
                pass
        
        # Check cookie (double-submit pattern)
        token = request.cookies.get(settings.csrf_cookie_name)
        return token
    
    def set_csrf_cookie(
        self, 
        response: Response, 
        token: str,
        secure: Optional[bool] = None
    ) -> None:
        """Set CSRF token cookie."""
        response.set_cookie(
            key=settings.csrf_cookie_name,
            value=token,
            max_age=settings.csrf_token_expiry,
            secure=secure if secure is not None else settings.csrf_cookie_secure,
            httponly=settings.csrf_cookie_httponly,
            samesite=settings.csrf_cookie_samesite,
            path=settings.csrf_cookie_path,
            domain=settings.csrf_cookie_domain
        )
    
    def clear_csrf_cookie(self, response: Response) -> None:
        """Clear CSRF token cookie."""
        response.delete_cookie(
            key=settings.csrf_cookie_name,
            path=settings.csrf_cookie_path,
            domain=settings.csrf_cookie_domain
        )
    
    def check_referer(self, request: Request) -> bool:
        """Validate referer header for additional security."""
        if not settings.csrf_check_referer:
            return True
        
        referer = request.headers.get("referer")
        if not referer:
            return False
        
        # Parse referer
        from urllib.parse import urlparse
        referer_parsed = urlparse(referer)
        
        # Check against host
        host = request.headers.get("host")
        if referer_parsed.netloc == host:
            return True
        
        # Check against trusted origins
        for origin in settings.csrf_trusted_origins:
            origin_parsed = urlparse(origin)
            if referer_parsed.netloc == origin_parsed.netloc:
                return True
        
        return False
    
    def _cleanup_expired_tokens(self) -> None:
        """Remove expired tokens from cache."""
        current_time = time.time()
        expired_tokens = [
            token for token, expiry in self.token_cache.items()
            if current_time > expiry
        ]
        for token in expired_tokens:
            del self.token_cache[token]

# Global instance
csrf_service = CSRFService()

3. Create CSRF Middleware

# Create mcpgateway/middleware/csrf_middleware.py

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

from mcpgateway.config import settings
from mcpgateway.services.csrf_service import csrf_service
from mcpgateway.dependencies import get_current_user_optional

class CSRFMiddleware(BaseHTTPMiddleware):
    """Middleware to enforce CSRF protection."""
    
    async def dispatch(self, request: Request, call_next):
        # Skip CSRF check if disabled
        if not settings.csrf_enabled:
            return await call_next(request)
        
        # Skip for exempted paths
        if any(request.url.path.startswith(path) for path in settings.csrf_exempt_paths):
            return await call_next(request)
        
        # Skip for safe methods
        if request.method in settings.csrf_safe_methods:
            response = await call_next(request)
            
            # Add CSRF token to response for authenticated users
            user = await get_current_user_optional(request)
            if user:
                await self._add_csrf_token_to_response(request, response, user)
            
            return response
        
        # For state-changing requests, validate CSRF token
        try:
            # Get current user
            user = await get_current_user_optional(request)
            if not user:
                # No CSRF check for unauthenticated requests
                return await call_next(request)
            
            # Extract token from request
            token = await csrf_service.extract_token_from_request(request)
            if not token:
                return JSONResponse(
                    status_code=403,
                    content={"detail": "CSRF token missing"}
                )
            
            # Get session ID from JWT or session
            session_id = getattr(user, "session_id", str(user.id))
            
            # Validate token
            if not csrf_service.validate_token(token, str(user.id), session_id):
                return JSONResponse(
                    status_code=403,
                    content={"detail": "Invalid CSRF token"}
                )
            
            # Check referer
            if not csrf_service.check_referer(request):
                return JSONResponse(
                    status_code=403,
                    content={"detail": "Invalid referer"}
                )
            
            # Process request
            response = await call_next(request)
            
            # Rotate token if needed
            if settings.csrf_rotate_on_error and response.status_code >= 400:
                await self._add_csrf_token_to_response(request, response, user, rotate=True)
            
            return response
            
        except Exception as e:
            return JSONResponse(
                status_code=403,
                content={"detail": f"CSRF validation error: {str(e)}"}
            )
    
    async def _add_csrf_token_to_response(
        self, 
        request: Request, 
        response: Response, 
        user,
        rotate: bool = False
    ):
        """Add CSRF token to response headers and cookies."""
        # Check if we need a new token
        existing_token = request.cookies.get(settings.csrf_cookie_name)
        session_id = getattr(user, "session_id", str(user.id))
        
        if not existing_token or rotate:
            # Generate new token
            new_token = csrf_service.generate_token(str(user.id), session_id)
            
            # Set cookie
            csrf_service.set_csrf_cookie(response, new_token)
            
            # Add to response header for easy JS access
            response.headers[settings.csrf_header_name] = new_token

4. Update Authentication Endpoints

# Update mcpgateway/routers/auth.py

from mcpgateway.services.csrf_service import csrf_service

@router.post("/login")
async def login(
    form_data: OAuth2PasswordRequestForm = Depends(),
    response: Response,
    db: Session = Depends(get_db)
):
    """Login endpoint that generates CSRF token."""
    # ... existing authentication logic ...
    
    # After successful authentication
    if user and access_token:
        # Generate CSRF token
        session_id = str(uuid.uuid4())  # Generate session ID
        csrf_token = csrf_service.generate_token(str(user.id), session_id)
        
        # Set CSRF cookie
        csrf_service.set_csrf_cookie(response, csrf_token)
        
        # Include CSRF token in response
        return {
            "access_token": access_token,
            "token_type": "bearer",
            "csrf_token": csrf_token,  # For SPA applications
            "user": {
                "id": user.id,
                "username": user.username,
                "is_admin": user.is_admin
            }
        }

@router.post("/logout")
async def logout(
    response: Response,
    current_user: User = Depends(get_current_user)
):
    """Logout endpoint that clears CSRF token."""
    # Clear CSRF cookie
    csrf_service.clear_csrf_cookie(response)
    
    # ... existing logout logic ...
    
    return {"message": "Logged out successfully"}

@router.get("/csrf-token")
async def get_csrf_token(
    response: Response,
    current_user: User = Depends(get_current_user)
):
    """Get a fresh CSRF token."""
    session_id = getattr(current_user, "session_id", str(current_user.id))
    csrf_token = csrf_service.generate_token(str(current_user.id), session_id)
    
    # Set cookie and header
    csrf_service.set_csrf_cookie(response, csrf_token)
    response.headers[settings.csrf_header_name] = csrf_token
    
    return {"csrf_token": csrf_token}

5. Update Frontend JavaScript

// Update admin.js with CSRF support

// Global CSRF token management
let csrfToken = null;

// Get CSRF token from cookie
function getCSRFToken() {
    const name = 'csrf_token=';
    const decodedCookie = decodeURIComponent(document.cookie);
    const ca = decodedCookie.split(';');
    for(let i = 0; i < ca.length; i++) {
        let c = ca[i];
        while (c.charAt(0) == ' ') {
            c = c.substring(1);
        }
        if (c.indexOf(name) == 0) {
            return c.substring(name.length, c.length);
        }
    }
    return null;
}

// Fetch new CSRF token
async function refreshCSRFToken() {
    try {
        const response = await fetch(`${window.ROOT_PATH}/auth/csrf-token`, {
            credentials: 'include',
            headers: {
                'Authorization': `Bearer ${getAuthToken()}`
            }
        });
        
        if (response.ok) {
            const data = await response.json();
            csrfToken = data.csrf_token;
            return csrfToken;
        }
    } catch (error) {
        console.error('Failed to refresh CSRF token:', error);
    }
    return null;
}

// Enhanced fetch with CSRF token
async function fetchWithCSRF(url, options = {}) {
    // Get CSRF token
    if (!csrfToken) {
        csrfToken = getCSRFToken() || await refreshCSRFToken();
    }
    
    // Add CSRF token to headers for state-changing requests
    const method = (options.method || 'GET').toUpperCase();
    if (['POST', 'PUT', 'DELETE', 'PATCH'].includes(method)) {
        options.headers = {
            ...options.headers,
            'X-CSRF-Token': csrfToken
        };
        
        // For form data, add CSRF token as field
        if (options.body instanceof FormData) {
            options.body.append('csrf_token', csrfToken);
        }
        // For JSON requests, add to body
        else if (options.headers?.['Content-Type'] === 'application/json' && options.body) {
            const body = JSON.parse(options.body);
            body.csrf_token = csrfToken;
            options.body = JSON.stringify(body);
        }
    }
    
    // Make request
    const response = await fetchWithTimeout(url, options);
    
    // If CSRF token invalid, refresh and retry once
    if (response.status === 403) {
        const text = await response.text();
        if (text.includes('CSRF')) {
            csrfToken = await refreshCSRFToken();
            if (csrfToken && !options._retried) {
                options._retried = true;
                return fetchWithCSRF(url, options);
            }
        }
    }
    
    return response;
}

// Update all form submissions to use CSRF
async function handleGatewayFormSubmit(e) {
    e.preventDefault();
    
    const form = e.target;
    const formData = new FormData(form);
    
    try {
        // ... validation logic ...
        
        const response = await fetchWithCSRF(
            `${window.ROOT_PATH}/admin/gateways`,
            {
                method: "POST",
                body: formData,
            }
        );
        
        // ... handle response ...
    } catch (error) {
        // ... error handling ...
    }
}

// Initialize CSRF token on page load
document.addEventListener('DOMContentLoaded', async () => {
    // Get initial CSRF token
    csrfToken = getCSRFToken();
    if (!csrfToken && isAuthenticated()) {
        csrfToken = await refreshCSRFToken();
    }
    
    // ... rest of initialization ...
});

// Add CSRF token to all AJAX requests automatically
const originalFetch = window.fetch;
window.fetch = function(...args) {
    if (args[1] && typeof args[1] === 'object') {
        return fetchWithCSRF(args[0], args[1]);
    }
    return originalFetch.apply(this, args);
};

6. Update HTML Templates

<!-- Update admin.html forms with CSRF token fields -->

<!-- Add CSRF meta tag for easy access -->
<meta name="csrf-token" content="">

<!-- Update forms to include CSRF field -->
<form id="add-gateway-form" method="POST">
    <input type="hidden" name="csrf_token" value="">
    <!-- ... other form fields ... -->
</form>

<script>
// Populate CSRF token in forms
document.addEventListener('DOMContentLoaded', () => {
    const csrfToken = getCSRFToken();
    
    // Update meta tag
    const metaTag = document.querySelector('meta[name="csrf-token"]');
    if (metaTag && csrfToken) {
        metaTag.content = csrfToken;
    }
    
    // Update all forms
    document.querySelectorAll('input[name="csrf_token"]').forEach(input => {
        input.value = csrfToken;
    });
});
</script>

7. Add CSRF Decorator for Additional Protection

# Create mcpgateway/decorators/csrf_decorator.py

from functools import wraps
from fastapi import Request, HTTPException

from mcpgateway.services.csrf_service import csrf_service
from mcpgateway.dependencies import get_current_user

def require_csrf(exempt: bool = False):
    """Decorator to enforce CSRF protection on specific endpoints."""
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            if exempt or not settings.csrf_enabled:
                return await func(*args, **kwargs)
            
            # Find request object in args/kwargs
            request = None
            for arg in args:
                if isinstance(arg, Request):
                    request = arg
                    break
            if not request:
                request = kwargs.get('request')
            
            if not request:
                raise HTTPException(
                    status_code=500,
                    detail="CSRF protection requires Request object"
                )
            
            # Validate CSRF token
            token = await csrf_service.extract_token_from_request(request)
            user = await get_current_user(request)
            session_id = getattr(user, "session_id", str(user.id))
            
            if not csrf_service.validate_token(token, str(user.id), session_id):
                raise HTTPException(
                    status_code=403,
                    detail="Invalid or missing CSRF token"
                )
            
            return await func(*args, **kwargs)
        
        return wrapper
    return decorator

# Usage example
@router.post("/admin/gateways")
@require_csrf()
async def create_gateway(
    request: Request,
    gateway_data: GatewayCreate,
    current_user: User = Depends(get_current_admin)
):
    # ... endpoint logic ...

8. Add CSRF Configuration UI

# Add to admin endpoints
@router.get("/admin/security/csrf")
async def get_csrf_config(
    current_user: User = Depends(get_current_admin)
):
    """Get CSRF configuration."""
    return {
        "enabled": settings.csrf_enabled,
        "token_name": settings.csrf_token_name,
        "header_name": settings.csrf_header_name,
        "cookie_settings": {
            "secure": settings.csrf_cookie_secure,
            "httponly": settings.csrf_cookie_httponly,
            "samesite": settings.csrf_cookie_samesite
        },
        "token_expiry": settings.csrf_token_expiry,
        "safe_methods": settings.csrf_safe_methods,
        "exempt_paths": settings.csrf_exempt_paths
    }

@router.post("/admin/security/csrf/rotate")
@require_csrf()
async def rotate_csrf_tokens(
    response: Response,
    current_user: User = Depends(get_current_admin)
):
    """Force rotation of all CSRF tokens."""
    # Clear token cache
    csrf_service.token_cache.clear()
    
    # Generate new token for admin
    session_id = getattr(current_user, "session_id", str(current_user.id))
    new_token = csrf_service.generate_token(str(current_user.id), session_id)
    csrf_service.set_csrf_cookie(response, new_token)
    
    return {
        "message": "CSRF tokens rotated successfully",
        "new_token": new_token
    }

9. Add Tests

# tests/test_csrf_protection.py

import pytest
from fastapi.testclient import TestClient

def test_csrf_token_generation(client: TestClient, auth_headers):
    """Test CSRF token is generated on login."""
    response = client.post(
        "/auth/login",
        data={"username": "admin", "password": "password"}
    )
    assert response.status_code == 200
    assert "csrf_token" in response.json()
    assert "csrf_token" in response.cookies

def test_csrf_protection_enforced(client: TestClient, auth_headers):
    """Test CSRF protection blocks requests without token."""
    # Try to create gateway without CSRF token
    response = client.post(
        "/admin/gateways",
        headers=auth_headers,
        json={"name": "test", "url": "http://example.com"}
    )
    assert response.status_code == 403
    assert "CSRF token missing" in response.text

def test_csrf_token_validation(client: TestClient, auth_headers):
    """Test valid CSRF token allows request."""
    # Get CSRF token
    csrf_response = client.get("/auth/csrf-token", headers=auth_headers)
    csrf_token = csrf_response.json()["csrf_token"]
    
    # Make request with CSRF token
    headers = {**auth_headers, "X-CSRF-Token": csrf_token}
    response = client.post(
        "/admin/gateways",
        headers=headers,
        json={"name": "test", "url": "http://example.com"}
    )
    assert response.status_code in [200, 201]

def test_csrf_safe_methods_exempt(client: TestClient, auth_headers):
    """Test safe methods don't require CSRF token."""
    response = client.get("/admin/gateways", headers=auth_headers)
    assert response.status_code == 200

10. Update .env.example

#####################################
# CSRF Protection Settings
#####################################

# Enable CSRF protection
CSRF_ENABLED=true

# Token configuration
CSRF_TOKEN_NAME=X-CSRF-Token
CSRF_COOKIE_NAME=csrf_token
CSRF_TOKEN_LENGTH=32
CSRF_TOKEN_EXPIRY=3600  # 1 hour

# Cookie settings
CSRF_COOKIE_SECURE=true  # Set to false for development
CSRF_COOKIE_HTTPONLY=false  # Must be false for JS access
CSRF_COOKIE_SAMESITE=Strict

# Security settings
CSRF_CHECK_REFERER=true
CSRF_ROTATE_ON_LOGIN=true
CSRF_ROTATE_ON_ERROR=true

# Trusted origins (comma-separated)
CSRF_TRUSTED_ORIGINS=https://app.example.com,https://admin.example.com

# Exempted paths (comma-separated)
CSRF_EXEMPT_PATHS=/health,/auth/login,/auth/refresh,/docs,/openapi.json

Security Benefits

  1. Protection Against CSRF Attacks: Prevents unauthorized state-changing requests
  2. Double-Submit Cookie Pattern: Validates token in both cookie and header/body
  3. Session Binding: Tokens are bound to specific user sessions
  4. Automatic Rotation: Tokens rotate on login and errors
  5. Referer Validation: Additional layer of security
  6. SameSite Cookies: Modern browser protection
  7. Flexible Integration: Works with forms, AJAX, and APIs

Migration Guide

  1. Enable CSRF Protection:

    CSRF_ENABLED=true
  2. Update Frontend Code:

    • Replace fetch() with fetchWithCSRF()
    • Add CSRF tokens to all forms
    • Handle 403 errors for token refresh
  3. Configure Trusted Origins:

    CSRF_TRUSTED_ORIGINS=https://your-domain.com
  4. Test Protection:

    # Should fail without CSRF token
    curl -X POST https://your-app/admin/gateways \
         -H "Authorization: Bearer $TOKEN" \
         -d '{"name":"test"}'
    
    # Should succeed with CSRF token
    curl -X POST https://your-app/admin/gateways \
         -H "Authorization: Bearer $TOKEN" \
         -H "X-CSRF-Token: $CSRF_TOKEN" \
         -d '{"name":"test"}'

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestsecurityImproves securitytriageIssues / Features awaiting triage

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions