test24 / api /utils.py
Niansuh's picture
Create utils.py
261bb88 verified
raw
history blame
7.33 kB
from datetime import datetime
import json
from typing import AsyncGenerator, Union
import uuid
import aiohttp
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from api.config import GIZAI_API_ENDPOINT, GIZAI_BASE_URL
from api.models import ChatRequest, ImageResponseModel, ChatCompletionResponse
from api.logger import setup_logger
logger = setup_logger(__name__)
class GizAI:
# Chat models
default_model = 'chat-gemini-flash'
chat_models = [
default_model,
'chat-gemini-pro',
'chat-gpt4m',
'chat-gpt4',
'claude-sonnet',
'claude-haiku',
'llama-3-70b',
'llama-3-8b',
'mistral-large',
'chat-o1-mini'
]
# Image models
image_models = [
'flux1',
'sdxl',
'sd',
'sd35',
]
models = [*chat_models, *image_models]
model_aliases = {
# Chat model aliases
"gemini-flash": "chat-gemini-flash",
"gemini-pro": "chat-gemini-pro",
"gpt-4o-mini": "chat-gpt4m",
"gpt-4o": "chat-gpt4",
"claude-3.5-sonnet": "claude-sonnet",
"claude-3-haiku": "claude-haiku",
"llama-3.1-70b": "llama-3-70b",
"llama-3.1-8b": "llama-3-8b",
"o1-mini": "chat-o1-mini",
# Image model aliases
"sd-1.5": "sd",
"sd-3.5": "sd35",
"flux-schnell": "flux1",
}
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model
@classmethod
def is_image_model(cls, model: str) -> bool:
return model in cls.image_models
async def process_gizai_response(request: ChatRequest, model: str) -> Union[AsyncGenerator[str, None], JSONResponse]:
async with aiohttp.ClientSession() as session:
if GizAI.is_image_model(model):
# Image generation
prompt = request.messages[-1].content if isinstance(request.messages[-1].content, str) else request.messages[-1].content[0].get("text", "")
data = {
"model": model,
"input": {
"width": "1024",
"height": "1024",
"steps": 4,
"output_format": "webp",
"batch_size": 1,
"mode": "plan",
"prompt": prompt
}
}
try:
async with session.post(
GIZAI_API_ENDPOINT,
headers={
'Accept': 'application/json, text/plain, */*',
'Accept-Language': 'en-US,en;q=0.9',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Origin': 'https://app.giz.ai',
'Pragma': 'no-cache',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"'
},
json=data
) as response:
response.raise_for_status()
response_data = await response.json()
if response_data.get('status') == 'completed' and response_data.get('output'):
images = response_data['output']
return {"images": images, "alt": "Generated Image"}
else:
raise HTTPException(status_code=500, detail="Image generation failed.")
except aiohttp.ClientResponseError as e:
logger.error(f"HTTP error occurred: {e.status} - {e.message}")
raise HTTPException(status_code=e.status, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
else:
# Chat completion
messages_formatted = [
{
"type": "human",
"content": msg.content if isinstance(msg.content, str) else msg.content[0].get("text", "")
} for msg in request.messages
]
data = {
"model": model,
"input": {
"messages": messages_formatted,
"mode": "plan"
},
"noStream": not request.stream
}
try:
async with session.post(
GIZAI_API_ENDPOINT,
headers={
'Accept': 'application/json, text/plain, */*',
'Accept-Language': 'en-US,en;q=0.9',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Origin': 'https://app.giz.ai',
'Pragma': 'no-cache',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"'
},
json=data
) as response:
response.raise_for_status()
if request.stream:
# Handle streaming response
async def stream_response():
async for line in response.content:
if line:
decoded_line = line.decode('utf-8').strip()
if decoded_line.startswith("data:"):
content = decoded_line.replace("data: ", "")
yield f"data: {content}\n\n"
return stream_response()
else:
# Handle non-streaming response
result = await response.json()
return result.get('output', '')
except aiohttp.ClientResponseError as e:
logger.error(f"HTTP error occurred: {e.status} - {e.message}")
raise HTTPException(status_code=e.status, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))