|
|
|
|
|
from datetime import datetime
|
|
import json
|
|
import uuid
|
|
import asyncio
|
|
import random
|
|
import string
|
|
from typing import Any, Dict, Optional
|
|
|
|
import httpx
|
|
from fastapi import HTTPException
|
|
from api.config import (
|
|
MODEL_MAPPING,
|
|
get_headers_api_chat,
|
|
get_headers_chat,
|
|
BASE_URL,
|
|
AGENT_MODE,
|
|
TRENDING_AGENT_MODE,
|
|
MODEL_PREFIXES,
|
|
MODEL_REFERERS
|
|
)
|
|
from api.models import ChatRequest
|
|
from api.logger import setup_logger
|
|
from api.validate import getHid
|
|
|
|
logger = setup_logger(__name__)
|
|
|
|
|
|
def create_chat_completion_data(
|
|
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": timestamp,
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": content, "role": "assistant"},
|
|
"finish_reason": finish_reason,
|
|
}
|
|
],
|
|
"usage": None,
|
|
}
|
|
|
|
|
|
def message_to_dict(message, model_prefix: Optional[str] = None):
|
|
content = message.content if isinstance(message.content, str) else message.content[0]["text"]
|
|
if model_prefix:
|
|
content = f"{model_prefix} {content}"
|
|
if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
|
|
|
|
return {
|
|
"role": message.role,
|
|
"content": content,
|
|
"data": {
|
|
"imageBase64": message.content[1]["image_url"]["url"],
|
|
"fileText": "",
|
|
"title": "snapshot",
|
|
},
|
|
}
|
|
return {"role": message.role, "content": content}
|
|
|
|
|
|
def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
|
|
"""Remove the model prefix from the response content if present."""
|
|
if model_prefix and content.startswith(model_prefix):
|
|
logger.debug(f"Stripping prefix '{model_prefix}' from content.")
|
|
return content[len(model_prefix):].strip()
|
|
return content
|
|
|
|
|
|
async def process_streaming_response(request: ChatRequest):
|
|
|
|
request_id = f"chatcmpl-{uuid.uuid4()}"
|
|
logger.info(f"Processing request with ID: {request_id} - Model: {request.model}")
|
|
|
|
agent_mode = AGENT_MODE.get(request.model, {})
|
|
trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
|
|
model_prefix = MODEL_PREFIXES.get(request.model, "")
|
|
|
|
|
|
headers_api_chat = get_headers_api_chat(BASE_URL)
|
|
|
|
if request.model == 'o1-preview':
|
|
delay_seconds = random.randint(1, 60)
|
|
logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Request ID: {request_id})")
|
|
await asyncio.sleep(delay_seconds)
|
|
|
|
|
|
h_value = await getHid()
|
|
if not h_value:
|
|
logger.error("Failed to retrieve h-value for validation.")
|
|
raise HTTPException(status_code=500, detail="Validation failed due to missing h-value.")
|
|
|
|
json_data = {
|
|
"agentMode": agent_mode,
|
|
"clickedAnswer2": False,
|
|
"clickedAnswer3": False,
|
|
"clickedForceWebSearch": False,
|
|
"codeModelMode": True,
|
|
"githubToken": None,
|
|
"id": None,
|
|
"isChromeExt": False,
|
|
"isMicMode": False,
|
|
"maxTokens": request.max_tokens,
|
|
"messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
|
|
"mobileClient": False,
|
|
"playgroundTemperature": request.temperature,
|
|
"playgroundTopP": request.top_p,
|
|
"previewToken": None,
|
|
"trendingAgentMode": trending_agent_mode,
|
|
"userId": None,
|
|
"userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
|
|
"userSystemPrompt": None,
|
|
"validated": h_value,
|
|
"visitFromDelta": False,
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{BASE_URL}/api/chat",
|
|
headers=headers_api_chat,
|
|
json=json_data,
|
|
timeout=100,
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
timestamp = int(datetime.now().timestamp())
|
|
if line:
|
|
content = line
|
|
if content.startswith("$@$v=undefined-rv1$@$"):
|
|
content = content[21:]
|
|
cleaned_content = strip_model_prefix(content, model_prefix)
|
|
yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n"
|
|
|
|
yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"HTTP error occurred for Request ID {request_id}: {e}")
|
|
raise HTTPException(status_code=e.response.status_code, detail=str(e))
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Error occurred during request for Request ID {request_id}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def process_non_streaming_response(request: ChatRequest):
|
|
|
|
request_id = f"chatcmpl-{uuid.uuid4()}"
|
|
logger.info(f"Processing request with ID: {request_id} - Model: {request.model}")
|
|
|
|
agent_mode = AGENT_MODE.get(request.model, {})
|
|
trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
|
|
model_prefix = MODEL_PREFIXES.get(request.model, "")
|
|
|
|
|
|
headers_api_chat = get_headers_api_chat(BASE_URL)
|
|
headers_chat = get_headers_chat(BASE_URL, next_action=str(uuid.uuid4()), next_router_state_tree=json.dumps([""]))
|
|
|
|
if request.model == 'o1-preview':
|
|
delay_seconds = random.randint(20, 60)
|
|
logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Request ID: {request_id})")
|
|
await asyncio.sleep(delay_seconds)
|
|
|
|
|
|
h_value = await getHid()
|
|
if not h_value:
|
|
logger.error("Failed to retrieve h-value for validation.")
|
|
raise HTTPException(status_code=500, detail="Validation failed due to missing h-value.")
|
|
|
|
json_data = {
|
|
"agentMode": agent_mode,
|
|
"clickedAnswer2": False,
|
|
"clickedAnswer3": False,
|
|
"clickedForceWebSearch": False,
|
|
"codeModelMode": True,
|
|
"githubToken": None,
|
|
"id": None,
|
|
"isChromeExt": False,
|
|
"isMicMode": False,
|
|
"maxTokens": request.max_tokens,
|
|
"messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
|
|
"mobileClient": False,
|
|
"playgroundTemperature": request.temperature,
|
|
"playgroundTopP": request.top_p,
|
|
"previewToken": None,
|
|
"trendingAgentMode": trending_agent_mode,
|
|
"userId": None,
|
|
"userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
|
|
"userSystemPrompt": None,
|
|
"validated": h_value,
|
|
"visitFromDelta": False,
|
|
}
|
|
|
|
full_response = ""
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
async with client.stream(
|
|
method="POST", url=f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for chunk in response.aiter_text():
|
|
full_response += chunk
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"HTTP error occurred for Request ID {request_id}: {e}")
|
|
raise HTTPException(status_code=e.response.status_code, detail=str(e))
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Error occurred during request for Request ID {request_id}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if full_response.startswith("$@$v=undefined-rv1$@$"):
|
|
full_response = full_response[21:]
|
|
|
|
cleaned_full_response = strip_model_prefix(full_response, model_prefix)
|
|
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion",
|
|
"created": int(datetime.now().timestamp()),
|
|
"model": request.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": cleaned_full_response},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": None,
|
|
}
|
|
|