from datetime import datetime, timedelta import json import uuid import asyncio import random import string from typing import Any, Dict, Optional import os from fastapi import HTTPException, Request from dotenv import load_dotenv import httpx from api import validate # Import validate to use getHid from api.config import ( MODEL_MAPPING, get_headers_api_chat, get_headers_chat, BASE_URL, AGENT_MODE, TRENDING_AGENT_MODE, MODEL_PREFIXES, MODEL_REFERERS ) from api.models import ChatRequest from api.logger import setup_logger # Initialize environment variables and logger load_dotenv() logger = setup_logger(__name__) # Set request limit per minute from environment variable REQUEST_LIMIT_PER_MINUTE = int(os.getenv("REQUEST_LIMIT_PER_MINUTE", "10")) # Dictionary to track IP addresses and request counts request_counts = {} # Function to get the IP address of the requester def get_client_ip(request: Request) -> str: """Retrieve the IP address of the client making the request.""" return request.client.host # Function to limit requests per IP per minute def check_rate_limit(ip: str): """Check if the IP has exceeded the request limit per minute.""" current_time = datetime.now() if ip not in request_counts: request_counts[ip] = {"count": 1, "timestamp": current_time} logger.info(f"New IP {ip} added to request counts.") else: ip_data = request_counts[ip] # Reset the count if the timestamp is more than a minute old if current_time - ip_data["timestamp"] > timedelta(minutes=1): request_counts[ip] = {"count": 1, "timestamp": current_time} logger.info(f"Request count reset for IP {ip}.") else: # Increment the count and check if it exceeds the limit ip_data["count"] += 1 logger.info(f"IP {ip} made request number {ip_data['count']}.") if ip_data["count"] > REQUEST_LIMIT_PER_MINUTE: logger.warning(f"Rate limit exceeded for IP {ip}.") raise HTTPException( status_code=429, detail=f"Rate limit exceeded. Maximum {REQUEST_LIMIT_PER_MINUTE} requests per minute allowed." ) # Helper function to create chat completion data def create_chat_completion_data( content: str, model: str, timestamp: int, finish_reason: Optional[str] = None ) -> Dict[str, Any]: return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": {"content": content, "role": "assistant"}, "finish_reason": finish_reason, } ], "usage": None, } # Function to convert message to dictionary format, ensuring base64 data and optional model prefix def message_to_dict(message, model_prefix: Optional[str] = None): content = message.content if isinstance(message.content, str) else message.content[0]["text"] if model_prefix: content = f"{model_prefix} {content}" if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]: # Ensure base64 images are always included for all models return { "role": message.role, "content": content, "data": { "imageBase64": message.content[1]["image_url"]["url"], "fileText": "", "title": "snapshot", }, } return {"role": message.role, "content": content} # Function to strip model prefix from content if present def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str: """Remove the model prefix from the response content if present.""" if model_prefix and content.startswith(model_prefix): logger.debug(f"Stripping prefix '{model_prefix}' from content.") return content[len(model_prefix):].strip() return content # Process streaming response with headers from config.py async def process_streaming_response(request: ChatRequest, request_obj: Request): referer_url = get_referer_url(request.model) logger.info(f"Processing streaming response - Model: {request.model} - URL: {referer_url}") # Get the IP address and check rate limit client_ip = get_client_ip(request_obj) check_rate_limit(client_ip) agent_mode = AGENT_MODE.get(request.model, {}) trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {}) model_prefix = MODEL_PREFIXES.get(request.model, "") headers_api_chat = get_headers_api_chat(referer_url) validated_token = validate.getHid() # Get the validated token from validate.py logger.info(f"Retrieved validated token for IP {client_ip}: {validated_token}") if request.model == 'o1-preview': delay_seconds = random.randint(1, 60) logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview'") await asyncio.sleep(delay_seconds) json_data = { "agentMode": agent_mode, "clickedAnswer2": False, "clickedAnswer3": False, "clickedForceWebSearch": False, "codeModelMode": True, "githubToken": None, "isChromeExt": False, "isMicMode": False, "maxTokens": request.max_tokens, "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages], "mobileClient": False, "playgroundTemperature": request.temperature, "playgroundTopP": request.top_p, "previewToken": None, "trendingAgentMode": trending_agent_mode, "userId": None, "userSelectedModel": MODEL_MAPPING.get(request.model, request.model), "userSystemPrompt": None, "validated": validated_token, "visitFromDelta": False, } async with httpx.AsyncClient() as client: try: async with client.stream( "POST", f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data, timeout=100, ) as response: response.raise_for_status() async for line in response.aiter_lines(): timestamp = int(datetime.now().timestamp()) if line: content = line if content.startswith("$@$v=undefined-rv1$@$"): content = content[21:] cleaned_content = strip_model_prefix(content, model_prefix) yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n" yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n" yield "data: [DONE]\n\n" except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred (IP: {client_ip}): {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request (IP: {client_ip}): {e}") raise HTTPException(status_code=500, detail=str(e)) # Process non-streaming response with headers from config.py async def process_non_streaming_response(request: ChatRequest, request_obj: Request): referer_url = get_referer_url(request.model) logger.info(f"Processing non-streaming response - Model: {request.model} - URL: {referer_url}") # Get the IP address and check rate limit client_ip = get_client_ip(request_obj) check_rate_limit(client_ip) agent_mode = AGENT_MODE.get(request.model, {}) trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {}) model_prefix = MODEL_PREFIXES.get(request.model, "") headers_api_chat = get_headers_api_chat(referer_url) headers_chat = get_headers_chat(referer_url, next_action=str(uuid.uuid4()), next_router_state_tree=json.dumps([""])) validated_token = validate.getHid() # Get the validated token from validate.py if request.model == 'o1-preview': delay_seconds = random.randint(20, 60) logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview'") await asyncio.sleep(delay_seconds) json_data = { "agentMode": agent_mode, "clickedAnswer2": False, "clickedAnswer3": False, "clickedForceWebSearch": False, "codeModelMode": True, "githubToken": None, "isChromeExt": False, "isMicMode": False, "maxTokens": request.max_tokens, "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages], "mobileClient": False, "playgroundTemperature": request.temperature, "playgroundTopP": request.top_p, "previewToken": None, "trendingAgentMode": trending_agent_mode, "userId": None, "userSelectedModel": MODEL_MAPPING.get(request.model, request.model), "userSystemPrompt": None, "validated": validated_token, "visitFromDelta": False, } full_response = "" async with httpx.AsyncClient() as client: try: async with client.stream( method="POST", url=f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data ) as response: response.raise_for_status() async for chunk in response.aiter_text(): full_response += chunk except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred (IP: {client_ip}): {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request (IP: {client_ip}): {e}") raise HTTPException(status_code=500, detail=str(e)) if full_response.startswith("$@$v=undefined-rv1$@$"): full_response = full_response[21:] cleaned_full_response = strip_model_prefix(full_response, model_prefix) return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": request.model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": cleaned_full_response}, "finish_reason": "stop", } ], "usage": None, }