Spaces:
Sleeping
Sleeping
File size: 3,279 Bytes
cce774d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
from loguru import logger
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import CrossEncoder
from typing import List, Optional
# Initialize FastAPI app with documentation metadata
app = FastAPI(
title="Document Reranker API",
description="An API for reranking documents using a CrossEncoder model.",
version="1.0",
docs_url="/docs", # Swagger UI
redoc_url="/redoc", # ReDoc UI
)
# Enable CORS (optional but useful for frontend integration)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins (modify as needed)
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Device selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.warning(
f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
)
# Load the model at startup to avoid reloading for each request
model = CrossEncoder(
"jinaai/jina-reranker-v1-turbo-en",
trust_remote_code=True,
device=DEVICE,
cache_dir="models",
)
class RerankerRequest(BaseModel):
query: str = Field(..., description="The search query string")
documents: List[str] = Field(..., description="List of documents to rerank")
return_documents: bool = Field(
True, description="Whether to return document content in results"
)
top_k: int = Field(3, description="Number of top results to return")
class RankedResult(BaseModel):
score: float
index: int
document: Optional[str] = None
class RerankerResponse(BaseModel):
results: List[RankedResult]
@app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
async def rerank_documents(request: RerankerRequest):
"""
Reranks the given list of documents based on their relevance to the query.
- **query**: The input query string.
- **documents**: A list of documents to be reranked.
- **return_documents**: Whether to include document content in results.
- **top_k**: Number of top-ranked documents to return.
Returns:
- A list of ranked documents with scores and indexes.
"""
try:
# Call the model's rank method with the provided parameters
results = model.rank(
request.query,
request.documents,
return_documents=request.return_documents,
top_k=request.top_k,
)
# Format the results based on the model's output
formatted_results = [
RankedResult(
score=result["score"],
index=result["corpus_id"],
document=result["text"] if request.return_documents else None,
)
for result in results
]
return RerankerResponse(results=formatted_results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
# Run the FastAPI app with Uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|