Spaces:
Runtime error
Runtime error
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import StreamingResponse | |
from typing import Optional | |
import json | |
from time import time | |
import logging | |
from .api import InferenceApi | |
from .schemas import ( | |
GenerateRequest, | |
EmbeddingRequest, | |
EmbeddingResponse, | |
SystemStatusResponse, | |
ValidationResponse, | |
ChatCompletionRequest, | |
ChatCompletionResponse | |
) | |
router = APIRouter() | |
logger = logging.getLogger(__name__) | |
api = None | |
def init_router(inference_api: InferenceApi): | |
"""Initialize router with an already setup API instance""" | |
global api | |
api = inference_api | |
logger.info("Router initialized with Inference API instance") | |
async def generate_text(request: GenerateRequest): | |
"""Generate text response from prompt""" | |
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") | |
try: | |
response = await api.generate_response( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens | |
) | |
logger.info("Successfully generated response") | |
return {"generated_text": response} | |
except Exception as e: | |
logger.error(f"Error in generate_text endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_stream(request: GenerateRequest): | |
"""Generate streaming text response from prompt""" | |
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") | |
try: | |
return StreamingResponse( | |
api.generate_stream( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens | |
), | |
media_type="text/event-stream" | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_stream endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
"""OpenAI-compatible chat completion endpoint""" | |
logger.info(f"Received chat completion request with {len(request.messages)} messages") | |
try: | |
# Extract the last user message, or combine messages if needed | |
last_message = request.messages[-1].content | |
if request.stream: | |
# For streaming, we need to create a generator that yields OpenAI-compatible chunks | |
async def generate_stream(): | |
async for chunk in api.generate_stream( | |
prompt=last_message, | |
): | |
# Create a streaming response chunk in OpenAI format | |
response_chunk = { | |
"id": "chatcmpl-123", | |
"object": "chat.completion.chunk", | |
"created": int(time()), | |
"model": request.model, | |
"choices": [{ | |
"index": 0, | |
"delta": { | |
"content": chunk | |
}, | |
"finish_reason": None | |
}] | |
} | |
yield f"data: {json.dumps(response_chunk)}\n\n" | |
# Send the final chunk | |
yield f"data: [DONE]\n\n" | |
return StreamingResponse( | |
generate_stream(), | |
media_type="text/event-stream" | |
) | |
else: | |
# For non-streaming, generate the full response | |
response_text = await api.generate_response( | |
prompt=last_message, | |
) | |
# Convert to OpenAI format | |
return ChatCompletionResponse.from_response( | |
content=response_text, | |
model=request.model | |
) | |
except Exception as e: | |
logger.error(f"Error in chat completion endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embedding(request: EmbeddingRequest): | |
"""Generate embedding vector from text""" | |
logger.info(f"Received embedding request for text: {request.text[:50]}...") | |
try: | |
embedding = await api.generate_embedding(request.text) | |
logger.info(f"Successfully generated embedding of dimension {len(embedding)}") | |
return EmbeddingResponse( | |
embedding=embedding, | |
dimension=len(embedding) | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_embedding endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def check_system(): | |
"""Get system status from LLM Server""" | |
try: | |
return await api.check_system_status() | |
except Exception as e: | |
logger.error(f"Error checking system status: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def validate_system(): | |
"""Get system validation status from LLM Server""" | |
try: | |
return await api.validate_system() | |
except Exception as e: | |
logger.error(f"Error validating system: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_model(model_name: Optional[str] = None): | |
"""Initialize a model for use""" | |
try: | |
return await api.initialize_model(model_name) | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_embedding_model(model_name: Optional[str] = None): | |
"""Initialize a model specifically for embeddings""" | |
try: | |
return await api.initialize_embedding_model(model_name) | |
except Exception as e: | |
logger.error(f"Error initializing embedding model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def download_model(model_name: Optional[str] = None): | |
"""Download model files to local storage""" | |
try: | |
# Use model name from config if none provided | |
model_to_download = model_name or config["model"]["defaults"]["model_name"] | |
logger.info(f"Received request to download model: {model_to_download}") | |
result = await api.download_model(model_to_download) | |
logger.info(f"Successfully downloaded model: {model_to_download}") | |
return result | |
except Exception as e: | |
logger.error(f"Error downloading model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def shutdown_event(): | |
"""Clean up resources on shutdown""" | |
if api: | |
await api.cleanup() |