test24 / api /utils.py
Niansuh's picture
Update api/utils.py
3e0f1a7 verified
raw
history blame
10 kB
from datetime import datetime
import json
import uuid
import asyncio
import random
import string
from typing import Any, Dict, Optional
import httpx
from fastapi import HTTPException
from api.config import (
MODEL_MAPPING,
headers,
BASE_URL,
AGENT_MODE,
TRENDING_AGENT_MODE,
MODEL_PREFIXES
)
from api.models import ChatRequest
from api.logger import setup_logger
from api import validate
logger = setup_logger(__name__)
# Helper function to create a random alphanumeric chat ID
def generate_chat_id(length: int = 7) -> str:
characters = string.ascii_letters + string.digits
return ''.join(random.choices(characters, k=length))
# Helper function to create chat completion data
def create_chat_completion_data(
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": content, "role": "assistant"},
"finish_reason": finish_reason,
}
],
"usage": None,
}
# Function to convert message to dictionary format, ensuring base64 data and optional model prefix
def message_to_dict(message, model_prefix: Optional[str] = None):
if isinstance(message.content, str):
content = message.content
elif isinstance(message.content, list):
content = message.content[0]["text"]
else:
content = message.content
if model_prefix:
content = f"{model_prefix} {content}"
if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
# Ensure base64 images are always included for all models
return {
"role": message.role,
"content": content,
"data": {
"imageBase64": message.content[1]["image_url"]["url"],
"fileText": "",
"title": "snapshot",
},
}
return {"role": message.role, "content": content}
# Function to strip model prefix from content if present
def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
"""Remove the model prefix from the response content if present."""
if model_prefix and content.startswith(model_prefix):
logger.debug(f"Stripping prefix '{model_prefix}' from content.")
return content[len(model_prefix):].strip()
return content
# Process streaming response
async def process_streaming_response(request: ChatRequest):
chat_id = generate_chat_id()
referer_url = BASE_URL
logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
agent_mode = AGENT_MODE.get(request.model, {})
trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
model_prefix = MODEL_PREFIXES.get(request.model, "")
headers_api_chat = headers.copy()
headers_api_chat['Referer'] = referer_url
headers_api_chat['Cookie'] = f'hid={validate.getHid()}'
logger.debug(f"Headers being sent: {headers_api_chat}")
if request.model == 'o1-preview':
delay_seconds = random.randint(1, 60)
logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
await asyncio.sleep(delay_seconds)
json_data = {
"agentMode": agent_mode,
"clickedAnswer2": False,
"clickedAnswer3": False,
"clickedForceWebSearch": False,
"codeModelMode": True,
"githubToken": None,
"id": chat_id,
"isChromeExt": False,
"isMicMode": False,
"maxTokens": request.max_tokens,
"messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
"mobileClient": False,
"playgroundTemperature": request.temperature,
"playgroundTopP": request.top_p,
"previewToken": None,
"trendingAgentMode": trending_agent_mode,
"userId": None,
"userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
"userSystemPrompt": None,
# Remove 'validated' if not required
# "validated": validate.getHid(),
"visitFromDelta": False,
}
logger.debug(f"JSON payload being sent: {json.dumps(json_data)}")
async with httpx.AsyncClient() as client:
try:
async with client.stream(
"POST",
f"{BASE_URL}/api/chat",
headers=headers_api_chat,
json=json_data,
timeout=100,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
timestamp = int(datetime.now().timestamp())
if line:
content = line
logger.debug(f"Received content: {content}")
# Modify the condition to detect specific error message
if "Invalid or expired 'hid'" in content:
logger.warning("Invalid or expired 'hid' detected. Refreshing 'hid'.")
validate.getHid(True)
content = "hid已刷新,重新对话即可\n"
yield f"data: {json.dumps(create_chat_completion_data(content, request.model, timestamp))}\n\n"
break
elif content.startswith("$@$v=undefined-rv1$@$"):
content = content[21:]
cleaned_content = strip_model_prefix(content, model_prefix)
yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n"
yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
yield "data: [DONE]\n\n"
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Process non-streaming response
async def process_non_streaming_response(request: ChatRequest):
chat_id = generate_chat_id()
referer_url = BASE_URL
logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
agent_mode = AGENT_MODE.get(request.model, {})
trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
model_prefix = MODEL_PREFIXES.get(request.model, "")
headers_api_chat = headers.copy()
headers_api_chat['Referer'] = referer_url
headers_api_chat['Cookie'] = f'hid={validate.getHid()}'
logger.debug(f"Headers being sent: {headers_api_chat}")
if request.model == 'o1-preview':
delay_seconds = random.randint(20, 60)
logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
await asyncio.sleep(delay_seconds)
json_data = {
"agentMode": agent_mode,
"clickedAnswer2": False,
"clickedAnswer3": False,
"clickedForceWebSearch": False,
"codeModelMode": True,
"githubToken": None,
"id": chat_id,
"isChromeExt": False,
"isMicMode": False,
"maxTokens": request.max_tokens,
"messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
"mobileClient": False,
"playgroundTemperature": request.temperature,
"playgroundTopP": request.top_p,
"previewToken": None,
"trendingAgentMode": trending_agent_mode,
"userId": None,
"userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
"userSystemPrompt": None,
# Remove 'validated' if not required
# "validated": validate.getHid(),
"visitFromDelta": False,
}
logger.debug(f"JSON payload being sent: {json.dumps(json_data)}")
full_response = ""
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{BASE_URL}/api/chat",
headers=headers_api_chat,
json=json_data,
timeout=100,
)
response.raise_for_status()
full_response = response.text
logger.debug(f"Full response received: {full_response}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Modify the condition to detect specific error message
if "Invalid or expired 'hid'" in full_response:
logger.warning("Invalid or expired 'hid' detected. Refreshing 'hid'.")
validate.getHid(True)
full_response = "hid已刷新,重新对话即可"
if full_response.startswith("$@$v=undefined-rv1$@$"):
full_response = full_response[21:]
cleaned_full_response = strip_model_prefix(full_response, model_prefix)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": cleaned_full_response},
"finish_reason": "stop",
}
],
"usage": None,
}