from fastapi import APIRouter, HTTPException from typing import Optional from .api import InferenceApi from .schemas import ( GenerateRequest, EmbeddingRequest, EmbeddingResponse, SystemStatusResponse, ValidationResponse ) import logging router = APIRouter() logger = logging.getLogger(__name__) api = None def init_router(config: dict): """Initialize router with config and Inference API instance""" global api api = InferenceApi() logger.info("Router initialized with Inference API instance") @router.post("/generate") async def generate_text(request: GenerateRequest): """Generate text response from prompt""" logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") try: response = await api.generate_response( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ) logger.info("Successfully generated response") return {"generated_text": response} except Exception as e: logger.error(f"Error in generate_text endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate/stream") async def generate_stream(request: GenerateRequest): """Generate streaming text response from prompt""" logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") try: return api.generate_stream( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ) except Exception as e: logger.error(f"Error in generate_stream endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/embedding", response_model=EmbeddingResponse) async def generate_embedding(request: EmbeddingRequest): """Generate embedding vector from text""" logger.info(f"Received embedding request for text: {request.text[:50]}...") try: embedding = await api.generate_embedding(request.text) logger.info(f"Successfully generated embedding of dimension {len(embedding)}") return EmbeddingResponse( embedding=embedding, dimension=len(embedding) ) except Exception as e: logger.error(f"Error in generate_embedding endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/status", response_model=SystemStatusResponse, summary="Check System Status", description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information") async def check_system(): """Get system status from LLM Server""" try: return await api.check_system_status() except Exception as e: logger.error(f"Error checking system status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/validate", response_model=ValidationResponse, summary="Validate System Configuration", description="Validates system configuration, folders, and model setup") async def validate_system(): """Get system validation status from LLM Server""" try: return await api.validate_system() except Exception as e: logger.error(f"Error validating system: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize", summary="Initialize default or specified model", description="Initialize model for use. Uses default model from config if none specified.") async def initialize_model(model_name: Optional[str] = None): """Initialize a model for use""" try: return await api.initialize_model(model_name) except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize/embedding", summary="Initialize embedding model", description="Initialize a separate model specifically for generating embeddings") async def initialize_embedding_model(model_name: Optional[str] = None): """Initialize a model specifically for embeddings""" try: return await api.initialize_embedding_model(model_name) except Exception as e: logger.error(f"Error initializing embedding model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.on_event("shutdown") async def shutdown_event(): """Clean up resources on shutdown""" if api: await api.close()