import uuid from datetime import datetime import json from typing import Any, Dict, Optional import httpx from fastapi import HTTPException from api.models import ChatRequest from api.logger import setup_logger logger = setup_logger(__name__) # Base URL for giz.ai GIZAI_BASE_URL = "https://app.giz.ai" GIZAI_API_ENDPOINT = f"{GIZAI_BASE_URL}/api/data/users/inferenceServer.infer" # Headers for giz.ai GIZAI_HEADERS = { 'Accept': 'application/json, text/plain, */*', 'Accept-Language': 'en-US,en;q=0.9', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'Content-Type': 'application/json', 'Origin': 'https://app.giz.ai', 'Pragma': 'no-cache', 'Sec-Fetch-Dest': 'empty', 'Sec-Fetch-Mode': 'cors', 'Sec-Fetch-Site': 'same-origin', 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36', 'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"', 'sec-ch-ua-mobile': '?0', 'sec-ch-ua-platform': '"Linux"' } # List of models supported by giz.ai GIZAI_CHAT_MODELS = [ 'chat-gemini-flash', 'chat-gemini-pro', 'chat-gpt4m', 'chat-gpt4', 'claude-sonnet', 'claude-haiku', 'llama-3-70b', 'llama-3-8b', 'mistral-large', 'chat-o1-mini' ] GIZAI_IMAGE_MODELS = [ 'flux1', 'sdxl', 'sd', 'sd35', ] GIZAI_MODELS = GIZAI_CHAT_MODELS + GIZAI_IMAGE_MODELS GIZAI_MODEL_ALIASES = { # Chat model aliases "gemini-flash": "chat-gemini-flash", "gemini-pro": "chat-gemini-pro", "gpt-4o-mini": "chat-gpt4m", "gpt-4o": "chat-gpt4", "claude-3.5-sonnet": "claude-sonnet", "claude-3-haiku": "claude-haiku", "llama-3.1-70b": "llama-3-70b", "llama-3.1-8b": "llama-3-8b", "o1-mini": "chat-o1-mini", # Image model aliases "sd-1.5": "sd", "sd-3.5": "sd35", "flux-schnell": "flux1", } def get_gizai_model(model: str) -> str: if model in GIZAI_MODELS: return model elif model in GIZAI_MODEL_ALIASES: return GIZAI_MODEL_ALIASES[model] else: # Default model return 'chat-gemini-flash' def is_image_model(model: str) -> bool: return model in GIZAI_IMAGE_MODELS async def process_streaming_response(request: ChatRequest): # giz.ai does not support streaming # So we can raise an error or process as non-streaming return await process_non_streaming_response(request) async def process_non_streaming_response(request: ChatRequest): model = get_gizai_model(request.model) async with httpx.AsyncClient() as client: if is_image_model(model): # Image generation prompt = request.messages[-1].content data = { "model": model, "input": { "width": "1024", "height": "1024", "steps": 4, "output_format": "webp", "batch_size": 1, "mode": "plan", "prompt": prompt } } try: response = await client.post( GIZAI_API_ENDPOINT, headers=GIZAI_HEADERS, json=data, timeout=100, ) response.raise_for_status() response_data = response.json() if response_data.get('status') == 'completed' and response_data.get('output'): images = response_data['output'] # Return image response (e.g., URLs) return { "id": f"imggen-{uuid.uuid4()}", "object": "image_generation", "created": int(datetime.now().timestamp()), "model": request.model, "data": images, } else: raise HTTPException(status_code=500, detail="Image generation failed") except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request: {e}") raise HTTPException(status_code=500, detail=str(e)) else: # Chat completion messages_content = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages]) data = { "model": model, "input": { "messages": [ { "type": "human", "content": messages_content } ], "mode": "plan" }, "noStream": True } try: response = await client.post( GIZAI_API_ENDPOINT, headers=GIZAI_HEADERS, json=data, timeout=100, ) response.raise_for_status() response_data = response.json() output = response_data.get('output', '') return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": request.model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop", } ], "usage": None, } except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request: {e}") raise HTTPException(status_code=500, detail=str(e))