import os import torch from loguru import logger from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from sentence_transformers import CrossEncoder from typing import List, Optional # Initialize FastAPI app with documentation metadata app = FastAPI( title="Document Reranker API", description="An API for reranking documents using a CrossEncoder model.", version="1.0", docs_url="/docs", # Swagger UI redoc_url="/redoc", # ReDoc UI ) # Enable CORS (optional but useful for frontend integration) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins (modify as needed) allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Device selection DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.warning( f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})" ) # Ensure a writable cache directory os.makedirs("models_cache", exist_ok=True) # Load the model at startup to avoid reloading for each request try: model = CrossEncoder( "jinaai/jina-reranker-v1-turbo-en", trust_remote_code=True, device=DEVICE, cache_dir="models_cache", ) except Exception as e: logger.error(f"Failed to load model: {e}") raise RuntimeError("Model loading failed. Check logs for details.") class RerankerRequest(BaseModel): query: str = Field(..., description="The search query string") documents: List[str] = Field(..., description="List of documents to rerank") return_documents: bool = Field( True, description="Whether to return document content in results" ) top_k: int = Field(3, description="Number of top results to return") class RankedResult(BaseModel): score: float index: int document: Optional[str] = None class RerankerResponse(BaseModel): results: List[RankedResult] @app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"]) async def rerank_documents(request: RerankerRequest): """ Reranks the given list of documents based on their relevance to the query. - **query**: The input query string. - **documents**: A list of documents to be reranked. - **return_documents**: Whether to include document content in results. - **top_k**: Number of top-ranked documents to return. Returns: - A list of ranked documents with scores and indexes. """ try: # Prepare model input results = model.rank( request.query, request.documents, return_documents=request.return_documents, top_k=request.top_k, ) # Format the results based on the model's output formatted_results = [ RankedResult( score=result["score"], index=result["corpus_id"], document=result["text"] if request.return_documents else None, ) for result in results ] # Format results return RerankerResponse(results=formatted_results) except Exception as e: logger.error(f"Error in reranking: {e}") raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}") # Run the FastAPI app with Uvicorn if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)