test24 / api /provider /gizai.py
Niansuh's picture
Update api/provider/gizai.py
ab35934 verified
raw
history blame
6.44 kB
import uuid
from datetime import datetime
import json
from typing import Any, Dict
import httpx
from fastapi import HTTPException
from api.logger import setup_logger
from api.config import MODEL_MAPPING
logger = setup_logger(__name__)
# Base URL and API Endpoint for GizAI
GIZAI_BASE_URL = "https://app.giz.ai"
GIZAI_API_ENDPOINT = f"{GIZAI_BASE_URL}/api/data/users/inferenceServer.infer"
# Headers for GizAI
GIZAI_HEADERS = {
'Accept': 'application/json, text/plain, */*',
'Accept-Language': 'en-US,en;q=0.9',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Origin': 'https://app.giz.ai',
'Pragma': 'no-cache',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"'
}
# List of models supported by GizAI
GIZAI_CHAT_MODELS = [
'chat-gemini-flash',
'chat-gemini-pro',
'chat-gpt4m',
'chat-gpt4',
'claude-sonnet',
'claude-haiku',
'llama-3-70b',
'llama-3-8b',
'mistral-large',
'chat-o1-mini'
]
GIZAI_IMAGE_MODELS = [
'flux1',
'sdxl',
'sd',
'sd35',
]
GIZAI_MODELS = GIZAI_CHAT_MODELS + GIZAI_IMAGE_MODELS
GIZAI_MODEL_ALIASES = {
# Chat model aliases
"gemini-flash": "chat-gemini-flash",
"gemini-pro": "chat-gemini-pro",
"gpt-4o-mini": "chat-gpt4m",
"gpt-4o": "chat-gpt4",
"claude-3.5-sonnet": "claude-sonnet",
"claude-3-haiku": "claude-haiku",
"llama-3.1-70b": "llama-3-70b",
"llama-3.1-8b": "llama-3-8b",
"o1-mini": "chat-o1-mini",
# Image model aliases
"sd-1.5": "sd",
"sd-3.5": "sd35",
"flux-schnell": "flux1",
}
def get_gizai_model(model: str) -> str:
model = MODEL_MAPPING.get(model, model)
if model in GIZAI_MODELS:
return model
elif model in GIZAI_MODEL_ALIASES:
return GIZAI_MODEL_ALIASES[model]
else:
# Default model
return 'chat-gemini-flash'
def is_image_model(model: str) -> bool:
return model in GIZAI_IMAGE_MODELS
async def process_streaming_response(request_data):
# GizAI does not support streaming; handle as non-streaming
response = await process_non_streaming_response(request_data)
# Return the response wrapped in an iterator
return iter([json.dumps(response)])
async def process_non_streaming_response(request_data):
model = get_gizai_model(request_data.get('model'))
async with httpx.AsyncClient() as client:
if is_image_model(model):
# Image generation
prompt = request_data['messages'][-1]['content']
data = {
"model": model,
"input": {
"width": "1024",
"height": "1024",
"steps": 4,
"output_format": "webp",
"batch_size": 1,
"mode": "plan",
"prompt": prompt
}
}
try:
response = await client.post(
GIZAI_API_ENDPOINT,
headers=GIZAI_HEADERS,
json=data,
timeout=100,
)
response.raise_for_status()
response_data = response.json()
if response_data.get('status') == 'completed' and response_data.get('output'):
images = response_data['output']
# Return image response (e.g., URLs)
return {
"id": f"imggen-{uuid.uuid4()}",
"object": "image_generation",
"created": int(datetime.now().timestamp()),
"model": request_data['model'],
"data": images,
}
else:
raise HTTPException(status_code=500, detail="Image generation failed")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request: {e}")
raise HTTPException(status_code=500, detail=str(e))
else:
# Chat completion
messages = request_data['messages']
messages_content = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
data = {
"model": model,
"input": {
"messages": [
{
"type": "human",
"content": messages_content
}
],
"mode": "plan"
},
"noStream": True
}
try:
response = await client.post(
GIZAI_API_ENDPOINT,
headers=GIZAI_HEADERS,
json=data,
timeout=100,
)
response.raise_for_status()
response_data = response.json()
output = response_data.get('output', '')
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": request_data['model'],
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": output},
"finish_reason": "stop",
}
],
"usage": None,
}
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request: {e}")
raise HTTPException(status_code=500, detail=str(e))