File size: 2,961 Bytes
b7a7f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from functools import wraps

from core.db import redis_throttle_client
from fastapi import HTTPException
from typing import Callable


def ip_throttle(rate: int, per: int = 86400) -> Callable:
    """
    Decorater used to throttle incoming requests based on IP Address

    Parameters
    ----------
    rate : int, required
        Rate value for the throttling.

    per : int, optional
        Rate limit expiry time (in seconds) (default is 86400 ie. 1 day)


    Raises
    ------
    HTTPException(status_code=429)
        If throttle limit is reached and the request is blocked
    """

    def outer_wrapper(func):
        @wraps(func)
        async def inner_wrapper(*args, **kwargs):
            client_ip = kwargs.get(
                "request"
            ).client.host  # FIXME - Proxy might mess this up. Might need to look X-HTTP-FORWARDED. Works locally
            identifier = f"ip_th_{client_ip}_{func.__name__}_{per}"
            current_count = await redis_throttle_client.client.get(identifier)

            if not current_count:
                await redis_throttle_client.client.set(identifier, 1, expire=per)

            elif not int(current_count.decode("utf-8")) <= rate:
                raise HTTPException(
                    status_code=429, detail="Error ID: 133"
                )  # Too many requests!

            ret_val = await func(*args, **kwargs)
            await redis_throttle_client.client.incr(identifier)
            return ret_val

        return inner_wrapper

    return outer_wrapper


def user_throttle(rate: int, per: int = 86400) -> Callable:
    """
    Decorater used to throttle incoming requests based on User

    Parameters
    ----------
    rate : int, required
        Rate value for the throttling.

    per : int, optional
        Rate limit expiry time (in seconds) (default is 86400 ie. 1 day)


    Raises
    ------
    HTTPException(status_code=429)
        If throttle limit is reached and the request is blocked
    """

    def outer_wrapper(func):
        @wraps(func)
        async def inner_wrapper(*args, **kwargs):
            client = kwargs.get("current_user")

            if not client:
                raise HTTPException(
                    status_code=429, detail="Error ID: 134"
                )  # User not logged in!

            identifier = f"user_th_{client}_{func.__name__}"
            current_count = await redis_throttle_client.client.get(identifier)
            if current_count == None:
                await redis_throttle_client.client.set(identifier, 1, ex=per)
            elif int(current_count.decode("utf-8")) < rate:
                await redis_throttle_client.client.incr(identifier, amount=1)
            else:
                raise HTTPException(
                    status_code=429, detail="Error ID: 135"
                )  # Too many requests!
            return await func(*args, **kwargs)

        return inner_wrapper

    return outer_wrapper