Inference-API / main /routes.py
AurelioAguirre's picture
added openAI schema based endpoint and response
02fd6bb
raw
history blame
6.85 kB
from fastapi import APIRouter, HTTPException
from typing import Optional
from .api import InferenceApi
from .schemas import (
GenerateRequest,
EmbeddingRequest,
EmbeddingResponse,
SystemStatusResponse,
ValidationResponse,
ChatCompletionRequest,
ChatCompletionResponse
)
import logging
router = APIRouter()
logger = logging.getLogger(__name__)
api = None
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint"""
logger.info(f"Received chat completion request with {len(request.messages)} messages")
try:
# Extract the last user message, or combine messages if needed
last_message = request.messages[-1].content
if request.stream:
# For streaming, we need to create a generator that yields OpenAI-compatible chunks
async def generate_stream():
async for chunk in api.generate_stream(
prompt=last_message,
):
# Create a streaming response chunk in OpenAI format
response_chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": int(time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {
"content": chunk
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(response_chunk)}\n\n"
# Send the final chunk
yield f"data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream"
)
else:
# For non-streaming, generate the full response
response_text = await api.generate_response(
prompt=last_message,
)
# Convert to OpenAI format
return ChatCompletionResponse.from_response(
content=response_text,
model=request.model
)
except Exception as e:
logger.error(f"Error in chat completion endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def init_router(config: dict):
"""Initialize router with config and Inference API instance"""
global api
api = InferenceApi(config)
await api.setup()
logger.info("Router initialized with Inference API instance")
@router.post("/generate")
async def generate_text(request: GenerateRequest):
"""Generate text response from prompt"""
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
try:
response = await api.generate_response(
prompt=request.prompt,
system_message=request.system_message,
max_new_tokens=request.max_new_tokens
)
logger.info("Successfully generated response")
return {"generated_text": response}
except Exception as e:
logger.error(f"Error in generate_text endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate/stream")
async def generate_stream(request: GenerateRequest):
"""Generate streaming text response from prompt"""
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
try:
return api.generate_stream(
prompt=request.prompt,
system_message=request.system_message,
max_new_tokens=request.max_new_tokens
)
except Exception as e:
logger.error(f"Error in generate_stream endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/embedding", response_model=EmbeddingResponse)
async def generate_embedding(request: EmbeddingRequest):
"""Generate embedding vector from text"""
logger.info(f"Received embedding request for text: {request.text[:50]}...")
try:
embedding = await api.generate_embedding(request.text)
logger.info(f"Successfully generated embedding of dimension {len(embedding)}")
return EmbeddingResponse(
embedding=embedding,
dimension=len(embedding)
)
except Exception as e:
logger.error(f"Error in generate_embedding endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/system/status",
response_model=SystemStatusResponse,
summary="Check System Status",
description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information")
async def check_system():
"""Get system status from LLM Server"""
try:
return await api.check_system_status()
except Exception as e:
logger.error(f"Error checking system status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/system/validate",
response_model=ValidationResponse,
summary="Validate System Configuration",
description="Validates system configuration, folders, and model setup")
async def validate_system():
"""Get system validation status from LLM Server"""
try:
return await api.validate_system()
except Exception as e:
logger.error(f"Error validating system: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/model/initialize",
summary="Initialize default or specified model",
description="Initialize model for use. Uses default model from config if none specified.")
async def initialize_model(model_name: Optional[str] = None):
"""Initialize a model for use"""
try:
return await api.initialize_model(model_name)
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/model/initialize/embedding",
summary="Initialize embedding model",
description="Initialize a separate model specifically for generating embeddings")
async def initialize_embedding_model(model_name: Optional[str] = None):
"""Initialize a model specifically for embeddings"""
try:
return await api.initialize_embedding_model(model_name)
except Exception as e:
logger.error(f"Error initializing embedding model: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.on_event("shutdown")
async def shutdown_event():
"""Clean up resources on shutdown"""
if api:
await api.close()