Jina_Re_Rank / app.py
Deep8591's picture
Upload 3 files
cce774d verified
raw
history blame
3.28 kB
import torch
from loguru import logger
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import CrossEncoder
from typing import List, Optional
# Initialize FastAPI app with documentation metadata
app = FastAPI(
title="Document Reranker API",
description="An API for reranking documents using a CrossEncoder model.",
version="1.0",
docs_url="/docs", # Swagger UI
redoc_url="/redoc", # ReDoc UI
)
# Enable CORS (optional but useful for frontend integration)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins (modify as needed)
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Device selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.warning(
f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
)
# Load the model at startup to avoid reloading for each request
model = CrossEncoder(
"jinaai/jina-reranker-v1-turbo-en",
trust_remote_code=True,
device=DEVICE,
cache_dir="models",
)
class RerankerRequest(BaseModel):
query: str = Field(..., description="The search query string")
documents: List[str] = Field(..., description="List of documents to rerank")
return_documents: bool = Field(
True, description="Whether to return document content in results"
)
top_k: int = Field(3, description="Number of top results to return")
class RankedResult(BaseModel):
score: float
index: int
document: Optional[str] = None
class RerankerResponse(BaseModel):
results: List[RankedResult]
@app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
async def rerank_documents(request: RerankerRequest):
"""
Reranks the given list of documents based on their relevance to the query.
- **query**: The input query string.
- **documents**: A list of documents to be reranked.
- **return_documents**: Whether to include document content in results.
- **top_k**: Number of top-ranked documents to return.
Returns:
- A list of ranked documents with scores and indexes.
"""
try:
# Call the model's rank method with the provided parameters
results = model.rank(
request.query,
request.documents,
return_documents=request.return_documents,
top_k=request.top_k,
)
# Format the results based on the model's output
formatted_results = [
RankedResult(
score=result["score"],
index=result["corpus_id"],
document=result["text"] if request.return_documents else None,
)
for result in results
]
return RerankerResponse(results=formatted_results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
# Run the FastAPI app with Uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)