from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from typing import Optional import json from time import time import logging from .api import InferenceApi from .schemas import ( GenerateRequest, EmbeddingRequest, EmbeddingResponse, SystemStatusResponse, ValidationResponse, ChatCompletionRequest, ChatCompletionResponse, QueryExpansionResponse, QueryExpansionRequest, ChunkRerankResponse, ChunkRerankRequest ) router = APIRouter() logger = logging.getLogger(__name__) api = None config = None def init_router(inference_api: InferenceApi, conf): """Initialize router with an already setup API instance""" global api, config api = inference_api config = conf logger.info("Router initialized with Inference API instance") @router.post("/generate") async def generate_text(request: GenerateRequest): """Generate text response from prompt""" logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") try: response = await api.generate_response( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ) logger.info("Successfully generated response") return {"generated_text": response} except Exception as e: logger.error(f"Error in generate_text endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate/stream") async def generate_stream(request: GenerateRequest): """Generate streaming text response from prompt""" logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") try: return StreamingResponse( api.generate_stream( prompt=request.prompt, system_message=request.system_message, max_new_tokens=request.max_new_tokens ), media_type="text/event-stream" ) except Exception as e: logger.error(f"Error in generate_stream endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """OpenAI-compatible chat completion endpoint""" logger.info(f"Received chat completion request with {len(request.messages)} messages") try: # Extract the last user message, or combine messages if needed last_message = request.messages[-1].content if request.stream: async def generate_stream(): async for chunk in api.generate_stream( prompt=last_message, ): # Parse the SSE format from LLM Server if chunk.startswith('data: '): chunk = chunk[6:].strip() # Remove "data: " and trailing \n\n chunk = chunk + " " logger.debug(f"Sending chunk: {chunk}...") if chunk == '[DONE]': continue response_chunk = { "id": "chatcmpl-123", "object": "chat.completion.chunk", "created": int(time()), "model": request.model, "choices": [{ "index": 0, "delta": { "content": chunk }, "finish_reason": None }] } yield f"data: {json.dumps(response_chunk)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } ) else: # For non-streaming, generate the full response response_text = await api.generate_response( prompt=last_message, ) # Convert to OpenAI format return ChatCompletionResponse.from_response( content=response_text, model=request.model ) except Exception as e: logger.error(f"Error in chat completion endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/expand_query", response_model=QueryExpansionResponse) async def expand_query(request: QueryExpansionRequest): """Expand a query for RAG processing""" logger.info(f"Received query expansion request: {request.query[:50]}...") try: result = await api.expand_query( query=request.query, system_message=request.system_message ) logger.info("Successfully expanded query") return result except FileNotFoundError as e: logger.error(f"Template file not found: {str(e)}") raise HTTPException(status_code=500, detail="Query expansion template not found") except json.JSONDecodeError as e: logger.error(f"Invalid JSON response from LLM: {str(e)}") raise HTTPException(status_code=500, detail="Invalid response format from LLM") except Exception as e: logger.error(f"Error in expand_query endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/rerank", response_model=ChunkRerankResponse) async def rerank_chunks(request: ChunkRerankRequest): """Rerank chunks based on their relevance to the query""" logger.info(f"Received reranking request for query: {request.query[:50]}...") try: result = await api.rerank_chunks( query=request.query, chunks=request.chunks, system_message=request.system_message ) logger.info(f"Successfully reranked {len(request.chunks)} chunks") return result except Exception as e: logger.error(f"Error in rerank_chunks endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/embedding", response_model=EmbeddingResponse) async def generate_embedding(request: EmbeddingRequest): """Generate embedding vector from text""" logger.info(f"Received embedding request for text: {request.text[:50]}...") try: embedding = await api.generate_embedding(request.text) logger.info(f"Successfully generated embedding of dimension {len(embedding)}") return EmbeddingResponse( embedding=embedding, dimension=len(embedding) ) except Exception as e: logger.error(f"Error in generate_embedding endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/status", response_model=SystemStatusResponse, summary="Check System Status", description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information") async def check_system(): """Get system status from LLM Server""" try: return await api.check_system_status() except Exception as e: logger.error(f"Error checking system status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/system/validate", response_model=ValidationResponse, summary="Validate System Configuration", description="Validates system configuration, folders, and model setup") async def validate_system(): """Get system validation status from LLM Server""" try: return await api.validate_system() except Exception as e: logger.error(f"Error validating system: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize", summary="Initialize default or specified model", description="Initialize model for use. Uses default model from config if none specified.") async def initialize_model(model_name: Optional[str] = None): """Initialize a model for use""" try: return await api.initialize_model(model_name) except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/initialize/embedding", summary="Initialize embedding model", description="Initialize a separate model specifically for generating embeddings") async def initialize_embedding_model(model_name: Optional[str] = None): """Initialize a model specifically for embeddings""" try: return await api.initialize_embedding_model(model_name) except Exception as e: logger.error(f"Error initializing embedding model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/model/download", summary="Download default or specified model", description="Downloads model files. Uses default model from config if none specified.") async def download_model(model_name: Optional[str] = None): """Download model files to local storage""" try: # Use model name from config if none provided model_to_download = model_name or config["model"]["defaults"]["model_name"] logger.info(f"Received request to download model: {model_to_download}") result = await api.download_model(model_to_download) logger.info(f"Successfully downloaded model: {model_to_download}") return result except Exception as e: logger.error(f"Error downloading model: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.on_event("shutdown") async def shutdown_event(): """Clean up resources on shutdown""" if api: await api.cleanup()