Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
from loguru import logger
|
3 |
from pydantic import BaseModel, Field
|
@@ -30,12 +31,20 @@ logger.warning(
|
|
30 |
f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
|
31 |
)
|
32 |
|
|
|
|
|
|
|
33 |
# Load the model at startup to avoid reloading for each request
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
class RerankerRequest(BaseModel):
|
@@ -71,27 +80,31 @@ async def rerank_documents(request: RerankerRequest):
|
|
71 |
- A list of ranked documents with scores and indexes.
|
72 |
"""
|
73 |
try:
|
74 |
-
#
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
83 |
formatted_results = [
|
84 |
RankedResult(
|
85 |
-
score=
|
86 |
-
index=
|
87 |
-
document=
|
88 |
)
|
89 |
-
for
|
90 |
]
|
91 |
|
92 |
return RerankerResponse(results=formatted_results)
|
93 |
|
94 |
except Exception as e:
|
|
|
95 |
raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
|
96 |
|
97 |
|
@@ -99,4 +112,4 @@ async def rerank_documents(request: RerankerRequest):
|
|
99 |
if __name__ == "__main__":
|
100 |
import uvicorn
|
101 |
|
102 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
from loguru import logger
|
4 |
from pydantic import BaseModel, Field
|
|
|
31 |
f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
|
32 |
)
|
33 |
|
34 |
+
# Ensure a writable cache directory
|
35 |
+
os.makedirs("models_cache", exist_ok=True)
|
36 |
+
|
37 |
# Load the model at startup to avoid reloading for each request
|
38 |
+
try:
|
39 |
+
model = CrossEncoder(
|
40 |
+
"jinaai/jina-reranker-v1-turbo-en",
|
41 |
+
trust_remote_code=True,
|
42 |
+
device=DEVICE,
|
43 |
+
cache_dir="models_cache",
|
44 |
+
)
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"Failed to load model: {e}")
|
47 |
+
raise RuntimeError("Model loading failed. Check logs for details.")
|
48 |
|
49 |
|
50 |
class RerankerRequest(BaseModel):
|
|
|
80 |
- A list of ranked documents with scores and indexes.
|
81 |
"""
|
82 |
try:
|
83 |
+
# Prepare model input
|
84 |
+
inputs = [[request.query, doc] for doc in request.documents]
|
85 |
+
|
86 |
+
# Get ranking scores
|
87 |
+
scores = model.predict(inputs)
|
88 |
+
|
89 |
+
# Sort scores and get top-k results
|
90 |
+
ranked_indices = sorted(
|
91 |
+
range(len(scores)), key=lambda i: scores[i], reverse=True
|
92 |
+
)[: request.top_k]
|
93 |
+
|
94 |
+
# Format results
|
95 |
formatted_results = [
|
96 |
RankedResult(
|
97 |
+
score=scores[i],
|
98 |
+
index=i,
|
99 |
+
document=request.documents[i] if request.return_documents else None,
|
100 |
)
|
101 |
+
for i in ranked_indices
|
102 |
]
|
103 |
|
104 |
return RerankerResponse(results=formatted_results)
|
105 |
|
106 |
except Exception as e:
|
107 |
+
logger.error(f"Error in reranking: {e}")
|
108 |
raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
|
109 |
|
110 |
|
|
|
112 |
if __name__ == "__main__":
|
113 |
import uvicorn
|
114 |
|
115 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
|