Niansuh commited on
Commit
ed5df64
·
verified ·
1 Parent(s): 0a7708d

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +57 -10
api/utils.py CHANGED
@@ -1,13 +1,15 @@
1
- from datetime import datetime
2
  import json
3
  import uuid
4
  import asyncio
5
  import random
6
  import string
7
  from typing import Any, Dict, Optional
 
 
 
8
 
9
  import httpx
10
- from fastapi import HTTPException
11
  from api import validate # Import validate to use getHid
12
  from api.config import (
13
  MODEL_MAPPING,
@@ -22,13 +24,51 @@ from api.config import (
22
  from api.models import ChatRequest
23
  from api.logger import setup_logger
24
 
 
 
25
  logger = setup_logger(__name__)
26
 
 
 
 
 
 
 
27
  # Helper function to create a random alphanumeric chat ID
28
  def generate_chat_id(length: int = 7) -> str:
29
  characters = string.ascii_letters + string.digits
30
  return ''.join(random.choices(characters, k=length))
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Helper function to create chat completion data
33
  def create_chat_completion_data(
34
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
@@ -82,19 +122,22 @@ def get_referer_url(chat_id: str, model: str) -> str:
82
  return BASE_URL
83
 
84
  # Process streaming response with headers from config.py
85
- async def process_streaming_response(request: ChatRequest):
86
  chat_id = generate_chat_id()
87
  referer_url = get_referer_url(chat_id, request.model)
88
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
89
 
 
 
 
 
90
  agent_mode = AGENT_MODE.get(request.model, {})
91
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
92
  model_prefix = MODEL_PREFIXES.get(request.model, "")
93
 
94
  headers_api_chat = get_headers_api_chat(referer_url)
95
  validated_token = validate.getHid() # Get the validated token from validate.py
96
- logger.info(f"Retrieved validated token: {validated_token}")
97
-
98
 
99
  if request.model == 'o1-preview':
100
  delay_seconds = random.randint(1, 60)
@@ -147,18 +190,22 @@ async def process_streaming_response(request: ChatRequest):
147
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
148
  yield "data: [DONE]\n\n"
149
  except httpx.HTTPStatusError as e:
150
- logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
151
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
152
  except httpx.RequestError as e:
153
- logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
154
  raise HTTPException(status_code=500, detail=str(e))
155
 
156
  # Process non-streaming response with headers from config.py
157
- async def process_non_streaming_response(request: ChatRequest):
158
  chat_id = generate_chat_id()
159
  referer_url = get_referer_url(chat_id, request.model)
160
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
161
 
 
 
 
 
162
  agent_mode = AGENT_MODE.get(request.model, {})
163
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
164
  model_prefix = MODEL_PREFIXES.get(request.model, "")
@@ -206,10 +253,10 @@ async def process_non_streaming_response(request: ChatRequest):
206
  async for chunk in response.aiter_text():
207
  full_response += chunk
208
  except httpx.HTTPStatusError as e:
209
- logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
210
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
211
  except httpx.RequestError as e:
212
- logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
213
  raise HTTPException(status_code=500, detail=str(e))
214
  if full_response.startswith("$@$v=undefined-rv1$@$"):
215
  full_response = full_response[21:]
 
1
+ from datetime import datetime, timedelta
2
  import json
3
  import uuid
4
  import asyncio
5
  import random
6
  import string
7
  from typing import Any, Dict, Optional
8
+ import os
9
+ from fastapi import HTTPException, Request
10
+ from dotenv import load_dotenv
11
 
12
  import httpx
 
13
  from api import validate # Import validate to use getHid
14
  from api.config import (
15
  MODEL_MAPPING,
 
24
  from api.models import ChatRequest
25
  from api.logger import setup_logger
26
 
27
+ # Initialize environment variables and logger
28
+ load_dotenv()
29
  logger = setup_logger(__name__)
30
 
31
+ # Set request limit per minute from environment variable
32
+ REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10"))
33
+
34
+ # Dictionary to track IP addresses and request counts
35
+ request_counts = {}
36
+
37
  # Helper function to create a random alphanumeric chat ID
38
  def generate_chat_id(length: int = 7) -> str:
39
  characters = string.ascii_letters + string.digits
40
  return ''.join(random.choices(characters, k=length))
41
 
42
+ # Function to get the IP address of the requester
43
+ def get_client_ip(request: Request) -> str:
44
+ """Retrieve the IP address of the client making the request."""
45
+ return request.client.host
46
+
47
+ # Function to limit requests per IP per minute
48
+ def check_rate_limit(ip: str):
49
+ """Check if the IP has exceeded the request limit per minute."""
50
+ current_time = datetime.now()
51
+ if ip not in request_counts:
52
+ request_counts[ip] = {"count": 1, "timestamp": current_time}
53
+ logger.info(f"New IP {ip} added to request counts.")
54
+ else:
55
+ ip_data = request_counts[ip]
56
+ # Reset the count if the timestamp is more than a minute old
57
+ if current_time - ip_data["timestamp"] > timedelta(minutes=1):
58
+ request_counts[ip] = {"count": 1, "timestamp": current_time}
59
+ logger.info(f"Request count reset for IP {ip}.")
60
+ else:
61
+ # Increment the count and check if it exceeds the limit
62
+ ip_data["count"] += 1
63
+ logger.info(f"IP {ip} made request number {ip_data['count']}.")
64
+
65
+ if ip_data["count"] > REQUEST_LIMIT_PER_MINUTE:
66
+ logger.warning(f"Rate limit exceeded for IP {ip}.")
67
+ raise HTTPException(
68
+ status_code=429,
69
+ detail=f"Rate limit exceeded. Maximum {REQUEST_LIMIT_PER_MINUTE} requests per minute allowed."
70
+ )
71
+
72
  # Helper function to create chat completion data
73
  def create_chat_completion_data(
74
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
 
122
  return BASE_URL
123
 
124
  # Process streaming response with headers from config.py
125
+ async def process_streaming_response(request: ChatRequest, request_obj: Request):
126
  chat_id = generate_chat_id()
127
  referer_url = get_referer_url(chat_id, request.model)
128
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
129
 
130
+ # Get the IP address and check rate limit
131
+ client_ip = get_client_ip(request_obj)
132
+ check_rate_limit(client_ip)
133
+
134
  agent_mode = AGENT_MODE.get(request.model, {})
135
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
136
  model_prefix = MODEL_PREFIXES.get(request.model, "")
137
 
138
  headers_api_chat = get_headers_api_chat(referer_url)
139
  validated_token = validate.getHid() # Get the validated token from validate.py
140
+ logger.info(f"Retrieved validated token for IP {client_ip}: {validated_token}")
 
141
 
142
  if request.model == 'o1-preview':
143
  delay_seconds = random.randint(1, 60)
 
190
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
191
  yield "data: [DONE]\n\n"
192
  except httpx.HTTPStatusError as e:
193
+ logger.error(f"HTTP error occurred for Chat ID {chat_id} (IP: {client_ip}): {e}")
194
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
195
  except httpx.RequestError as e:
196
+ logger.error(f"Error occurred during request for Chat ID {chat_id} (IP: {client_ip}): {e}")
197
  raise HTTPException(status_code=500, detail=str(e))
198
 
199
  # Process non-streaming response with headers from config.py
200
+ async def process_non_streaming_response(request: ChatRequest, request_obj: Request):
201
  chat_id = generate_chat_id()
202
  referer_url = get_referer_url(chat_id, request.model)
203
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
204
 
205
+ # Get the IP address and check rate limit
206
+ client_ip = get_client_ip(request_obj)
207
+ check_rate_limit(client_ip)
208
+
209
  agent_mode = AGENT_MODE.get(request.model, {})
210
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
211
  model_prefix = MODEL_PREFIXES.get(request.model, "")
 
253
  async for chunk in response.aiter_text():
254
  full_response += chunk
255
  except httpx.HTTPStatusError as e:
256
+ logger.error(f"HTTP error occurred for Chat ID {chat_id} (IP: {client_ip}): {e}")
257
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
258
  except httpx.RequestError as e:
259
+ logger.error(f"Error occurred during request for Chat ID {chat_id} (IP: {client_ip}): {e}")
260
  raise HTTPException(status_code=500, detail=str(e))
261
  if full_response.startswith("$@$v=undefined-rv1$@$"):
262
  full_response = full_response[21:]