Create rpmlimits.py
Browse files- api/rpmlimits.py +42 -0
api/rpmlimits.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import timedelta
|
3 |
+
import redis
|
4 |
+
from fastapi import HTTPException, Request
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from api.logger import setup_logger
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
logger = setup_logger(__name__)
|
10 |
+
|
11 |
+
# Initialize Redis connection
|
12 |
+
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
|
13 |
+
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
|
14 |
+
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
|
15 |
+
|
16 |
+
# Set request limit per minute from environment variable
|
17 |
+
REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10"))
|
18 |
+
|
19 |
+
def get_client_ip(request: Request) -> str:
|
20 |
+
"""Retrieve the IP address of the client making the request."""
|
21 |
+
return request.client.host
|
22 |
+
|
23 |
+
def check_rate_limit(ip: str):
|
24 |
+
"""Check if the IP has exceeded the request limit per minute."""
|
25 |
+
redis_key = f"rate_limit:{ip}"
|
26 |
+
current_count = redis_client.get(redis_key)
|
27 |
+
|
28 |
+
if current_count is None:
|
29 |
+
# New IP, initialize with a count and set an expiration of 1 minute
|
30 |
+
redis_client.set(redis_key, 1, ex=timedelta(minutes=1))
|
31 |
+
logger.info(f"New IP {ip} added to request counts.")
|
32 |
+
else:
|
33 |
+
current_count = int(current_count)
|
34 |
+
if current_count >= REQUEST_LIMIT_PER_MINUTE:
|
35 |
+
logger.warning(f"Rate limit exceeded for IP {ip}.")
|
36 |
+
raise HTTPException(
|
37 |
+
status_code=429,
|
38 |
+
detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}},
|
39 |
+
)
|
40 |
+
# Increment the request count for the IP
|
41 |
+
redis_client.incr(redis_key)
|
42 |
+
logger.info(f"IP {ip} made request number {current_count + 1}.")
|