|
import os |
|
from datetime import timedelta |
|
import redis |
|
from fastapi import HTTPException, Request |
|
from dotenv import load_dotenv |
|
from api.logger import setup_logger |
|
|
|
load_dotenv() |
|
logger = setup_logger(__name__) |
|
|
|
|
|
REDIS_HOST = os.getenv("REDIS_HOST", "localhost") |
|
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) |
|
|
|
|
|
REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10")) |
|
|
|
def get_client_ip(request: Request) -> str: |
|
"""Retrieve the IP address of the client making the request.""" |
|
return request.client.host |
|
|
|
def check_rate_limit(ip: str): |
|
"""Check if the IP has exceeded the request limit per minute.""" |
|
redis_key = f"rate_limit:{ip}" |
|
current_count = redis_client.get(redis_key) |
|
|
|
if current_count is None: |
|
|
|
redis_client.set(redis_key, 1, ex=timedelta(minutes=1)) |
|
logger.info(f"New IP {ip} added to request counts.") |
|
else: |
|
current_count = int(current_count) |
|
if current_count >= REQUEST_LIMIT_PER_MINUTE: |
|
logger.warning(f"Rate limit exceeded for IP {ip}.") |
|
raise HTTPException( |
|
status_code=429, |
|
detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}}, |
|
) |
|
|
|
redis_client.incr(redis_key) |
|
logger.info(f"IP {ip} made request number {current_count + 1}.") |
|
|