Niansuh commited on
Commit
734aed4
·
verified ·
1 Parent(s): bcbf154

Update api/rpmlimits.py

Browse files
Files changed (1) hide show
  1. api/rpmlimits.py +15 -44
api/rpmlimits.py CHANGED
@@ -1,44 +1,15 @@
1
- from datetime import datetime, timedelta
2
- from fastapi import HTTPException, Request
3
- from api.logger import setup_logger
4
- import os
5
- from dotenv import load_dotenv
6
-
7
- load_dotenv()
8
- logger = setup_logger(__name__)
9
-
10
- # Set request limit per minute from environment variable
11
- REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10"))
12
-
13
- # Dictionary to track IP addresses and request counts
14
- request_counts = {}
15
-
16
- def get_client_ip(request: Request) -> str:
17
- """Retrieve the IP address of the client making the request."""
18
- return request.client.host
19
-
20
- def check_rate_limit(ip: str):
21
- """Check if the IP has exceeded the request limit per minute."""
22
- current_time = datetime.now()
23
-
24
- if ip not in request_counts:
25
- # New IP, initialize with a count and set the timestamp
26
- request_counts[ip] = {"count": 1, "timestamp": current_time}
27
- logger.info(f"New IP {ip} added to request counts.")
28
- else:
29
- ip_data = request_counts[ip]
30
- # Check if the timestamp is more than a minute old
31
- if current_time - ip_data["timestamp"] < timedelta(minutes=1):
32
- # If within the same minute, increment the count
33
- ip_data["count"] += 1
34
- logger.info(f"IP {ip} made request number {ip_data['count']}.")
35
- if ip_data["count"] > REQUEST_LIMIT_PER_MINUTE:
36
- logger.warning(f"Rate limit exceeded for IP {ip}.")
37
- raise HTTPException(
38
- status_code=429,
39
- detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}},
40
- )
41
- else:
42
- # If more than a minute has passed, reset the count and timestamp
43
- request_counts[ip] = {"count": 1, "timestamp": current_time}
44
- logger.info(f"Request count reset for IP {ip}.")
 
1
+ from slowapi import Limiter
2
+ from slowapi.util import get_remote_address
3
+ from slowapi.errors import RateLimitExceeded
4
+ from fastapi import FastAPI, Request
5
+ from starlette.responses import JSONResponse
6
+
7
+ # Initialize the limiter
8
+ limiter = Limiter(key_func=get_remote_address)
9
+
10
+ # Exception handler for rate limit exceeded
11
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
12
+ return JSONResponse(
13
+ status_code=429,
14
+ content={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}}
15
+ )