Spaces:
Sleeping
Sleeping
import torch | |
from loguru import logger | |
from pydantic import BaseModel, Field | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from sentence_transformers import CrossEncoder | |
from typing import List, Optional | |
# Initialize FastAPI app with documentation metadata | |
app = FastAPI( | |
title="Document Reranker API", | |
description="An API for reranking documents using a CrossEncoder model.", | |
version="1.0", | |
docs_url="/docs", # Swagger UI | |
redoc_url="/redoc", # ReDoc UI | |
) | |
# Enable CORS (optional but useful for frontend integration) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allow all origins (modify as needed) | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Device selection | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.warning( | |
f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})" | |
) | |
# Load the model at startup to avoid reloading for each request | |
model = CrossEncoder( | |
"jinaai/jina-reranker-v1-turbo-en", | |
trust_remote_code=True, | |
device=DEVICE, | |
cache_dir="models", | |
) | |
class RerankerRequest(BaseModel): | |
query: str = Field(..., description="The search query string") | |
documents: List[str] = Field(..., description="List of documents to rerank") | |
return_documents: bool = Field( | |
True, description="Whether to return document content in results" | |
) | |
top_k: int = Field(3, description="Number of top results to return") | |
class RankedResult(BaseModel): | |
score: float | |
index: int | |
document: Optional[str] = None | |
class RerankerResponse(BaseModel): | |
results: List[RankedResult] | |
async def rerank_documents(request: RerankerRequest): | |
""" | |
Reranks the given list of documents based on their relevance to the query. | |
- **query**: The input query string. | |
- **documents**: A list of documents to be reranked. | |
- **return_documents**: Whether to include document content in results. | |
- **top_k**: Number of top-ranked documents to return. | |
Returns: | |
- A list of ranked documents with scores and indexes. | |
""" | |
try: | |
# Call the model's rank method with the provided parameters | |
results = model.rank( | |
request.query, | |
request.documents, | |
return_documents=request.return_documents, | |
top_k=request.top_k, | |
) | |
# Format the results based on the model's output | |
formatted_results = [ | |
RankedResult( | |
score=result["score"], | |
index=result["corpus_id"], | |
document=result["text"] if request.return_documents else None, | |
) | |
for result in results | |
] | |
return RerankerResponse(results=formatted_results) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}") | |
# Run the FastAPI app with Uvicorn | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |