Gourisankar Padihary commited on
Commit
d93e32b
·
1 Parent(s): 5485d7c

Support for all data set

Browse files
config.py CHANGED
@@ -1,7 +1,7 @@
1
 
2
  class ConfigConstants:
3
  # Constants related to datasets and models
4
- DATA_SET_NAMES = ['covidqa', 'techqa', 'cuad']
5
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
  RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
  DEFAULT_CHUNK_SIZE = 1000
 
1
 
2
  class ConfigConstants:
3
  # Constants related to datasets and models
4
+ DATA_SET_NAMES = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
5
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
  RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
  DEFAULT_CHUNK_SIZE = 1000
retriever/chunk_documents.py CHANGED
@@ -1,12 +1,25 @@
1
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
2
 
3
  def chunk_documents(dataset, chunk_size=1000, chunk_overlap=200):
4
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
5
  documents = []
 
 
6
  for data in dataset:
7
  text_list = data['documents']
8
  for text in text_list:
9
  chunks = text_splitter.split_text(text)
10
  for i, chunk in enumerate(chunks):
 
 
 
 
 
 
 
 
11
  documents.append({'text': chunk, 'source': f"{data['question']}_chunk_{i}"})
 
 
12
  return documents
 
1
  from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ import hashlib
3
 
4
  def chunk_documents(dataset, chunk_size=1000, chunk_overlap=200):
5
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
6
  documents = []
7
+ seen_hashes = set() # Track hashes of chunks to avoid duplicates
8
+
9
  for data in dataset:
10
  text_list = data['documents']
11
  for text in text_list:
12
  chunks = text_splitter.split_text(text)
13
  for i, chunk in enumerate(chunks):
14
+ # Generate a unique hash for the chunk
15
+ chunk_hash = hashlib.sha256(chunk.encode()).hexdigest()
16
+
17
+ # Skip if the chunk is a duplicate
18
+ if chunk_hash in seen_hashes:
19
+ continue
20
+
21
+ # Add the chunk to the documents list and track its hash
22
  documents.append({'text': chunk, 'source': f"{data['question']}_chunk_{i}"})
23
+ seen_hashes.add(chunk_hash)
24
+
25
  return documents
retriever/embed_documents.py CHANGED
@@ -7,7 +7,7 @@ from config import ConfigConstants
7
 
8
  def embed_documents(documents, embedding_path="embeddings.faiss"):
9
  embedding_model = HuggingFaceEmbeddings(model_name=ConfigConstants.EMBEDDING_MODEL_NAME)
10
-
11
  if os.path.exists(embedding_path):
12
  logging.info("Loading embeddings from local file")
13
  vector_store = FAISS.load_local(embedding_path, embedding_model, allow_dangerous_deserialization=True)
@@ -17,3 +17,80 @@ def embed_documents(documents, embedding_path="embeddings.faiss"):
17
  vector_store.save_local(embedding_path)
18
 
19
  return vector_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def embed_documents(documents, embedding_path="embeddings.faiss"):
9
  embedding_model = HuggingFaceEmbeddings(model_name=ConfigConstants.EMBEDDING_MODEL_NAME)
10
+
11
  if os.path.exists(embedding_path):
12
  logging.info("Loading embeddings from local file")
13
  vector_store = FAISS.load_local(embedding_path, embedding_model, allow_dangerous_deserialization=True)
 
17
  vector_store.save_local(embedding_path)
18
 
19
  return vector_store
20
+
21
+ '''import os
22
+ import logging
23
+ import hashlib
24
+ from typing import List, Dict
25
+ from concurrent.futures import ThreadPoolExecutor
26
+ from tqdm import tqdm
27
+ from langchain_community.vectorstores import FAISS
28
+ from langchain_huggingface import HuggingFaceEmbeddings
29
+ from config import ConfigConstants
30
+
31
+
32
+ def embed_documents(documents: List[Dict], embedding_path: str = "embeddings.faiss", metadata_path: str = "metadata.json") -> FAISS:
33
+ logging.info(f"Total documents got :{len(documents)}")
34
+ embedding_model = HuggingFaceEmbeddings(model_name=ConfigConstants.EMBEDDING_MODEL_NAME)
35
+
36
+ if os.path.exists(embedding_path) and os.path.exists(metadata_path):
37
+ logging.info("Loading embeddings and metadata from local files")
38
+ vector_store = FAISS.load_local(embedding_path, embedding_model, allow_dangerous_deserialization=True)
39
+ existing_metadata = _load_metadata(metadata_path)
40
+ else:
41
+ # Initialize FAISS with at least one document to avoid the IndexError
42
+ if documents:
43
+ vector_store = FAISS.from_texts([documents[0]['text']], embedding_model)
44
+ else:
45
+ # If no documents are provided, initialize an empty FAISS index with a dummy document
46
+ vector_store = FAISS.from_texts(["dummy document"], embedding_model)
47
+ existing_metadata = {}
48
+
49
+ # Identify new or modified documents
50
+ new_documents = []
51
+ for doc in documents:
52
+ doc_hash = _generate_document_hash(doc['text'])
53
+ if doc_hash not in existing_metadata:
54
+ new_documents.append(doc)
55
+ existing_metadata[doc_hash] = True # Mark as processed
56
+
57
+ if new_documents:
58
+ logging.info(f"Generating embeddings for {len(new_documents)} new documents")
59
+ with ThreadPoolExecutor() as executor:
60
+ futures = []
61
+ for doc in new_documents:
62
+ futures.append(executor.submit(_embed_single_document, doc, embedding_model))
63
+
64
+ for future in tqdm(futures, desc="Generating embeddings", unit="doc"):
65
+ vector_store.add_texts([future.result()])
66
+
67
+ # Save updated embeddings and metadata
68
+ vector_store.save_local(embedding_path)
69
+ _save_metadata(metadata_path, existing_metadata)
70
+ else:
71
+ logging.info("No new documents to process. Using existing embeddings.")
72
+
73
+ return vector_store
74
+
75
+ def _embed_single_document(doc: Dict, embedding_model: HuggingFaceEmbeddings) -> str:
76
+ return doc['text']
77
+
78
+ def _generate_document_hash(text: str) -> str:
79
+ """Generate a unique hash for a document based on its text."""
80
+ return hashlib.sha256(text.encode()).hexdigest()
81
+
82
+ def _load_metadata(metadata_path: str) -> Dict[str, bool]:
83
+ """Load metadata from a file."""
84
+ import json
85
+ if os.path.exists(metadata_path):
86
+ with open(metadata_path, "r") as f:
87
+ return json.load(f)
88
+ return {}
89
+
90
+ def _save_metadata(metadata_path: str, metadata: Dict[str, bool]):
91
+ """Save metadata to a file."""
92
+ import json
93
+ with open(metadata_path, "w") as f:
94
+ json.dump(metadata, f)'''
95
+
96
+