Deep8591 commited on
Commit
48d20bf
·
verified ·
1 Parent(s): 1123987

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -19
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from loguru import logger
3
  from pydantic import BaseModel, Field
@@ -30,12 +31,20 @@ logger.warning(
30
  f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
31
  )
32
 
 
 
 
33
  # Load the model at startup to avoid reloading for each request
34
- model = CrossEncoder(
35
- "jinaai/jina-reranker-v1-turbo-en",
36
- trust_remote_code=True,
37
- device=DEVICE
38
- )
 
 
 
 
 
39
 
40
 
41
  class RerankerRequest(BaseModel):
@@ -71,27 +80,31 @@ async def rerank_documents(request: RerankerRequest):
71
  - A list of ranked documents with scores and indexes.
72
  """
73
  try:
74
- # Call the model's rank method with the provided parameters
75
- results = model.rank(
76
- request.query,
77
- request.documents,
78
- return_documents=request.return_documents,
79
- top_k=request.top_k,
80
- )
81
-
82
- # Format the results based on the model's output
 
 
 
83
  formatted_results = [
84
  RankedResult(
85
- score=result["score"],
86
- index=result["corpus_id"],
87
- document=result["text"] if request.return_documents else None,
88
  )
89
- for result in results
90
  ]
91
 
92
  return RerankerResponse(results=formatted_results)
93
 
94
  except Exception as e:
 
95
  raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
96
 
97
 
@@ -99,4 +112,4 @@ async def rerank_documents(request: RerankerRequest):
99
  if __name__ == "__main__":
100
  import uvicorn
101
 
102
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
2
  import torch
3
  from loguru import logger
4
  from pydantic import BaseModel, Field
 
31
  f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
32
  )
33
 
34
+ # Ensure a writable cache directory
35
+ os.makedirs("models_cache", exist_ok=True)
36
+
37
  # Load the model at startup to avoid reloading for each request
38
+ try:
39
+ model = CrossEncoder(
40
+ "jinaai/jina-reranker-v1-turbo-en",
41
+ trust_remote_code=True,
42
+ device=DEVICE,
43
+ cache_dir="models_cache",
44
+ )
45
+ except Exception as e:
46
+ logger.error(f"Failed to load model: {e}")
47
+ raise RuntimeError("Model loading failed. Check logs for details.")
48
 
49
 
50
  class RerankerRequest(BaseModel):
 
80
  - A list of ranked documents with scores and indexes.
81
  """
82
  try:
83
+ # Prepare model input
84
+ inputs = [[request.query, doc] for doc in request.documents]
85
+
86
+ # Get ranking scores
87
+ scores = model.predict(inputs)
88
+
89
+ # Sort scores and get top-k results
90
+ ranked_indices = sorted(
91
+ range(len(scores)), key=lambda i: scores[i], reverse=True
92
+ )[: request.top_k]
93
+
94
+ # Format results
95
  formatted_results = [
96
  RankedResult(
97
+ score=scores[i],
98
+ index=i,
99
+ document=request.documents[i] if request.return_documents else None,
100
  )
101
+ for i in ranked_indices
102
  ]
103
 
104
  return RerankerResponse(results=formatted_results)
105
 
106
  except Exception as e:
107
+ logger.error(f"Error in reranking: {e}")
108
  raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
109
 
110
 
 
112
  if __name__ == "__main__":
113
  import uvicorn
114
 
115
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)