Niansuh commited on
Commit
922e6b4
·
verified ·
1 Parent(s): 96bd80a

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +4 -52
api/utils.py CHANGED
@@ -1,13 +1,12 @@
1
- from datetime import datetime, timedelta
2
  import json
3
  import uuid
4
  import asyncio
5
  import random
6
  from typing import Any, Dict, Optional
7
- import os
8
  from fastapi import HTTPException, Request
9
  from dotenv import load_dotenv
10
- import httpx
11
  from api import validate
12
  from api.config import (
13
  MODEL_MAPPING,
@@ -20,49 +19,12 @@ from api.config import (
20
  )
21
  from api.models import ChatRequest
22
  from api.logger import setup_logger
 
23
 
24
  # Initialize environment variables and logger
25
  load_dotenv()
26
  logger = setup_logger(__name__)
27
 
28
- # Set request limit per minute from environment variable
29
- REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10"))
30
-
31
- # Dictionary to track IP addresses and request counts
32
- request_counts = {}
33
-
34
- # Function to get the IP address of the requester
35
- def get_client_ip(request: Request) -> str:
36
- """Retrieve the IP address of the client making the request."""
37
- return request.client.host
38
-
39
- # Function to limit requests per IP per minute
40
- def check_rate_limit(ip: str):
41
- """Check if the IP has exceeded the request limit per minute."""
42
- current_time = datetime.now()
43
- if ip not in request_counts:
44
- # If the IP is new, initialize its counter and timestamp
45
- request_counts[ip] = {"count": 1, "timestamp": current_time}
46
- logger.info(f"New IP {ip} added to request counts.")
47
- else:
48
- ip_data = request_counts[ip]
49
- # Check if the timestamp is more than a minute old
50
- if current_time - ip_data["timestamp"] < timedelta(minutes=1):
51
- # If within the same minute, increment the count
52
- ip_data["count"] += 1
53
- logger.info(f"IP {ip} made request number {ip_data['count']}.")
54
- if ip_data["count"] > REQUEST_LIMIT_PER_MINUTE:
55
- logger.warning(f"Rate limit exceeded for IP {ip}.")
56
- raise HTTPException(
57
- status_code=429,
58
- detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}},
59
- )
60
- else:
61
- # If more than a minute has passed, reset the count and timestamp
62
- request_counts[ip] = {"count": 1, "timestamp": current_time}
63
- logger.info(f"Request count reset for IP {ip}.")
64
-
65
- # Helper function to create chat completion data
66
  def create_chat_completion_data(
67
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
68
  ) -> Dict[str, Any]:
@@ -81,13 +43,11 @@ def create_chat_completion_data(
81
  "usage": None,
82
  }
83
 
84
- # Function to convert message to dictionary format, ensuring base64 data and optional model prefix
85
  def message_to_dict(message, model_prefix: Optional[str] = None):
86
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
87
  if model_prefix:
88
  content = f"{model_prefix} {content}"
89
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
90
- # Ensure base64 images are always included for all models
91
  return {
92
  "role": message.role,
93
  "content": content,
@@ -99,25 +59,19 @@ def message_to_dict(message, model_prefix: Optional[str] = None):
99
  }
100
  return {"role": message.role, "content": content}
101
 
102
- # Function to strip model prefix from content if present
103
  def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
104
- """Remove the model prefix from the response content if present."""
105
  if model_prefix and content.startswith(model_prefix):
106
  logger.debug(f"Stripping prefix '{model_prefix}' from content.")
107
  return content[len(model_prefix):].strip()
108
  return content
109
 
110
- # Simplified function to get the base referer URL
111
  def get_referer_url() -> str:
112
- """Return the base URL for the referer without model-specific logic."""
113
  return BASE_URL
114
 
115
- # Process streaming response with headers from config.py
116
  async def process_streaming_response(request: ChatRequest, request_obj: Request):
117
  referer_url = get_referer_url()
118
  logger.info(f"Processing streaming response - Model: {request.model} - URL: {referer_url}")
119
 
120
- # Get the IP address and check rate limit
121
  client_ip = get_client_ip(request_obj)
122
  check_rate_limit(client_ip)
123
 
@@ -126,7 +80,7 @@ async def process_streaming_response(request: ChatRequest, request_obj: Request)
126
  model_prefix = MODEL_PREFIXES.get(request.model, "")
127
 
128
  headers_api_chat = get_headers_api_chat(referer_url)
129
- validated_token = validate.getHid() # Get the validated token from validate.py
130
  logger.info(f"Retrieved validated token for IP {client_ip}: {validated_token}")
131
 
132
  if request.model == 'o1-preview':
@@ -185,12 +139,10 @@ async def process_streaming_response(request: ChatRequest, request_obj: Request)
185
  logger.error(f"Error occurred during request (IP: {client_ip}): {e}")
186
  raise HTTPException(status_code=500, detail=str(e))
187
 
188
- # Process non-streaming response with headers from config.py
189
  async def process_non_streaming_response(request: ChatRequest, request_obj: Request):
190
  referer_url = get_referer_url()
191
  logger.info(f"Processing non-streaming response - Model: {request.model} - URL: {referer_url}")
192
 
193
- # Get the IP address and check rate limit
194
  client_ip = get_client_ip(request_obj)
195
  check_rate_limit(client_ip)
196
 
 
1
+ from datetime import datetime
2
  import json
3
  import uuid
4
  import asyncio
5
  import random
6
  from typing import Any, Dict, Optional
7
+ import httpx
8
  from fastapi import HTTPException, Request
9
  from dotenv import load_dotenv
 
10
  from api import validate
11
  from api.config import (
12
  MODEL_MAPPING,
 
19
  )
20
  from api.models import ChatRequest
21
  from api.logger import setup_logger
22
+ from api.rpmlimits import check_rate_limit, get_client_ip # Import rate limit functions
23
 
24
  # Initialize environment variables and logger
25
  load_dotenv()
26
  logger = setup_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def create_chat_completion_data(
29
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
30
  ) -> Dict[str, Any]:
 
43
  "usage": None,
44
  }
45
 
 
46
  def message_to_dict(message, model_prefix: Optional[str] = None):
47
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
48
  if model_prefix:
49
  content = f"{model_prefix} {content}"
50
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
 
51
  return {
52
  "role": message.role,
53
  "content": content,
 
59
  }
60
  return {"role": message.role, "content": content}
61
 
 
62
  def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
 
63
  if model_prefix and content.startswith(model_prefix):
64
  logger.debug(f"Stripping prefix '{model_prefix}' from content.")
65
  return content[len(model_prefix):].strip()
66
  return content
67
 
 
68
  def get_referer_url() -> str:
 
69
  return BASE_URL
70
 
 
71
  async def process_streaming_response(request: ChatRequest, request_obj: Request):
72
  referer_url = get_referer_url()
73
  logger.info(f"Processing streaming response - Model: {request.model} - URL: {referer_url}")
74
 
 
75
  client_ip = get_client_ip(request_obj)
76
  check_rate_limit(client_ip)
77
 
 
80
  model_prefix = MODEL_PREFIXES.get(request.model, "")
81
 
82
  headers_api_chat = get_headers_api_chat(referer_url)
83
+ validated_token = validate.getHid()
84
  logger.info(f"Retrieved validated token for IP {client_ip}: {validated_token}")
85
 
86
  if request.model == 'o1-preview':
 
139
  logger.error(f"Error occurred during request (IP: {client_ip}): {e}")
140
  raise HTTPException(status_code=500, detail=str(e))
141
 
 
142
  async def process_non_streaming_response(request: ChatRequest, request_obj: Request):
143
  referer_url = get_referer_url()
144
  logger.info(f"Processing non-streaming response - Model: {request.model} - URL: {referer_url}")
145
 
 
146
  client_ip = get_client_ip(request_obj)
147
  check_rate_limit(client_ip)
148