File size: 4,343 Bytes
a71fd48
 
0d394b4
a71fd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d394b4
 
 
 
 
a71fd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import base64
import json
import logging
import time
from urllib.parse import urlencode

from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

from mediaflow_proxy.configs import settings


class EncryptionHandler:
    def __init__(self, secret_key: str):
        self.secret_key = secret_key.encode("utf-8").ljust(32)[:32]

    def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str:
        if expiration:
            data["exp"] = int(time.time()) + expiration
        if ip:
            data["ip"] = ip
        json_data = json.dumps(data).encode("utf-8")
        iv = get_random_bytes(16)
        cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
        encrypted_data = cipher.encrypt(pad(json_data, AES.block_size))
        return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8")

    def decrypt_data(self, token: str, client_ip: str) -> dict:
        try:
            encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8"))
            iv = encrypted_data[:16]
            cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
            decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size)
            data = json.loads(decrypted_data)

            if "exp" in data:
                if data["exp"] < time.time():
                    raise HTTPException(status_code=401, detail="Token has expired")
                del data["exp"]  # Remove expiration from the data

            if "ip" in data:
                if data["ip"] != client_ip:
                    raise HTTPException(status_code=403, detail="IP address mismatch")
                del data["ip"]  # Remove IP from the data

            return data
        except Exception as e:
            raise HTTPException(status_code=401, detail="Invalid or expired token")


class EncryptionMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        self.encryption_handler = encryption_handler

    async def dispatch(self, request: Request, call_next):
        encrypted_token = request.query_params.get("token")
        if encrypted_token:
            try:
                client_ip = self.get_client_ip(request)
                decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip)
                # Modify request query parameters with decrypted data
                query_params = dict(request.query_params)
                query_params.pop("token")  # Remove the encrypted token from query params
                query_params.update(decrypted_data)  # Add decrypted data to query params
                query_params["has_encrypted"] = True

                # Create a new request scope with updated query parameters
                new_query_string = urlencode(query_params)
                request.scope["query_string"] = new_query_string.encode()
                request._query_params = query_params
            except HTTPException as e:
                return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code)

        try:
            response = await call_next(request)
        except Exception as e:
            logging.exception("An error occurred while processing the request")
            return JSONResponse(content={"error": str(e)}, status_code=500)
        return response

    @staticmethod
    def get_client_ip(request: Request) -> str | None:
        """
        Extract the client's real IP address from the request headers or fallback to the client host.
        """
        x_forwarded_for = request.headers.get("X-Forwarded-For")
        if x_forwarded_for:
            # In some cases, this header can contain multiple IPs
            # separated by commas.
            # The first one is the original client's IP.
            return x_forwarded_for.split(",")[0].strip()
        # Fallback to X-Real-IP if X-Forwarded-For is not available
        x_real_ip = request.headers.get("X-Real-IP")
        if x_real_ip:
            return x_real_ip
        return request.client.host if request.client else "127.0.0.1"


encryption_handler = EncryptionHandler(settings.api_password)