from fastapi import APIRouter, HTTPException from typing import Optional from .api import InferenceApi from .schemas import ( GenerateRequest, EmbeddingRequest, EmbeddingResponse, SystemStatusResponse, ValidationResponse, ChatCompletionRequest, ChatCompletionResponse ) import logging router = APIRouter() logger = logging.getLogger(__name__) api = None @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """OpenAI-compatible chat completion endpoint""" logger.info(f"Received chat completion request with {len(request.messages)} messages") try: # Extract the last user message, or combine messages if needed last_message = request.messages[-1].content if request.stream: # For streaming, we need to create a generator that yields OpenAI-compatible chunks async def generate_stream(): async for chunk in api.generate_stream( prompt=last_message, ): # Create a streaming response chunk in OpenAI format response_chunk = { "id": "chatcmpl-123", "object": "chat.completion.chunk", "created": int(time()), "model": request.model, "choices": [{ "index": 0, "delta": { "content": chunk }, "finish_reason": None }] } yield f"data: {json.dumps(response_chunk)}\n\n" # Send the final chunk yield f"data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) else: # For non-streaming, generate the full response response_text = await api.generate_response( prompt=last_message, ) # Convert to OpenAI format return ChatCompletionResponse.from_response( content=response_text, model=request.model ) except Exception as e: logger.error(f"Error in chat completion endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) async def init_router(config: dict): """Initialize router with config and Inference API instance""" global api api = InferenceApi(config) await api.setup() logger.info("Router initialized with Inference API instance") @router.post("/generate") async def generate_text(request: GenerateRequest): """Generate text response from prompt""" logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") try: response = await api.generate_response( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ) logger.info("Successfully generated response") return {"generated_text": response} except Exception as e: logger.error(f"Error in generate_text endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate/stream") async def generate_stream(request: GenerateRequest): """Generate streaming text response from prompt""" logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") try: return api.generate_stream( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ) except Exception as e: logger.error(f"Error in generate_stream endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/embedding", response_model=EmbeddingResponse) async def generate_embedding(request: EmbeddingRequest): """Generate embedding vector from text""" logger.info(f"Received embedding request for text: {request.text[:50]}...") try: embedding = await api.generate_embedding(request.text) logger.info(f"Successfully generated embedding of dimension {len(embedding)}") return EmbeddingResponse( embedding=embedding, dimension=len(embedding) ) except Exception as e: logger.error(f"Error in generate_embedding endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/status", response_model=SystemStatusResponse, summary="Check System Status", description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information") async def check_system(): """Get system status from LLM Server""" try: return await api.check_system_status() except Exception as e: logger.error(f"Error checking system status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/validate", response_model=ValidationResponse, summary="Validate System Configuration", description="Validates system configuration, folders, and model setup") async def validate_system(): """Get system validation status from LLM Server""" try: return await api.validate_system() except Exception as e: logger.error(f"Error validating system: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize", summary="Initialize default or specified model", description="Initialize model for use. Uses default model from config if none specified.") async def initialize_model(model_name: Optional[str] = None): """Initialize a model for use""" try: return await api.initialize_model(model_name) except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize/embedding", summary="Initialize embedding model", description="Initialize a separate model specifically for generating embeddings") async def initialize_embedding_model(model_name: Optional[str] = None): """Initialize a model specifically for embeddings""" try: return await api.initialize_embedding_model(model_name) except Exception as e: logger.error(f"Error initializing embedding model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.on_event("shutdown") async def shutdown_event(): """Clean up resources on shutdown""" if api: await api.close()