File size: 1,658 Bytes
96bd80a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import os
from datetime import timedelta
import redis
from fastapi import HTTPException, Request
from dotenv import load_dotenv
from api.logger import setup_logger
load_dotenv()
logger = setup_logger(__name__)
# Initialize Redis connection
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
# Set request limit per minute from environment variable
REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10"))
def get_client_ip(request: Request) -> str:
"""Retrieve the IP address of the client making the request."""
return request.client.host
def check_rate_limit(ip: str):
"""Check if the IP has exceeded the request limit per minute."""
redis_key = f"rate_limit:{ip}"
current_count = redis_client.get(redis_key)
if current_count is None:
# New IP, initialize with a count and set an expiration of 1 minute
redis_client.set(redis_key, 1, ex=timedelta(minutes=1))
logger.info(f"New IP {ip} added to request counts.")
else:
current_count = int(current_count)
if current_count >= REQUEST_LIMIT_PER_MINUTE:
logger.warning(f"Rate limit exceeded for IP {ip}.")
raise HTTPException(
status_code=429,
detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}},
)
# Increment the request count for the IP
redis_client.incr(redis_key)
logger.info(f"IP {ip} made request number {current_count + 1}.")
|