File size: 3,279 Bytes
cce774d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from loguru import logger
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import CrossEncoder
from typing import List, Optional

# Initialize FastAPI app with documentation metadata
app = FastAPI(
    title="Document Reranker API",
    description="An API for reranking documents using a CrossEncoder model.",
    version="1.0",
    docs_url="/docs",  # Swagger UI
    redoc_url="/redoc",  # ReDoc UI
)

# Enable CORS (optional but useful for frontend integration)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins (modify as needed)
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Device selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.warning(
    f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
)

# Load the model at startup to avoid reloading for each request
model = CrossEncoder(
    "jinaai/jina-reranker-v1-turbo-en",
    trust_remote_code=True,
    device=DEVICE,
    cache_dir="models",
)


class RerankerRequest(BaseModel):
    query: str = Field(..., description="The search query string")
    documents: List[str] = Field(..., description="List of documents to rerank")
    return_documents: bool = Field(
        True, description="Whether to return document content in results"
    )
    top_k: int = Field(3, description="Number of top results to return")


class RankedResult(BaseModel):
    score: float
    index: int
    document: Optional[str] = None


class RerankerResponse(BaseModel):
    results: List[RankedResult]


@app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
async def rerank_documents(request: RerankerRequest):
    """

    Reranks the given list of documents based on their relevance to the query.



    - **query**: The input query string.

    - **documents**: A list of documents to be reranked.

    - **return_documents**: Whether to include document content in results.

    - **top_k**: Number of top-ranked documents to return.



    Returns:

    - A list of ranked documents with scores and indexes.

    """
    try:
        # Call the model's rank method with the provided parameters
        results = model.rank(
            request.query,
            request.documents,
            return_documents=request.return_documents,
            top_k=request.top_k,
        )

        # Format the results based on the model's output
        formatted_results = [
            RankedResult(
                score=result["score"],
                index=result["corpus_id"],
                document=result["text"] if request.return_documents else None,
            )
            for result in results
        ]

        return RerankerResponse(results=formatted_results)

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")


# Run the FastAPI app with Uvicorn
if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)