import os from datetime import timedelta import redis from fastapi import HTTPException, Request from dotenv import load_dotenv from api.logger import setup_logger load_dotenv() logger = setup_logger(__name__) # Initialize Redis connection REDIS_HOST = os.getenv("REDIS_HOST", "localhost") REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) # Set request limit per minute from environment variable REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10")) def get_client_ip(request: Request) -> str: """Retrieve the IP address of the client making the request.""" return request.client.host def check_rate_limit(ip: str): """Check if the IP has exceeded the request limit per minute.""" redis_key = f"rate_limit:{ip}" current_count = redis_client.get(redis_key) if current_count is None: # New IP, initialize with a count and set an expiration of 1 minute redis_client.set(redis_key, 1, ex=timedelta(minutes=1)) logger.info(f"New IP {ip} added to request counts.") else: current_count = int(current_count) if current_count >= REQUEST_LIMIT_PER_MINUTE: logger.warning(f"Rate limit exceeded for IP {ip}.") raise HTTPException( status_code=429, detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}}, ) # Increment the request count for the IP redis_client.incr(redis_key) logger.info(f"IP {ip} made request number {current_count + 1}.")