File size: 3,477 Bytes
48d20bf
d5fa99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d20bf
 
 
d5fa99a
48d20bf
 
 
 
 
 
 
 
 
 
d5fa99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d20bf
c9a3348
 
 
 
 
 
 
 
d5fa99a
 
c9a3348
 
 
d5fa99a
c9a3348
d5fa99a
 
c9a3348
d5fa99a
 
 
48d20bf
d5fa99a
 
 
 
 
 
 
622ed28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import torch
from loguru import logger
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import CrossEncoder
from typing import List, Optional

# Initialize FastAPI app with documentation metadata
app = FastAPI(
    title="Document Reranker API",
    description="An API for reranking documents using a CrossEncoder model.",
    version="1.0",
    docs_url="/docs",  # Swagger UI
    redoc_url="/redoc",  # ReDoc UI
)

# Enable CORS (optional but useful for frontend integration)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins (modify as needed)
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Device selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.warning(
    f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
)

# Ensure a writable cache directory
os.makedirs("models_cache", exist_ok=True)

# Load the model at startup to avoid reloading for each request
try:
    model = CrossEncoder(
        "jinaai/jina-reranker-v1-turbo-en",
        trust_remote_code=True,
        device=DEVICE,
        cache_dir="models_cache",
    )
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise RuntimeError("Model loading failed. Check logs for details.")


class RerankerRequest(BaseModel):
    query: str = Field(..., description="The search query string")
    documents: List[str] = Field(..., description="List of documents to rerank")
    return_documents: bool = Field(
        True, description="Whether to return document content in results"
    )
    top_k: int = Field(3, description="Number of top results to return")


class RankedResult(BaseModel):
    score: float
    index: int
    document: Optional[str] = None


class RerankerResponse(BaseModel):
    results: List[RankedResult]


@app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
async def rerank_documents(request: RerankerRequest):
    """
    Reranks the given list of documents based on their relevance to the query.
    - **query**: The input query string.
    - **documents**: A list of documents to be reranked.
    - **return_documents**: Whether to include document content in results.
    - **top_k**: Number of top-ranked documents to return.
    Returns:
    - A list of ranked documents with scores and indexes.
    """
    try:
        # Prepare model input
        results = model.rank(
            request.query,
            request.documents,
            return_documents=request.return_documents,
            top_k=request.top_k,
        )

        # Format the results based on the model's output
        formatted_results = [
            RankedResult(
                score=result["score"],
                index=result["corpus_id"],
                document=result["text"] if request.return_documents else None,
            )
            for result in results
        ]

        # Format results
        return RerankerResponse(results=formatted_results)

    except Exception as e:
        logger.error(f"Error in reranking: {e}")
        raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")


# Run the FastAPI app with Uvicorn
if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)