Spaces:
Runtime error
Runtime error
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import StreamingResponse | |
from typing import Optional | |
import json | |
from time import time | |
import logging | |
from .api import InferenceApi | |
from .schemas import ( | |
GenerateRequest, | |
EmbeddingRequest, | |
EmbeddingResponse, | |
SystemStatusResponse, | |
ValidationResponse, | |
ChatCompletionRequest, | |
ChatCompletionResponse, QueryExpansionResponse, QueryExpansionRequest, ChunkRerankResponse, ChunkRerankRequest | |
) | |
router = APIRouter() | |
logger = logging.getLogger(__name__) | |
api = None | |
config = None | |
def init_router(inference_api: InferenceApi, conf): | |
"""Initialize router with an already setup API instance""" | |
global api, config | |
api = inference_api | |
config = conf | |
logger.info("Router initialized with Inference API instance") | |
async def generate_text(request: GenerateRequest): | |
"""Generate text response from prompt""" | |
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") | |
try: | |
response = await api.generate_response( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens | |
) | |
logger.info("Successfully generated response") | |
return {"generated_text": response} | |
except Exception as e: | |
logger.error(f"Error in generate_text endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_stream(request: GenerateRequest): | |
"""Generate streaming text response from prompt""" | |
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") | |
try: | |
return StreamingResponse( | |
api.generate_stream( | |
prompt=request.prompt, | |
system_message=request.system_message, | |
max_new_tokens=request.max_new_tokens | |
), | |
media_type="text/event-stream" | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_stream endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
"""OpenAI-compatible chat completion endpoint""" | |
logger.info(f"Received chat completion request with {len(request.messages)} messages") | |
try: | |
# Extract the last user message, or combine messages if needed | |
last_message = request.messages[-1].content | |
if request.stream: | |
async def generate_stream(): | |
async for chunk in api.generate_stream( | |
prompt=last_message, | |
): | |
# Parse the SSE format from LLM Server | |
if chunk.startswith('data: '): | |
chunk = chunk[6:].replace("\n\n", "") # Remove "data: " and trailing \n\n | |
#chunk = chunk + " " | |
logger.debug(f"Sending chunk: {chunk}...") | |
if chunk == '[DONE]': | |
continue | |
response_chunk = { | |
"id": "chatcmpl-123", | |
"object": "chat.completion.chunk", | |
"created": int(time()), | |
"model": request.model, | |
"choices": [{ | |
"index": 0, | |
"delta": { | |
"content": chunk | |
}, | |
"finish_reason": None | |
}] | |
} | |
yield f"data: {json.dumps(response_chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
return StreamingResponse( | |
generate_stream(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
} | |
) | |
else: | |
# For non-streaming, generate the full response | |
response_text = await api.generate_response( | |
prompt=last_message, | |
) | |
# Convert to OpenAI format | |
return ChatCompletionResponse.from_response( | |
content=response_text, | |
model=request.model | |
) | |
except Exception as e: | |
logger.error(f"Error in chat completion endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def expand_query(request: QueryExpansionRequest): | |
"""Expand a query for RAG processing""" | |
logger.info(f"Received query expansion request: {request.query[:50]}...") | |
try: | |
result = await api.expand_query( | |
query=request.query, | |
system_message=request.system_message | |
) | |
logger.info("Successfully expanded query") | |
return result | |
except FileNotFoundError as e: | |
logger.error(f"Template file not found: {str(e)}") | |
raise HTTPException(status_code=500, detail="Query expansion template not found") | |
except json.JSONDecodeError as e: | |
logger.error(f"Invalid JSON response from LLM: {str(e)}") | |
raise HTTPException(status_code=500, detail="Invalid response format from LLM") | |
except Exception as e: | |
logger.error(f"Error in expand_query endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def rerank_chunks(request: ChunkRerankRequest): | |
"""Rerank chunks based on their relevance to the query""" | |
logger.info(f"Received reranking request for query: {request.query[:50]}...") | |
try: | |
result = await api.rerank_chunks( | |
query=request.query, | |
chunks=request.chunks, | |
system_message=request.system_message | |
) | |
logger.info(f"Successfully reranked {len(request.chunks)} chunks") | |
return result | |
except Exception as e: | |
logger.error(f"Error in rerank_chunks endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embedding(request: EmbeddingRequest): | |
"""Generate embedding vector from text""" | |
logger.info(f"Received embedding request for text: {request.text[:50]}...") | |
try: | |
embedding = await api.generate_embedding(request.text) | |
logger.info(f"Successfully generated embedding of dimension {len(embedding)}") | |
return EmbeddingResponse( | |
embedding=embedding, | |
dimension=len(embedding) | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_embedding endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def check_system(): | |
"""Get system status from LLM Server""" | |
try: | |
return await api.check_system_status() | |
except Exception as e: | |
logger.error(f"Error checking system status: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def validate_system(): | |
"""Get system validation status from LLM Server""" | |
try: | |
return await api.validate_system() | |
except Exception as e: | |
logger.error(f"Error validating system: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_model(model_name: Optional[str] = None): | |
"""Initialize a model for use""" | |
try: | |
return await api.initialize_model(model_name) | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def initialize_embedding_model(model_name: Optional[str] = None): | |
"""Initialize a model specifically for embeddings""" | |
try: | |
return await api.initialize_embedding_model(model_name) | |
except Exception as e: | |
logger.error(f"Error initializing embedding model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def download_model(model_name: Optional[str] = None): | |
"""Download model files to local storage""" | |
try: | |
# Use model name from config if none provided | |
model_to_download = model_name or config["model"]["defaults"]["model_name"] | |
logger.info(f"Received request to download model: {model_to_download}") | |
result = await api.download_model(model_to_download) | |
logger.info(f"Successfully downloaded model: {model_to_download}") | |
return result | |
except Exception as e: | |
logger.error(f"Error downloading model: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def shutdown_event(): | |
"""Clean up resources on shutdown""" | |
if api: | |
await api.cleanup() |