Gourisankar Padihary commited on
Commit
bcc15bd
·
1 Parent(s): 0ea6d19

load dataset locally

Browse files
data/load_dataset.py CHANGED
@@ -1,9 +1,23 @@
 
1
  import logging
2
  from datasets import load_dataset
 
3
 
4
- def load_data(data_set_name):
5
- logging.info("Loading dataset")
6
- dataset = load_dataset("rungalileo/ragbench", data_set_name, split="test")
 
 
 
 
 
 
 
 
 
 
 
 
7
  logging.info("Dataset loaded successfully")
8
  logging.info(f"Number of documents found: {dataset.num_rows}")
9
- return dataset
 
1
+ import os
2
  import logging
3
  from datasets import load_dataset
4
+ import pickle # For saving the dataset locally
5
 
6
+ def load_data(data_set_name, local_path="local_datasets"):
7
+ os.makedirs(local_path, exist_ok=True)
8
+ dataset_file = os.path.join(local_path, f"{data_set_name}_test.pkl")
9
+
10
+ if os.path.exists(dataset_file):
11
+ logging.info("Loading dataset from local storage")
12
+ with open(dataset_file, "rb") as f:
13
+ dataset = pickle.load(f)
14
+ else:
15
+ logging.info("Loading dataset from Hugging Face")
16
+ dataset = load_dataset("rungalileo/ragbench", data_set_name, split="test")
17
+ logging.info(f"Saving {data_set_name} dataset locally")
18
+ with open(dataset_file, "wb") as f:
19
+ pickle.dump(dataset, f)
20
+
21
  logging.info("Dataset loaded successfully")
22
  logging.info(f"Number of documents found: {dataset.num_rows}")
23
+ return dataset
generator/initialize_llm.py CHANGED
@@ -3,14 +3,14 @@ import os
3
  from langchain_groq import ChatGroq
4
 
5
  def initialize_generation_llm():
6
- os.environ["GROQ_API_KEY"] = "your_groq_api_key"
7
  model_name = "llama3-8b-8192"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
9
  logging.info(f'Generation LLM {model_name} initialized')
10
  return llm
11
 
12
  def initialize_validation_llm():
13
- os.environ["GROQ_API_KEY"] = "your_groq_api_key"
14
  model_name = "llama3-70b-8192"
15
  llm = ChatGroq(model=model_name, temperature=0.7)
16
  logging.info(f'Validation LLM {model_name} initialized')
 
3
  from langchain_groq import ChatGroq
4
 
5
  def initialize_generation_llm():
6
+ os.environ["GROQ_API_KEY"] = "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"
7
  model_name = "llama3-8b-8192"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
9
  logging.info(f'Generation LLM {model_name} initialized')
10
  return llm
11
 
12
  def initialize_validation_llm():
13
+ os.environ["GROQ_API_KEY"] = "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"
14
  model_name = "llama3-70b-8192"
15
  llm = ChatGroq(model=model_name, temperature=0.7)
16
  logging.info(f'Validation LLM {model_name} initialized')
main.py CHANGED
@@ -12,7 +12,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
12
 
13
  def main():
14
  logging.info("Starting the RAG pipeline")
15
- data_set_name = 'techqa'
16
 
17
  # Load the dataset
18
  dataset = load_data(data_set_name)
@@ -36,11 +36,11 @@ def main():
36
  val_llm = initialize_validation_llm()
37
 
38
  # Sample question
39
- row_num = 7
40
  query = dataset[row_num]['question']
41
 
42
  # Call generate_metrics for above sample question
43
- generate_metrics(gen_llm, val_llm, vector_store, query)
44
 
45
  #Compute RMSE and AUC-ROC for entire dataset
46
  compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
 
12
 
13
  def main():
14
  logging.info("Starting the RAG pipeline")
15
+ data_set_name = 'covidqa'
16
 
17
  # Load the dataset
18
  dataset = load_data(data_set_name)
 
36
  val_llm = initialize_validation_llm()
37
 
38
  # Sample question
39
+ row_num = 2
40
  query = dataset[row_num]['question']
41
 
42
  # Call generate_metrics for above sample question
43
+ #generate_metrics(gen_llm, val_llm, vector_store, query)
44
 
45
  #Compute RMSE and AUC-ROC for entire dataset
46
  compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
retriever/embed_documents.py CHANGED
@@ -1,7 +1,17 @@
 
 
1
  from langchain_huggingface import HuggingFaceEmbeddings
2
  from langchain_community.vectorstores import FAISS
3
 
4
- def embed_documents(documents):
5
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L3-v2")
6
- vector_store = FAISS.from_texts([doc['text'] for doc in documents], embedding_model)
 
 
 
 
 
 
 
 
7
  return vector_store
 
1
+ import os
2
+ import logging
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
 
6
+ def embed_documents(documents, embedding_path="embeddings.faiss"):
7
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L3-v2")
8
+
9
+ if os.path.exists(embedding_path):
10
+ logging.info("Loading embeddings from local file")
11
+ vector_store = FAISS.load_local(embedding_path, embedding_model, allow_dangerous_deserialization=True)
12
+ else:
13
+ logging.info("Generating and saving embeddings")
14
+ vector_store = FAISS.from_texts([doc['text'] for doc in documents], embedding_model)
15
+ vector_store.save_local(embedding_path)
16
+
17
  return vector_store