from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from typing import Optional import json from time import time import logging from .api import InferenceApi from .schemas import ( GenerateRequest, EmbeddingRequest, EmbeddingResponse, SystemStatusResponse, ValidationResponse, ChatCompletionRequest, ChatCompletionResponse ) router = APIRouter() logger = logging.getLogger(__name__) api = None def init_router(inference_api: InferenceApi): """Initialize router with an already setup API instance""" global api api = inference_api logger.info("Router initialized with Inference API instance") @router.post("/generate") async def generate_text(request: GenerateRequest): """Generate text response from prompt""" logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") try: response = await api.generate_response( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ) logger.info("Successfully generated response") return {"generated_text": response} except Exception as e: logger.error(f"Error in generate_text endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate/stream") async def generate_stream(request: GenerateRequest): """Generate streaming text response from prompt""" logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") try: return StreamingResponse( api.generate_stream( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ), media_type="text/event-stream" ) except Exception as e: logger.error(f"Error in generate_stream endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """OpenAI-compatible chat completion endpoint""" logger.info(f"Received chat completion request with {len(request.messages)} messages") try: # Extract the last user message, or combine messages if needed last_message = request.messages[-1].content if request.stream: # For streaming, we need to create a generator that yields OpenAI-compatible chunks async def generate_stream(): async for chunk in api.generate_stream( prompt=last_message, ): # Create a streaming response chunk in OpenAI format response_chunk = { "id": "chatcmpl-123", "object": "chat.completion.chunk", "created": int(time()), "model": request.model, "choices": [{ "index": 0, "delta": { "content": chunk }, "finish_reason": None }] } yield f"data: {json.dumps(response_chunk)}\n\n" # Send the final chunk yield f"data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) else: # For non-streaming, generate the full response response_text = await api.generate_response( prompt=last_message, ) # Convert to OpenAI format return ChatCompletionResponse.from_response( content=response_text, model=request.model ) except Exception as e: logger.error(f"Error in chat completion endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/embedding", response_model=EmbeddingResponse) async def generate_embedding(request: EmbeddingRequest): """Generate embedding vector from text""" logger.info(f"Received embedding request for text: {request.text[:50]}...") try: embedding = await api.generate_embedding(request.text) logger.info(f"Successfully generated embedding of dimension {len(embedding)}") return EmbeddingResponse( embedding=embedding, dimension=len(embedding) ) except Exception as e: logger.error(f"Error in generate_embedding endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/status", response_model=SystemStatusResponse, summary="Check System Status", description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information") async def check_system(): """Get system status from LLM Server""" try: return await api.check_system_status() except Exception as e: logger.error(f"Error checking system status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/validate", response_model=ValidationResponse, summary="Validate System Configuration", description="Validates system configuration, folders, and model setup") async def validate_system(): """Get system validation status from LLM Server""" try: return await api.validate_system() except Exception as e: logger.error(f"Error validating system: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize", summary="Initialize default or specified model", description="Initialize model for use. Uses default model from config if none specified.") async def initialize_model(model_name: Optional[str] = None): """Initialize a model for use""" try: return await api.initialize_model(model_name) except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize/embedding", summary="Initialize embedding model", description="Initialize a separate model specifically for generating embeddings") async def initialize_embedding_model(model_name: Optional[str] = None): """Initialize a model specifically for embeddings""" try: return await api.initialize_embedding_model(model_name) except Exception as e: logger.error(f"Error initializing embedding model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/download", summary="Download default or specified model", description="Downloads model files. Uses default model from config if none specified.") async def download_model(model_name: Optional[str] = None): """Download model files to local storage""" try: # Use model name from config if none provided model_to_download = model_name or config["model"]["defaults"]["model_name"] logger.info(f"Received request to download model: {model_to_download}") result = await api.download_model(model_to_download) logger.info(f"Successfully downloaded model: {model_to_download}") return result except Exception as e: logger.error(f"Error downloading model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.on_event("shutdown") async def shutdown_event(): """Clean up resources on shutdown""" if api: await api.cleanup()