from datetime import datetime, timedelta import json import uuid import asyncio import random from typing import Any, Dict, Optional import os from fastapi import HTTPException, Request from dotenv import load_dotenv import httpx from api import validate from api.config import ( MODEL_MAPPING, get_headers_api_chat, get_headers_chat, BASE_URL, AGENT_MODE, TRENDING_AGENT_MODE, MODEL_PREFIXES ) from api.models import ChatRequest from api.logger import setup_logger # Initialize environment variables and logger load_dotenv() logger = setup_logger(__name__) # Set request limit per minute from environment variable REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10")) # Dictionary to track IP addresses and request counts request_counts = {} # Function to get the IP address of the requester def get_client_ip(request: Request) -> str: """Retrieve the IP address of the client making the request.""" return request.client.host # Function to limit requests per IP per minute def check_rate_limit(ip: str): """Check if the IP has exceeded the request limit per minute.""" current_time = datetime.now() if ip not in request_counts: # If the IP is new, initialize its counter and timestamp request_counts[ip] = {"count": 1, "timestamp": current_time} logger.info(f"New IP {ip} added to request counts.") else: ip_data = request_counts[ip] # Check if the timestamp is more than a minute old if current_time - ip_data["timestamp"] < timedelta(minutes=1): # If within the same minute, increment the count ip_data["count"] += 1 logger.info(f"IP {ip} made request number {ip_data['count']}.") if ip_data["count"] > REQUEST_LIMIT_PER_MINUTE: logger.warning(f"Rate limit exceeded for IP {ip}.") raise HTTPException( status_code=429, detail={"error": {"message": "Rate limit exceeded. Please wait and try again.", "type": "rate_limit"}}, ) else: # If more than a minute has passed, reset the count and timestamp request_counts[ip] = {"count": 1, "timestamp": current_time} logger.info(f"Request count reset for IP {ip}.") # Helper function to create chat completion data def create_chat_completion_data( content: str, model: str, timestamp: int, finish_reason: Optional[str] = None ) -> Dict[str, Any]: return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": {"content": content, "role": "assistant"}, "finish_reason": finish_reason, } ], "usage": None, } # Function to convert message to dictionary format, ensuring base64 data and optional model prefix def message_to_dict(message, model_prefix: Optional[str] = None): content = message.content if isinstance(message.content, str) else message.content[0]["text"] if model_prefix: content = f"{model_prefix} {content}" if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]: # Ensure base64 images are always included for all models return { "role": message.role, "content": content, "data": { "imageBase64": message.content[1]["image_url"]["url"], "fileText": "", "title": "snapshot", }, } return {"role": message.role, "content": content} # Function to strip model prefix from content if present def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str: """Remove the model prefix from the response content if present.""" if model_prefix and content.startswith(model_prefix): logger.debug(f"Stripping prefix '{model_prefix}' from content.") return content[len(model_prefix):].strip() return content # Simplified function to get the base referer URL def get_referer_url() -> str: """Return the base URL for the referer without model-specific logic.""" return BASE_URL # Process streaming response with headers from config.py async def process_streaming_response(request: ChatRequest, request_obj: Request): referer_url = get_referer_url() logger.info(f"Processing streaming response - Model: {request.model} - URL: {referer_url}") # Get the IP address and check rate limit client_ip = get_client_ip(request_obj) check_rate_limit(client_ip) agent_mode = AGENT_MODE.get(request.model, {}) trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {}) model_prefix = MODEL_PREFIXES.get(request.model, "") headers_api_chat = get_headers_api_chat(referer_url) validated_token = validate.getHid() # Get the validated token from validate.py logger.info(f"Retrieved validated token for IP {client_ip}: {validated_token}") if request.model == 'o1-preview': delay_seconds = random.randint(1, 60) logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview'") await asyncio.sleep(delay_seconds) json_data = { "agentMode": agent_mode, "clickedAnswer2": False, "clickedAnswer3": False, "clickedForceWebSearch": False, "codeModelMode": True, "githubToken": None, "isChromeExt": False, "isMicMode": False, "maxTokens": request.max_tokens, "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages], "mobileClient": False, "playgroundTemperature": request.temperature, "playgroundTopP": request.top_p, "previewToken": None, "trendingAgentMode": trending_agent_mode, "userId": None, "userSelectedModel": MODEL_MAPPING.get(request.model, request.model), "userSystemPrompt": None, "validated": validated_token, "visitFromDelta": False, } async with httpx.AsyncClient() as client: try: async with client.stream( "POST", f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data, timeout=100, ) as response: response.raise_for_status() async for line in response.aiter_lines(): timestamp = int(datetime.now().timestamp()) if line: content = line if content.startswith("$@$v=undefined-rv1$@$"): content = content[21:] cleaned_content = strip_model_prefix(content, model_prefix) yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n" yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n" yield "data: [DONE]\n\n" except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred (IP: {client_ip}): {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request (IP: {client_ip}): {e}") raise HTTPException(status_code=500, detail=str(e)) # Process non-streaming response with headers from config.py async def process_non_streaming_response(request: ChatRequest, request_obj: Request): referer_url = get_referer_url() logger.info(f"Processing non-streaming response - Model: {request.model} - URL: {referer_url}") # Get the IP address and check rate limit client_ip = get_client_ip(request_obj) check_rate_limit(client_ip) agent_mode = AGENT_MODE.get(request.model, {}) trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {}) model_prefix = MODEL_PREFIXES.get(request.model, "") headers_api_chat = get_headers_api_chat(referer_url) headers_chat = get_headers_chat(referer_url, next_action=str(uuid.uuid4()), next_router_state_tree=json.dumps([""])) validated_token = validate.getHid() if request.model == 'o1-preview': delay_seconds = random.randint(20, 60) logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview'") await asyncio.sleep(delay_seconds) json_data = { "agentMode": agent_mode, "clickedAnswer2": False, "clickedAnswer3": False, "clickedForceWebSearch": False, "codeModelMode": True, "githubToken": None, "isChromeExt": False, "isMicMode": False, "maxTokens": request.max_tokens, "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages], "mobileClient": False, "playgroundTemperature": request.temperature, "playgroundTopP": request.top_p, "previewToken": None, "trendingAgentMode": trending_agent_mode, "userId": None, "userSelectedModel": MODEL_MAPPING.get(request.model, request.model), "userSystemPrompt": None, "validated": validated_token, "visitFromDelta": False, } full_response = "" async with httpx.AsyncClient() as client: try: async with client.stream( method="POST", url=f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data ) as response: response.raise_for_status() async for chunk in response.aiter_text(): full_response += chunk except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred (IP: {client_ip}): {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request (IP: {client_ip}): {e}") raise HTTPException(status_code=500, detail=str(e)) if full_response.startswith("$@$v=undefined-rv1$@$"): full_response = full_response[21:] cleaned_full_response = strip_model_prefix(full_response, model_prefix) return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": request.model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": cleaned_full_response}, "finish_reason": "stop", } ], "usage": None, }