|
92 | 92 | rate_limit_storage = defaultdict(list)
|
93 | 93 |
|
94 | 94 |
|
95 |
| -# Rate limiting decorator |
96 | 95 | def rate_limit(requests_per_minute: int = None):
|
97 |
| - """Rate limiting decorator for admin endpoints. |
98 |
| -
|
99 |
| - Args: |
100 |
| - requests_per_minute: Maximum requests per minute (default: uses config value) |
101 |
| -
|
102 |
| - Returns: |
103 |
| - Decorator function that applies rate limiting to the wrapped endpoint |
104 |
| -
|
105 |
| - Examples: |
106 |
| - >>> # Test that rate_limit is callable |
107 |
| - >>> from mcpgateway.admin import rate_limit |
108 |
| - >>> callable(rate_limit) |
109 |
| - True |
110 |
| - >>> # Test that it returns a decorator function |
111 |
| - >>> import inspect |
112 |
| - >>> decorator = rate_limit(30) |
113 |
| - >>> inspect.isfunction(decorator) |
114 |
| - True |
115 |
| - """ |
116 |
| - |
117 | 96 | def decorator(func):
|
118 |
| - """Inner decorator function that wraps the endpoint with rate limiting. |
119 |
| -
|
120 |
| - Args: |
121 |
| - func: The FastAPI endpoint function to wrap |
122 |
| -
|
123 |
| - Returns: |
124 |
| - The wrapped function with rate limiting applied |
125 |
| - """ |
126 |
| - |
127 | 97 | @wraps(func)
|
128 | 98 | async def wrapper(*args, request: Request = None, **kwargs):
|
129 |
| - """Wrapper function that applies rate limiting logic. |
130 |
| -
|
131 |
| - Args: |
132 |
| - *args: Variable length argument list passed to wrapped function |
133 |
| - request: FastAPI Request object containing client information |
134 |
| - **kwargs: Arbitrary keyword arguments passed to wrapped function |
135 |
| -
|
136 |
| - Returns: |
137 |
| - The result of the wrapped function call |
138 |
| -
|
139 |
| - Raises: |
140 |
| - HTTPException: When rate limit is exceeded (429 Too Many Requests) |
141 |
| - """ |
142 |
| - # Get the rate limit from parameter or config |
| 99 | + # use configured limit if none provided |
143 | 100 | limit = requests_per_minute or settings.validation_max_requests_per_minute
|
144 | 101 |
|
145 |
| - # Get client identifier (IP address) |
146 |
| - client_ip = request.client.host if request and request.client else "unknown" |
| 102 | + # request can be None in some edge cases (e.g., tests) |
| 103 | + client_ip = (request.client.host if request and request.client else "unknown") |
147 | 104 | current_time = time.time()
|
148 | 105 | minute_ago = current_time - 60
|
149 | 106 |
|
150 |
| - # Clean old entries and get current requests |
151 |
| - rate_limit_storage[client_ip] = [timestamp for timestamp in rate_limit_storage[client_ip] if timestamp > minute_ago] |
| 107 | + # prune old timestamps |
| 108 | + rate_limit_storage[client_ip] = [ |
| 109 | + ts for ts in rate_limit_storage[client_ip] if ts > minute_ago |
| 110 | + ] |
152 | 111 |
|
153 |
| - # Check rate limit |
| 112 | + # enforce |
154 | 113 | if len(rate_limit_storage[client_ip]) >= limit:
|
155 |
| - logger.warning(f"Rate limit exceeded for IP {client_ip} on endpoint {func.__name__}") |
156 |
| - raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Maximum {limit} requests per minute.") |
| 114 | + logger.warning( |
| 115 | + f"Rate limit exceeded for IP {client_ip} on endpoint {func.__name__}" |
| 116 | + ) |
| 117 | + raise HTTPException( |
| 118 | + status_code=429, |
| 119 | + detail=f"Rate limit exceeded. Maximum {limit} requests per minute.", |
| 120 | + ) |
157 | 121 |
|
158 |
| - # Add current request timestamp |
159 | 122 | rate_limit_storage[client_ip].append(current_time)
|
160 | 123 |
|
161 |
| - # Call the original function |
162 |
| - return await func(*args, **kwargs) |
163 |
| - |
| 124 | + # IMPORTANT: forward request to the real endpoint |
| 125 | + return await func(*args, request=request, **kwargs) |
164 | 126 | return wrapper
|
165 |
| - |
166 | 127 | return decorator
|
167 | 128 |
|
168 | 129 |
|
| 130 | + |
169 | 131 | admin_router = APIRouter(prefix="/admin", tags=["Admin UI"])
|
170 | 132 |
|
171 | 133 | ####################
|
@@ -4344,3 +4306,93 @@ async def admin_list_tags(
|
4344 | 4306 | except Exception as e:
|
4345 | 4307 | logger.error(f"Failed to retrieve tags for admin: {str(e)}")
|
4346 | 4308 | raise HTTPException(status_code=500, detail=f"Failed to retrieve tags: {str(e)}")
|
| 4309 | + |
| 4310 | +# admin.py |
| 4311 | + |
| 4312 | +@admin_router.post("/tools/import/") |
| 4313 | +@admin_router.post("/tools/import") |
| 4314 | +@rate_limit(requests_per_minute=10) |
| 4315 | +async def admin_import_tools( |
| 4316 | + request: Request, |
| 4317 | + db: Session = Depends(get_db), |
| 4318 | + user: str = Depends(require_auth), |
| 4319 | +) -> JSONResponse: |
| 4320 | + logger.debug("bulk tool import: user=%s", user) |
| 4321 | + try: |
| 4322 | + # ---------- robust payload parsing ---------- |
| 4323 | + ctype = (request.headers.get("content-type") or "").lower() |
| 4324 | + if "application/json" in ctype: |
| 4325 | + try: |
| 4326 | + payload = await request.json() |
| 4327 | + except Exception as ex: |
| 4328 | + logger.exception("Invalid JSON body") |
| 4329 | + return JSONResponse({"success": False, "message": f"Invalid JSON: {ex}"}, status_code=422) |
| 4330 | + else: |
| 4331 | + try: |
| 4332 | + form = await request.form() |
| 4333 | + except Exception as ex: |
| 4334 | + logger.exception("Invalid form body") |
| 4335 | + return JSONResponse({"success": False, "message": f"Invalid form data: {ex}"}, status_code=422) |
| 4336 | + raw = form.get("tools_json") or form.get("json") or form.get("payload") |
| 4337 | + if not raw: |
| 4338 | + return JSONResponse({"success": False, "message": "Missing tools_json/json/payload form field."}, status_code=422) |
| 4339 | + try: |
| 4340 | + payload = json.loads(raw) |
| 4341 | + except Exception as ex: |
| 4342 | + logger.exception("Invalid JSON in form field") |
| 4343 | + return JSONResponse({"success": False, "message": f"Invalid JSON: {ex}"}, status_code=422) |
| 4344 | + |
| 4345 | + if not isinstance(payload, list): |
| 4346 | + return JSONResponse({"success": False, "message": "Payload must be a JSON array of tools."}, status_code=422) |
| 4347 | + |
| 4348 | + MAX_BATCH = 200 |
| 4349 | + if len(payload) > MAX_BATCH: |
| 4350 | + return JSONResponse({"success": False, "message": f"Too many tools ({len(payload)}). Max {MAX_BATCH}."}, status_code=413) |
| 4351 | + |
| 4352 | + created, errors = [], [] |
| 4353 | + |
| 4354 | + # ---------- import loop ---------- |
| 4355 | + for i, item in enumerate(payload): |
| 4356 | + name = (item or {}).get("name") |
| 4357 | + try: |
| 4358 | + tool = ToolCreate(**item) # pydantic validation |
| 4359 | + await tool_service.register_tool(db, tool) |
| 4360 | + created.append({"index": i, "name": name}) |
| 4361 | + except IntegrityError as ex: |
| 4362 | + # The formatter can itself throw; guard it. |
| 4363 | + try: |
| 4364 | + formatted = ErrorFormatter.format_database_error(ex) |
| 4365 | + except Exception: |
| 4366 | + formatted = {"message": str(ex)} |
| 4367 | + errors.append({"index": i, "name": name, "error": formatted}) |
| 4368 | + except (ValidationError, CoreValidationError) as ex: |
| 4369 | + # Ditto: guard the formatter |
| 4370 | + try: |
| 4371 | + formatted = ErrorFormatter.format_validation_error(ex) |
| 4372 | + except Exception: |
| 4373 | + formatted = {"message": str(ex)} |
| 4374 | + errors.append({"index": i, "name": name, "error": formatted}) |
| 4375 | + except ToolError as ex: |
| 4376 | + errors.append({"index": i, "name": name, "error": {"message": str(ex)}}) |
| 4377 | + except Exception as ex: |
| 4378 | + logger.exception("Unexpected error importing tool %r at index %d", name, i) |
| 4379 | + errors.append({"index": i, "name": name, "error": {"message": str(ex)}}) |
| 4380 | + |
| 4381 | + return JSONResponse( |
| 4382 | + { |
| 4383 | + "success": len(errors) == 0, |
| 4384 | + "created_count": len(created), |
| 4385 | + "failed_count": len(errors), |
| 4386 | + "created": created, |
| 4387 | + "errors": errors, |
| 4388 | + }, |
| 4389 | + status_code=200, |
| 4390 | + ) |
| 4391 | + |
| 4392 | + except HTTPException as ex: |
| 4393 | + # let FastAPI semantics (e.g., auth) pass through |
| 4394 | + raise |
| 4395 | + except Exception as ex: |
| 4396 | + # absolute catch-all: report instead of crashing |
| 4397 | + logger.exception("Fatal error in admin_import_tools") |
| 4398 | + return JSONResponse({"success": False, "message": str(ex)}, status_code=500) |
0 commit comments