test24 / api /utils.py
Niansuh's picture
Update api/utils.py
da8ba32 verified
raw
history blame
5.65 kB
from datetime import datetime
import json
import uuid
import asyncio
import random
import string
from typing import Any, Dict, Optional, AsyncGenerator
import httpx
from fastapi import HTTPException
from api.config import (
models,
model_aliases,
ALLOWED_MODELS,
MODEL_MAPPING,
get_headers_api_chat,
BASE_URL,
)
from api.models import ChatRequest, Message
from api.logger import setup_logger
from api.providers.gizai import GizAI # Import the GizAI provider
logger = setup_logger(__name__)
# Helper function to create a random alphanumeric chat ID
def generate_chat_id(length: int = 7) -> str:
characters = string.ascii_letters + string.digits
return ''.join(random.choices(characters, k=length))
# Helper function to create chat completion data
def create_chat_completion_data(
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": content, "role": "assistant"},
"finish_reason": finish_reason,
}
],
"usage": None,
}
# Function to convert message to dictionary format, ensuring base64 data
def message_to_dict(message: Message):
if isinstance(message.content, str):
content = message.content
elif isinstance(message.content, list) and isinstance(message.content[0], dict) and "text" in message.content[0]:
content = message.content[0]["text"]
else:
content = ""
if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
# Ensure base64 images are always included for all models
return {
"role": message.role,
"content": content,
"data": {
"imageBase64": message.content[1]["image_url"]["url"],
"fileText": "",
"title": "snapshot",
},
}
return {"role": message.role, "content": content}
# Function to resolve model aliases
def resolve_model(model: str) -> str:
if model in MODEL_MAPPING:
return model
elif model in model_aliases:
return model_aliases[model]
else:
logger.warning(f"Model '{model}' not recognized. Using default model '{GizAI.default_model}'.")
return GizAI.default_model # default_model
# Process streaming response with GizAI provider
async def process_streaming_response(request: ChatRequest) -> AsyncGenerator[str, None]:
chat_id = generate_chat_id()
resolved_model = resolve_model(request.model)
logger.info(f"Generated Chat ID: {chat_id} - Model: {resolved_model}")
# Instantiate the GizAI provider
gizai_provider = GizAI()
# Create the async generator
async for response in gizai_provider.create_async_generator(
model=resolved_model,
messages=request.messages,
proxy=request.proxy # Assuming 'proxy' is part of ChatRequest; if not, adjust accordingly
):
timestamp = int(datetime.now().timestamp())
if isinstance(response, ImageResponse):
# Handle image responses
yield f"data: {json.dumps({'image_url': response.images, 'alt': response.alt})}\n\n"
else:
# Handle text responses
yield f"data: {json.dumps(create_chat_completion_data(response, resolved_model, timestamp))}\n\n"
# Indicate completion
timestamp = int(datetime.now().timestamp())
yield f"data: {json.dumps(create_chat_completion_data('', resolved_model, timestamp, 'stop'))}\n\n"
yield "data: [DONE]\n\n"
# Process non-streaming response with GizAI provider
async def process_non_streaming_response(request: ChatRequest) -> Dict[str, Any]:
chat_id = generate_chat_id()
resolved_model = resolve_model(request.model)
logger.info(f"Generated Chat ID: {chat_id} - Model: {resolved_model}")
# Instantiate the GizAI provider
gizai_provider = GizAI()
# Collect the responses
responses = []
async for response in gizai_provider.create_async_generator(
model=resolved_model,
messages=request.messages,
proxy=request.proxy # Assuming 'proxy' is part of ChatRequest; if not, adjust accordingly
):
if isinstance(response, ImageResponse):
# For image responses, collect image URLs
responses.append({"image_url": response.images, "alt": response.alt})
else:
# For text responses, append the text
responses.append(response)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": resolved_model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": responses},
"finish_reason": "stop",
}
],
"usage": None,
}
# Helper function to format prompt from messages
def format_prompt(messages: list[Message]) -> str:
# Implement the prompt formatting as per GizAI's requirements
# Placeholder implementation
formatted_messages = []
for msg in messages:
if isinstance(msg.content, str):
formatted_messages.append(msg.content)
elif isinstance(msg.content, list):
text = msg.content[0].get("text", "")
formatted_messages.append(text)
return "\n".join(formatted_messages)