swaroop-uddandarao commited on
Commit
bea31e7
·
1 Parent(s): a7d8778

added rerank model options

Browse files
Files changed (2) hide show
  1. app.py +10 -1
  2. finetuneresults.py +83 -16
app.py CHANGED
@@ -19,7 +19,10 @@ from huggingface_hub import dataset_info
19
 
20
  # Load embedding model
21
  QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
22
- RERANKING_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
 
 
 
23
  PROMPT_MODEL = "llama-3.3-70b-specdec"
24
  EVAL_MODEL = "llama-3.3-70b-specdec"
25
  WINDOW_SIZE = 5
@@ -107,6 +110,12 @@ with gr.Blocks() as iface:
107
  label="Select a Model"
108
  )
109
 
 
 
 
 
 
 
110
  submit_button = gr.Button("Evaluate Model")
111
 
112
  with gr.Row():
 
19
 
20
  # Load embedding model
21
  QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
22
+ RERANKING_MODELS = {
23
+ "MS MARCO MiniLM": "cross-encoder/ms-marco-MiniLM-L-6-v2",
24
+ "MonoT5 Base": "castorini/monot5-base-msmarco",
25
+ }
26
  PROMPT_MODEL = "llama-3.3-70b-specdec"
27
  EVAL_MODEL = "llama-3.3-70b-specdec"
28
  WINDOW_SIZE = 5
 
110
  label="Select a Model"
111
  )
112
 
113
+ reranker_dropdown = gr.Dropdown(
114
+ list(RERANKING_MODELS.keys()),
115
+ value="MS MARCO MiniLM",
116
+ label="Select Reranking Model"
117
+ )
118
+
119
  submit_button = gr.Button("Evaluate Model")
120
 
121
  with gr.Row():
finetuneresults.py CHANGED
@@ -1,5 +1,62 @@
1
  from sentence_transformers import CrossEncoder
2
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
  Retrieves unique full documents based on the top-ranked document IDs.
5
 
@@ -37,25 +94,35 @@ Returns:
37
  """
38
 
39
  def rerank_documents(query, retrieved_docs_df, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
 
 
 
 
40
 
41
- # Load Cross-Encoder model
42
- model = CrossEncoder(model_name)
43
-
44
- # Prepare query-document pairs
45
- query_doc_pairs = [(query, " ".join(doc)) for doc in retrieved_docs_df["document"]]
46
 
47
- # Compute relevance scores
48
- scores = model.predict(query_doc_pairs)
49
 
50
- # Add scores to the DataFrame
51
- retrieved_docs_df["relevance_score"] = scores
52
 
53
- # Sort by score in descending order (higher score = more relevant)
54
- reranked_docs_df = retrieved_docs_df.sort_values(by="relevance_score", ascending=False).reset_index(drop=True)
55
 
56
- return reranked_docs_df
 
 
 
 
 
57
 
58
  def FineTuneAndRerankSearchResults(top_10_chunk_results, rag_extarcted_data, question, reranking_model):
59
- unique_docs= retrieve_full_documents(top_10_chunk_results, rag_extarcted_data)
60
- reranked_results = rerank_documents(question, unique_docs, reranking_model)
61
- return reranked_results
 
 
 
 
 
1
  from sentence_transformers import CrossEncoder
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
+ import numpy as np
5
+ from typing import List, Tuple
6
+
7
+ class MonoT5Reranker:
8
+ def __init__(self, model_name: str):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ print(f"Using device: {self.device}")
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
+ self.model.to(self.device)
14
+ self.model.eval()
15
+
16
+ def predict(self, query_doc_pairs: List[Tuple[str, str]]) -> np.ndarray:
17
+ scores = []
18
+ batch_size = 8 # Adjust based on your GPU/CPU memory
19
+
20
+ for i in range(0, len(query_doc_pairs), batch_size):
21
+ batch_pairs = query_doc_pairs[i:i + batch_size]
22
+
23
+ # Format input as per MonoT5 requirements
24
+ inputs = [f"Query: {query} Document: {doc}" for query, doc in batch_pairs]
25
+
26
+ # Tokenize
27
+ encoded = self.tokenizer(
28
+ inputs,
29
+ padding=True,
30
+ truncation=True,
31
+ max_length=512,
32
+ return_tensors="pt"
33
+ ).to(self.device)
34
+
35
+ # Get predictions
36
+ with torch.no_grad():
37
+ outputs = self.model(**encoded)
38
+ batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
39
+ scores.extend(batch_scores.tolist())
40
+
41
+ return np.array(scores)
42
+
43
+ class MSMARCOReranker:
44
+ def __init__(self, model_name: str):
45
+ self.model = CrossEncoder(model_name)
46
+
47
+ def predict(self, query_doc_pairs: List[Tuple[str, str]]) -> np.ndarray:
48
+ return self.model.predict(query_doc_pairs)
49
+
50
+
51
+ def get_reranker(model_name: str):
52
+ """Factory function to get appropriate reranker based on model name."""
53
+ if "monot5" in model_name.lower():
54
+ print(f"Using MonoT5 reranker: {model_name}")
55
+ return MonoT5Reranker(model_name)
56
+ else:
57
+ print(f"Using MS MARCO reranker: {model_name}")
58
+ return MSMARCOReranker(model_name)
59
+
60
  """
61
  Retrieves unique full documents based on the top-ranked document IDs.
62
 
 
94
  """
95
 
96
  def rerank_documents(query, retrieved_docs_df, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
97
+ """Reranks documents using the specified reranking model."""
98
+ try:
99
+ # Load Cross-Encoder model
100
+ model = get_reranker(model_name)
101
 
102
+ # Prepare query-document pairs
103
+ query_doc_pairs = [(query, " ".join(doc)) for doc in retrieved_docs_df["document"]]
 
 
 
104
 
105
+ # Compute relevance scores
106
+ scores = model.predict(query_doc_pairs)
107
 
108
+ # Add scores to the DataFrame
109
+ retrieved_docs_df["relevance_score"] = scores
110
 
111
+ # Sort by score in descending order (higher score = more relevant)
112
+ reranked_docs_df = retrieved_docs_df.sort_values(by="relevance_score", ascending=False).reset_index(drop=True)
113
 
114
+ return reranked_docs_df
115
+ except Exception as e:
116
+ print(f"Error in reranking: {e}")
117
+ # Return original order if reranking fails
118
+ retrieved_docs_df["relevance_score"] = 1.0
119
+ return retrieved_docs_df
120
 
121
  def FineTuneAndRerankSearchResults(top_10_chunk_results, rag_extarcted_data, question, reranking_model):
122
+ try:
123
+ unique_docs= retrieve_full_documents(top_10_chunk_results, rag_extarcted_data)
124
+ reranked_results = rerank_documents(question, unique_docs, reranking_model)
125
+ return reranked_results
126
+ except Exception as e:
127
+ print(f"Error in FineTuneAndRerankSearchResults: {e}")
128
+ return None