Gourisankar Padihary commited on
Commit
5485d7c
·
1 Parent(s): 5184c29

Further update

Browse files
Files changed (5) hide show
  1. app.py +7 -6
  2. config.py +14 -0
  3. main.py +10 -16
  4. retriever/embed_documents.py +4 -2
  5. retriever/retrieve_documents.py +12 -4
app.py CHANGED
@@ -4,9 +4,9 @@ import threading
4
  import time
5
  from generator.compute_metrics import get_attributes_text
6
  from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
- from io import StringIO
8
 
9
- def launch_gradio(vector_store, gen_llm, val_llm):
10
  """
11
  Launch the Gradio app with pre-initialized objects.
12
  """
@@ -43,7 +43,7 @@ def launch_gradio(vector_store, gen_llm, val_llm):
43
  def answer_question(query, state):
44
  try:
45
  # Generate response using the passed objects
46
- response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query)
47
 
48
  # Update state with the response and source documents
49
  state["query"] = query
@@ -66,7 +66,7 @@ def launch_gradio(vector_store, gen_llm, val_llm):
66
  query = state.get("query", "")
67
 
68
  # Generate metrics using the passed objects
69
- attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 1)
70
 
71
  attributes_text = get_attributes_text(attributes)
72
 
@@ -87,8 +87,9 @@ def launch_gradio(vector_store, gen_llm, val_llm):
87
 
88
  # Section to display LLM names
89
  with gr.Row():
90
- model_info = f"Generation LLM: {gen_llm.name if hasattr(gen_llm, 'name') else 'Unknown'}\n"
91
- model_info += f"Validation LLM: {val_llm.name if hasattr(val_llm, 'name') else 'Unknown'}\n"
 
92
  gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
93
 
94
  # State to store response and source documents
 
4
  import time
5
  from generator.compute_metrics import get_attributes_text
6
  from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
+ from config import AppConfig, ConfigConstants
8
 
9
+ def launch_gradio(config : AppConfig):
10
  """
11
  Launch the Gradio app with pre-initialized objects.
12
  """
 
43
  def answer_question(query, state):
44
  try:
45
  # Generate response using the passed objects
46
+ response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
47
 
48
  # Update state with the response and source documents
49
  state["query"] = query
 
66
  query = state.get("query", "")
67
 
68
  # Generate metrics using the passed objects
69
+ attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
70
 
71
  attributes_text = get_attributes_text(attributes)
72
 
 
87
 
88
  # Section to display LLM names
89
  with gr.Row():
90
+ model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
91
+ model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
92
+ model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
93
  gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
94
 
95
  # State to store response and source documents
config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class ConfigConstants:
3
+ # Constants related to datasets and models
4
+ DATA_SET_NAMES = ['covidqa', 'techqa', 'cuad']
5
+ EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
+ RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
+ DEFAULT_CHUNK_SIZE = 1000
8
+ CHUNK_OVERLAP = 200
9
+
10
+ class AppConfig:
11
+ def __init__(self, vector_store, gen_llm, val_llm):
12
+ self.vector_store = vector_store
13
+ self.gen_llm = gen_llm
14
+ self.val_llm = val_llm
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  from data.load_dataset import load_data
3
  from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
4
  from retriever.chunk_documents import chunk_documents
@@ -12,32 +13,23 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
12
 
13
  def main():
14
  logging.info("Starting the RAG pipeline")
15
-
16
-
17
- # Load single dataset
18
- #dataset = load_data(data_set_name)
19
- #logging.info("Dataset loaded")
20
- # List of datasets to load
21
- data_set_names = ['covidqa', 'techqa', 'cuad']
22
-
23
- default_chunk_size = 1000
24
- chunk_overlap = 200
25
 
26
  # Dictionary to store chunked documents
27
  all_chunked_documents = []
28
- # Load multiple datasets
29
  datasets = {}
30
- for data_set_name in data_set_names:
 
 
31
  logging.info(f"Loading dataset: {data_set_name}")
32
  datasets[data_set_name] = load_data(data_set_name)
33
 
34
  # Set chunk size based on dataset name
35
- chunk_size = default_chunk_size
36
  if data_set_name == 'cuad':
37
  chunk_size = 4000 # Custom chunk size for 'cuad'
38
 
39
  # Chunk documents
40
- chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
41
  all_chunked_documents.extend(chunked_documents) # Combine all chunks
42
 
43
  # Access individual datasets
@@ -58,11 +50,13 @@ def main():
58
  val_llm = initialize_validation_llm()
59
 
60
  #Compute RMSE and AUC-ROC for entire dataset
61
- data_set_name = 'covidqa'
 
62
  #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
63
 
64
  # Launch the Gradio app
65
- launch_gradio(vector_store, gen_llm, val_llm)
 
66
 
67
  logging.info("Finished!!!")
68
 
 
1
  import logging
2
+ from config import AppConfig, ConfigConstants
3
  from data.load_dataset import load_data
4
  from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
5
  from retriever.chunk_documents import chunk_documents
 
13
 
14
  def main():
15
  logging.info("Starting the RAG pipeline")
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Dictionary to store chunked documents
18
  all_chunked_documents = []
 
19
  datasets = {}
20
+
21
+ # Load multiple datasets
22
+ for data_set_name in ConfigConstants.DATA_SET_NAMES:
23
  logging.info(f"Loading dataset: {data_set_name}")
24
  datasets[data_set_name] = load_data(data_set_name)
25
 
26
  # Set chunk size based on dataset name
27
+ chunk_size = ConfigConstants.DEFAULT_CHUNK_SIZE
28
  if data_set_name == 'cuad':
29
  chunk_size = 4000 # Custom chunk size for 'cuad'
30
 
31
  # Chunk documents
32
+ chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=ConfigConstants.CHUNK_OVERLAP)
33
  all_chunked_documents.extend(chunked_documents) # Combine all chunks
34
 
35
  # Access individual datasets
 
50
  val_llm = initialize_validation_llm()
51
 
52
  #Compute RMSE and AUC-ROC for entire dataset
53
+ #Enable below code for calculation
54
+ #data_set_name = 'covidqa'
55
  #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
56
 
57
  # Launch the Gradio app
58
+ config = AppConfig(vector_store= vector_store, gen_llm= gen_llm, val_llm= val_llm)
59
+ launch_gradio(config)
60
 
61
  logging.info("Finished!!!")
62
 
retriever/embed_documents.py CHANGED
@@ -3,9 +3,11 @@ import logging
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
 
6
- def embed_documents(documents, embedding_path="embeddings.faiss"):
7
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L3-v2")
8
 
 
 
 
9
  if os.path.exists(embedding_path):
10
  logging.info("Loading embeddings from local file")
11
  vector_store = FAISS.load_local(embedding_path, embedding_model, allow_dangerous_deserialization=True)
 
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
 
6
+ from config import ConfigConstants
 
7
 
8
+ def embed_documents(documents, embedding_path="embeddings.faiss"):
9
+ embedding_model = HuggingFaceEmbeddings(model_name=ConfigConstants.EMBEDDING_MODEL_NAME)
10
+
11
  if os.path.exists(embedding_path):
12
  logging.info("Loading embeddings from local file")
13
  vector_store = FAISS.load_local(embedding_path, embedding_model, allow_dangerous_deserialization=True)
retriever/retrieve_documents.py CHANGED
@@ -1,13 +1,19 @@
 
1
  import numpy as np
2
  from transformers import pipeline
3
 
 
 
4
  def retrieve_top_k_documents(vector_store, query, top_k=5):
5
  documents = vector_store.similarity_search(query, k=top_k)
 
 
6
  documents = rerank_documents(query, documents)
 
7
  return documents
8
 
9
  # Reranking: Cross-Encoder for refining top-k results
10
- def rerank_documents(query, documents, reranker_model_name="cross-encoder/ms-marco-electra-base"):
11
  """
12
  Re-rank documents using a cross-encoder model.
13
 
@@ -20,7 +26,7 @@ def rerank_documents(query, documents, reranker_model_name="cross-encoder/ms-mar
20
  list: Re-ranked list of Document objects with updated scores.
21
  """
22
  # Initialize the cross-encoder model
23
- reranker = pipeline("text-classification", model=reranker_model_name, return_all_scores=False)
24
 
25
  # Pair the query with each document's text
26
  rerank_inputs = [{"text": query, "text_pair": doc.page_content} for doc in documents]
@@ -28,12 +34,14 @@ def rerank_documents(query, documents, reranker_model_name="cross-encoder/ms-mar
28
  # Get relevance scores for each query-document pair
29
  scores = reranker(rerank_inputs)
30
 
31
- # Attach the new scores to the documents
32
  for doc, score in zip(documents, scores):
33
- doc.metadata["rerank_score"] = score["score"] # Add score to document metadata
34
 
35
  # Sort documents by the rerank_score in descending order
36
  documents = sorted(documents, key=lambda x: x.metadata.get("rerank_score", 0), reverse=True)
 
 
37
  return documents
38
 
39
 
 
1
+ import logging
2
  import numpy as np
3
  from transformers import pipeline
4
 
5
+ from config import ConfigConstants
6
+
7
  def retrieve_top_k_documents(vector_store, query, top_k=5):
8
  documents = vector_store.similarity_search(query, k=top_k)
9
+ logging.info(f"Top {top_k} documents reterived for query")
10
+
11
  documents = rerank_documents(query, documents)
12
+
13
  return documents
14
 
15
  # Reranking: Cross-Encoder for refining top-k results
16
+ def rerank_documents(query, documents):
17
  """
18
  Re-rank documents using a cross-encoder model.
19
 
 
26
  list: Re-ranked list of Document objects with updated scores.
27
  """
28
  # Initialize the cross-encoder model
29
+ reranker = pipeline("text-classification", model=ConfigConstants.RE_RANKER_MODEL_NAME, top_k=1)
30
 
31
  # Pair the query with each document's text
32
  rerank_inputs = [{"text": query, "text_pair": doc.page_content} for doc in documents]
 
34
  # Get relevance scores for each query-document pair
35
  scores = reranker(rerank_inputs)
36
 
37
+ # Attach the new scores to the documents
38
  for doc, score in zip(documents, scores):
39
+ doc.metadata["rerank_score"] = score[0]['score'] # Access score from the first item in the list
40
 
41
  # Sort documents by the rerank_score in descending order
42
  documents = sorted(documents, key=lambda x: x.metadata.get("rerank_score", 0), reverse=True)
43
+ logging.info("Re-ranked documents using a cross-encoder model")
44
+
45
  return documents
46
 
47