File size: 1,658 Bytes
96bd80a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
from datetime import timedelta
import redis
from fastapi import HTTPException, Request
from dotenv import load_dotenv
from api.logger import setup_logger

load_dotenv()
logger = setup_logger(__name__)

# Initialize Redis connection
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)

# Set request limit per minute from environment variable
REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10"))

def get_client_ip(request: Request) -> str:
    """Retrieve the IP address of the client making the request."""
    return request.client.host

def check_rate_limit(ip: str):
    """Check if the IP has exceeded the request limit per minute."""
    redis_key = f"rate_limit:{ip}"
    current_count = redis_client.get(redis_key)
    
    if current_count is None:
        # New IP, initialize with a count and set an expiration of 1 minute
        redis_client.set(redis_key, 1, ex=timedelta(minutes=1))
        logger.info(f"New IP {ip} added to request counts.")
    else:
        current_count = int(current_count)
        if current_count >= REQUEST_LIMIT_PER_MINUTE:
            logger.warning(f"Rate limit exceeded for IP {ip}.")
            raise HTTPException(
                status_code=429,
                detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}},
            )
        # Increment the request count for the IP
        redis_client.incr(redis_key)
        logger.info(f"IP {ip} made request number {current_count + 1}.")