Deep8591 commited on
Commit
cce774d
·
verified ·
1 Parent(s): d6f554c

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +13 -0
  2. app.py +103 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ EXPOSE 7860
12
+
13
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from loguru import logger
3
+ from pydantic import BaseModel, Field
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from sentence_transformers import CrossEncoder
7
+ from typing import List, Optional
8
+
9
+ # Initialize FastAPI app with documentation metadata
10
+ app = FastAPI(
11
+ title="Document Reranker API",
12
+ description="An API for reranking documents using a CrossEncoder model.",
13
+ version="1.0",
14
+ docs_url="/docs", # Swagger UI
15
+ redoc_url="/redoc", # ReDoc UI
16
+ )
17
+
18
+ # Enable CORS (optional but useful for frontend integration)
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # Allow all origins (modify as needed)
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Device selection
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ logger.warning(
30
+ f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
31
+ )
32
+
33
+ # Load the model at startup to avoid reloading for each request
34
+ model = CrossEncoder(
35
+ "jinaai/jina-reranker-v1-turbo-en",
36
+ trust_remote_code=True,
37
+ device=DEVICE,
38
+ cache_dir="models",
39
+ )
40
+
41
+
42
+ class RerankerRequest(BaseModel):
43
+ query: str = Field(..., description="The search query string")
44
+ documents: List[str] = Field(..., description="List of documents to rerank")
45
+ return_documents: bool = Field(
46
+ True, description="Whether to return document content in results"
47
+ )
48
+ top_k: int = Field(3, description="Number of top results to return")
49
+
50
+
51
+ class RankedResult(BaseModel):
52
+ score: float
53
+ index: int
54
+ document: Optional[str] = None
55
+
56
+
57
+ class RerankerResponse(BaseModel):
58
+ results: List[RankedResult]
59
+
60
+
61
+ @app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
62
+ async def rerank_documents(request: RerankerRequest):
63
+ """
64
+ Reranks the given list of documents based on their relevance to the query.
65
+
66
+ - **query**: The input query string.
67
+ - **documents**: A list of documents to be reranked.
68
+ - **return_documents**: Whether to include document content in results.
69
+ - **top_k**: Number of top-ranked documents to return.
70
+
71
+ Returns:
72
+ - A list of ranked documents with scores and indexes.
73
+ """
74
+ try:
75
+ # Call the model's rank method with the provided parameters
76
+ results = model.rank(
77
+ request.query,
78
+ request.documents,
79
+ return_documents=request.return_documents,
80
+ top_k=request.top_k,
81
+ )
82
+
83
+ # Format the results based on the model's output
84
+ formatted_results = [
85
+ RankedResult(
86
+ score=result["score"],
87
+ index=result["corpus_id"],
88
+ document=result["text"] if request.return_documents else None,
89
+ )
90
+ for result in results
91
+ ]
92
+
93
+ return RerankerResponse(results=formatted_results)
94
+
95
+ except Exception as e:
96
+ raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
97
+
98
+
99
+ # Run the FastAPI app with Uvicorn
100
+ if __name__ == "__main__":
101
+ import uvicorn
102
+
103
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ loguru
5
+ fastapi
6
+ uvicorn
7
+ pydantic
8
+ sentence-transformers