import uuid from datetime import datetime import json from typing import Any, Dict import httpx from fastapi import HTTPException from api.models import ChatRequest from api.logger import setup_logger from api.config import MODEL_MAPPING, GIZAI_API_ENDPOINT, GIZAI_HEADERS logger = setup_logger(__name__) # List of models supported by GizAI GIZAI_CHAT_MODELS = [ 'chat-gemini-flash', 'chat-gemini-pro', 'chat-gpt4m', 'chat-gpt4', 'claude-sonnet', 'claude-haiku', 'llama-3-70b', 'llama-3-8b', 'mistral-large', 'chat-o1-mini' ] GIZAI_IMAGE_MODELS = [ 'flux1', 'sdxl', 'sd', 'sd35', ] GIZAI_MODELS = GIZAI_CHAT_MODELS + GIZAI_IMAGE_MODELS GIZAI_MODEL_ALIASES = { # Chat model aliases "gemini-flash": "chat-gemini-flash", "gemini-pro": "chat-gemini-pro", "gpt-4o-mini": "chat-gpt4m", "gpt-4o": "chat-gpt4", "claude-3.5-sonnet": "claude-sonnet", "claude-3-haiku": "claude-haiku", "llama-3.1-70b": "llama-3-70b", "llama-3.1-8b": "llama-3-8b", "o1-mini": "chat-o1-mini", # Image model aliases "sd-1.5": "sd", "sd-3.5": "sd35", "flux-schnell": "flux1", } def get_gizai_model(model: str) -> str: model = MODEL_MAPPING.get(model, model) if model in GIZAI_MODELS: return model elif model in GIZAI_MODEL_ALIASES: return GIZAI_MODEL_ALIASES[model] else: # Default model return 'chat-gemini-flash' def is_image_model(model: str) -> bool: return model in GIZAI_IMAGE_MODELS async def process_streaming_response(request: ChatRequest): # GizAI does not support streaming; handle as non-streaming return await process_non_streaming_response(request) async def process_non_streaming_response(request: ChatRequest): model = get_gizai_model(request.model) async with httpx.AsyncClient() as client: if is_image_model(model): # Image generation prompt = request.messages[-1].content data = { "model": model, "input": { "width": "1024", "height": "1024", "steps": 4, "output_format": "webp", "batch_size": 1, "mode": "plan", "prompt": prompt } } try: response = await client.post( GIZAI_API_ENDPOINT, headers=GIZAI_HEADERS, json=data, timeout=100, ) response.raise_for_status() response_data = response.json() if response_data.get('status') == 'completed' and response_data.get('output'): images = response_data['output'] # Return image response (e.g., URLs) return { "id": f"imggen-{uuid.uuid4()}", "object": "image_generation", "created": int(datetime.now().timestamp()), "model": request.model, "data": images, } else: raise HTTPException(status_code=500, detail="Image generation failed") except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request: {e}") raise HTTPException(status_code=500, detail=str(e)) else: # Chat completion messages_content = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages]) data = { "model": model, "input": { "messages": [ { "type": "human", "content": messages_content } ], "mode": "plan" }, "noStream": True } try: response = await client.post( GIZAI_API_ENDPOINT, headers=GIZAI_HEADERS, json=data, timeout=100, ) response.raise_for_status() response_data = response.json() output = response_data.get('output', '') return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": request.model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop", } ], "usage": None, } except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request: {e}") raise HTTPException(status_code=500, detail=str(e))