test24 / api /utils.py
Niansuh's picture
Update api/utils.py
bdcb734 verified
raw
history blame
7.33 kB
from datetime import datetime
import json
import uuid
import asyncio
import random
import string
from typing import Any, Dict, Optional, List
import httpx
from fastapi import HTTPException
from api.config import (
MODEL_MAPPING,
AGENT_MODE,
TRENDING_AGENT_MODE,
get_headers,
BASE_URL,
)
from api.models import ChatRequest
from api.logger import setup_logger
logger = setup_logger(__name__)
# Helper function to create a random alphanumeric chat ID
def generate_chat_id(length: int = 7) -> str:
characters = string.ascii_letters + string.digits
return ''.join(random.choices(characters, k=length))
# Helper function to create chat completion data
def create_chat_completion_data(
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": content, "role": "assistant"},
"finish_reason": finish_reason,
}
],
"usage": None,
}
# Function to convert message to dictionary format for API
def message_to_dict(message) -> Dict[str, Any]:
content = message.content if isinstance(message.content, str) else message.content[0]["text"]
if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
return {
"role": message.role,
"content": content,
"data": {
"imageBase64": message.content[1]["image_url"]["url"],
"fileText": "",
"title": "snapshot",
},
}
return {"role": message.role, "content": content}
# Function to retrieve agent modes for a specific model
def get_agent_modes(model: str) -> Dict[str, Any]:
agent_mode = AGENT_MODE.get(model, {})
trending_agent_mode = TRENDING_AGENT_MODE.get(model, {})
if agent_mode or trending_agent_mode:
logger.info(f"Applying agent configurations for model '{model}'")
else:
logger.info(f"Model '{model}' is not an agent model; defaulting to standard mode")
return {
"agentMode": agent_mode,
"trendingAgentMode": trending_agent_mode,
}
# Process streaming response with headers from config.py
async def process_streaming_response(request: ChatRequest):
chat_id = generate_chat_id()
logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
# Retrieve agent modes based on the model
agent_modes = get_agent_modes(request.model)
headers = get_headers()
json_data = {
"agentMode": agent_modes["agentMode"],
"trendingAgentMode": agent_modes["trendingAgentMode"],
"clickedAnswer2": False,
"clickedAnswer3": False,
"clickedForceWebSearch": False,
"codeModelMode": True,
"githubToken": None,
"id": chat_id,
"isChromeExt": False,
"isMicMode": False,
"maxTokens": request.max_tokens,
"messages": [message_to_dict(msg) for msg in request.messages],
"mobileClient": False,
"playgroundTemperature": request.temperature,
"playgroundTopP": request.top_p,
"previewToken": None,
"userId": None,
"userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
"userSystemPrompt": None,
"validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
"visitFromDelta": False,
}
logger.debug(f"Data Payload: {json_data}") # Inspect payload for accuracy
async with httpx.AsyncClient() as client:
try:
async with client.stream(
"POST",
f"{BASE_URL}/api/chat",
headers=headers,
json=json_data,
timeout=100,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
timestamp = int(datetime.now().timestamp())
if line:
yield f"data: {json.dumps(create_chat_completion_data(line, request.model, timestamp))}\n\n"
yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
yield "data: [DONE]\n\n"
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Process non-streaming response with headers from config.py
async def process_non_streaming_response(request: ChatRequest):
chat_id = generate_chat_id()
logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
# Retrieve agent modes based on the model
agent_modes = get_agent_modes(request.model)
headers = get_headers()
json_data = {
"agentMode": agent_modes["agentMode"],
"trendingAgentMode": agent_modes["trendingAgentMode"],
"clickedAnswer2": False,
"clickedAnswer3": False,
"clickedForceWebSearch": False,
"codeModelMode": True,
"githubToken": None,
"id": chat_id,
"isChromeExt": False,
"isMicMode": False,
"maxTokens": request.max_tokens,
"messages": [message_to_dict(msg) for msg in request.messages],
"mobileClient": False,
"playgroundTemperature": request.temperature,
"playgroundTopP": request.top_p,
"previewToken": None,
"userId": None,
"userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
"userSystemPrompt": None,
"validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
"visitFromDelta": False,
}
logger.debug(f"Data Payload: {json_data}") # Inspect payload for accuracy
full_response = ""
async with httpx.AsyncClient() as client:
try:
async with client.stream(
method="POST", url=f"{BASE_URL}/api/chat", headers=headers, json=json_data
) as response:
response.raise_for_status()
async for chunk in response.aiter_text():
full_response += chunk
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": full_response},
"finish_reason": "stop",
}
],
"usage": None,
}