Inference-API / app /routes.py
AurelioAguirre's picture
FIRST
47031d7
raw
history blame
4.68 kB
from fastapi import APIRouter, HTTPException
from typing import Optional
from .api import InferenceApi
from .schemas import (
GenerateRequest,
EmbeddingRequest,
EmbeddingResponse,
SystemStatusResponse,
ValidationResponse
)
import logging
router = APIRouter()
logger = logging.getLogger(__name__)
api = None
def init_router(config: dict):
"""Initialize router with config and Inference API instance"""
global api
api = InferenceApi(config)
logger.info("Router initialized with Inference API instance")
@router.post("/generate")
async def generate_text(request: GenerateRequest):
"""Generate text response from prompt"""
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
try:
response = await api.generate_response(
prompt=request.prompt,
system_message=request.system_message,
max_new_tokens=request.max_new_tokens
)
logger.info("Successfully generated response")
return {"generated_text": response}
except Exception as e:
logger.error(f"Error in generate_text endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate/stream")
async def generate_stream(request: GenerateRequest):
"""Generate streaming text response from prompt"""
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
try:
return api.generate_stream(
prompt=request.prompt,
system_message=request.system_message,
max_new_tokens=request.max_new_tokens
)
except Exception as e:
logger.error(f"Error in generate_stream endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/embedding", response_model=EmbeddingResponse)
async def generate_embedding(request: EmbeddingRequest):
"""Generate embedding vector from text"""
logger.info(f"Received embedding request for text: {request.text[:50]}...")
try:
embedding = await api.generate_embedding(request.text)
logger.info(f"Successfully generated embedding of dimension {len(embedding)}")
return EmbeddingResponse(
embedding=embedding,
dimension=len(embedding)
)
except Exception as e:
logger.error(f"Error in generate_embedding endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/system/status",
response_model=SystemStatusResponse,
summary="Check System Status",
description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information")
async def check_system():
"""Get system status from LLM Server"""
try:
return await api.check_system_status()
except Exception as e:
logger.error(f"Error checking system status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/system/validate",
response_model=ValidationResponse,
summary="Validate System Configuration",
description="Validates system configuration, folders, and model setup")
async def validate_system():
"""Get system validation status from LLM Server"""
try:
return await api.validate_system()
except Exception as e:
logger.error(f"Error validating system: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/model/initialize",
summary="Initialize default or specified model",
description="Initialize model for use. Uses default model from config if none specified.")
async def initialize_model(model_name: Optional[str] = None):
"""Initialize a model for use"""
try:
return await api.initialize_model(model_name)
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/model/initialize/embedding",
summary="Initialize embedding model",
description="Initialize a separate model specifically for generating embeddings")
async def initialize_embedding_model(model_name: Optional[str] = None):
"""Initialize a model specifically for embeddings"""
try:
return await api.initialize_embedding_model(model_name)
except Exception as e:
logger.error(f"Error initializing embedding model: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.on_event("shutdown")
async def shutdown_event():
"""Clean up resources on shutdown"""
if api:
await api.close()