Spaces:
Runtime error
Runtime error
from fastapi import APIRouter, HTTPException | |
from typing import Optional | |
from .api import InferenceApi | |
from .schemas import ( | |
GenerateRequest, | |
EmbeddingRequest, | |
EmbeddingResponse, | |
SystemStatusResponse, | |
ValidationResponse | |
) | |
import logging | |
router = APIRouter() | |
logger = logging.getLogger(__name__) | |
api = None | |
def init_router(config: dict): | |
"""Initialize router with config and Inference API instance""" | |
global api | |
api = InferenceApi(config) | |
logger.info("Router initialized with Inference API instance") | |
async def generate_text(request: GenerateRequest): | |
"""Generate text response from prompt""" | |
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") | |
try: | |
response = await api.generate_response( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens | |
) | |
logger.info("Successfully generated response") | |
return {"generated_text": response} | |
except Exception as e: | |
logger.error(f"Error in generate_text endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_stream(request: GenerateRequest): | |
"""Generate streaming text response from prompt""" | |
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") | |
try: | |
return api.generate_stream( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_stream endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embedding(request: EmbeddingRequest): | |
"""Generate embedding vector from text""" | |
logger.info(f"Received embedding request for text: {request.text[:50]}...") | |
try: | |
embedding = await api.generate_embedding(request.text) | |
logger.info(f"Successfully generated embedding of dimension {len(embedding)}") | |
return EmbeddingResponse( | |
embedding=embedding, | |
dimension=len(embedding) | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_embedding endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def check_system(): | |
"""Get system status from LLM Server""" | |
try: | |
return await api.check_system_status() | |
except Exception as e: | |
logger.error(f"Error checking system status: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def validate_system(): | |
"""Get system validation status from LLM Server""" | |
try: | |
return await api.validate_system() | |
except Exception as e: | |
logger.error(f"Error validating system: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_model(model_name: Optional[str] = None): | |
"""Initialize a model for use""" | |
try: | |
return await api.initialize_model(model_name) | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_embedding_model(model_name: Optional[str] = None): | |
"""Initialize a model specifically for embeddings""" | |
try: | |
return await api.initialize_embedding_model(model_name) | |
except Exception as e: | |
logger.error(f"Error initializing embedding model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def shutdown_event(): | |
"""Clean up resources on shutdown""" | |
if api: | |
await api.close() |