# api/utils.py from datetime import datetime import json import uuid import asyncio import random import string from typing import Any, Dict, Optional, AsyncGenerator import httpx from fastapi import HTTPException from api.config import ( MODELS, MODEL_ALIASES, DEFAULT_MODEL, API_ENDPOINT, get_headers_api_chat, BASE_URL, MODEL_PREFIXES, MODEL_REFERERS ) from api.models import ChatRequest, Message from api.logger import setup_logger logger = setup_logger(__name__) # Helper function to create a random alphanumeric chat ID def generate_chat_id(length: int = 7) -> str: characters = string.ascii_letters + string.digits return ''.join(random.choices(characters, k=length)) # Helper function to create a chat completion data chunk def create_chat_completion_data( content: str, model: str, timestamp: int, finish_reason: Optional[str] = None ) -> Dict[str, Any]: return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": {"content": content, "role": "assistant"}, "finish_reason": finish_reason, } ], "usage": None, } # Function to convert message to dictionary format, ensuring base64 data and optional model prefix def message_to_dict(message: Message, model_prefix: Optional[str] = None): content = message.content if isinstance(message.content, str) else message.content[0]["text"] if model_prefix: content = f"{model_prefix} {content}" if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]: # Ensure base64 images are always included for all models return { "role": message.role, "content": content, "data": { "imageBase64": message.content[1]["image_url"]["url"], "fileText": "", "title": "snapshot", }, } return {"role": message.role, "content": content} # Function to strip model prefix from content if present def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str: """Remove the model prefix from the response content if present.""" if model_prefix and content.startswith(model_prefix): logger.debug(f"Stripping prefix '{model_prefix}' from content.") return content[len(model_prefix):].strip() return content # Function to get the correct referer URL for logging def get_referer_url(chat_id: str, model: str) -> str: """Generate the referer URL based on specific models listed in MODEL_REFERERS.""" if model in MODEL_REFERERS: return f"{BASE_URL}/chat/{chat_id}?model={model}" return BASE_URL # Helper function to format messages def format_messages(messages: list[Message]) -> str: # Assuming messages need to be concatenated in some way return "\n".join([msg.content if isinstance(msg.content, str) else msg.content[0]["text"] for msg in messages]) # Process streaming response async def process_streaming_response(request: ChatRequest) -> AsyncGenerator[str, None]: chat_id = generate_chat_id() referer_url = get_referer_url(chat_id, request.model) logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}") model = request.model if request.model in MODELS else MODEL_ALIASES.get(request.model, DEFAULT_MODEL) model_prefix = MODEL_PREFIXES.get(model, "") headers_api_chat = get_headers_api_chat(referer_url) # Prepare data based on model type if model in ['flux1', 'sdxl', 'sd', 'sd35']: # Image models prompt = request.messages[-1].content if isinstance(request.messages[-1].content, str) else request.messages[-1].content[0]["text"] data = { "model": model, "input": { "width": "1024", "height": "1024", "steps": 4, "output_format": "webp", "batch_size": 1, "mode": "plan", "prompt": prompt } } else: # Chat models data = { "model": model, "input": { "messages": [ { "type": "human", "content": f"{model_prefix} {format_messages(request.messages)}" if model_prefix else format_messages(request.messages) } ], "mode": "plan" }, "noStream": False # Assuming streaming } async with httpx.AsyncClient() as client: try: async with client.post( API_ENDPOINT, headers=headers_api_chat, json=data, timeout=100 ) as response: response.raise_for_status() # Assuming the API returns a streaming response async for line in response.aiter_lines(): timestamp = int(datetime.now().timestamp()) if line: content = line # Depending on GizAI's response format, adjust parsing # Placeholder for content processing # Assuming content contains the message cleaned_content = strip_model_prefix(content, model_prefix) yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, model, timestamp))}\n\n" # Indicate end of stream yield f"data: {json.dumps(create_chat_completion_data('', model, int(datetime.now().timestamp()), 'stop'))}\n\n" yield "data: [DONE]\n\n" except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) # Process non-streaming response async def process_non_streaming_response(request: ChatRequest) -> Dict[str, Any]: chat_id = generate_chat_id() referer_url = get_referer_url(chat_id, request.model) logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}") model = request.model if request.model in MODELS else MODEL_ALIASES.get(request.model, DEFAULT_MODEL) model_prefix = MODEL_PREFIXES.get(model, "") headers_api_chat = get_headers_api_chat(referer_url) # Prepare data based on model type if model in ['flux1', 'sdxl', 'sd', 'sd35']: # Image models prompt = request.messages[-1].content if isinstance(request.messages[-1].content, str) else request.messages[-1].content[0]["text"] data = { "model": model, "input": { "width": "1024", "height": "1024", "steps": 4, "output_format": "webp", "batch_size": 1, "mode": "plan", "prompt": prompt } } else: # Chat models data = { "model": model, "input": { "messages": [ { "type": "human", "content": f"{model_prefix} {format_messages(request.messages)}" if model_prefix else format_messages(request.messages) } ], "mode": "plan" }, "noStream": True # Non-streaming } async with httpx.AsyncClient() as client: try: response = await client.post( API_ENDPOINT, headers=headers_api_chat, json=data, timeout=100 ) response.raise_for_status() response_data = response.json() # Process response based on GizAI's API response structure # Placeholder: assuming 'output' contains the generated content if model in ['flux1', 'sdxl', 'sd', 'sd35']: # Image models if response_data.get('status') == 'completed' and response_data.get('output'): images = response_data['output'] # Assuming images is a list of URLs # For non-streaming, return all images at once # Adjust according to actual response return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": "", "images": images}, "finish_reason": "stop", } ], "usage": None, } else: # Chat models # Assuming response_data contains the full response content = response_data.get('output', '') cleaned_content = strip_model_prefix(content, model_prefix) return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": model, "choices": [ { "index": 0, "message": {"role": "assistant", "content": cleaned_content}, "finish_reason": "stop", } ], "usage": None, } except httpx.HTTPStatusError as e: logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}") raise HTTPException(status_code=e.response.status_code, detail=str(e)) except httpx.RequestError as e: logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}") raise HTTPException(status_code=500, detail=str(e))