Spaces:
Sleeping
Sleeping
File size: 15,210 Bytes
a97d040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
import torch
import uuid
import re
import os
import json
import chromadb
from .asg_splitter import TextSplitting
from langchain_huggingface import HuggingFaceEmbeddings
import time
import concurrent.futures
class Retriever:
client = None
cur_dir = os.getcwd()
chromadb_path = os.path.join(cur_dir, "chromadb")
def __init__ (self):
self.client = chromadb.PersistentClient(path=self.chromadb_path)
def create_collection_chroma(self, collection_name: str):
"""
The Collection will be created with collection_name, the name must follow the rules:\n
0. Collection name must be unique, if the name exists then try to get this collection\n
1. The length of the name must be between 3 and 63 characters.\n
2. The name must start and end with a lowercase letter or a digit, and it can contain dots, dashes, and underscores in between.\n
3. The name must not contain two consecutive dots.\n
4. The name must not be a valid IP address.\n
"""
try:
self.client.create_collection(name=collection_name)
except chromadb.db.base.UniqueConstraintError:
self.get_collection_chroma(collection_name)
return collection_name
def get_collection_chroma (self, collection_name: str):
collection = self.client.get_collection(name=collection_name)
return collection
def add_documents_chroma (self, collection_name: str, embeddings_list: list[list[float]], documents_list: list[dict], metadata_list: list[dict]) :
"""
Please make sure that embeddings_list and metadata_list are matched with documents_list\n
Example of one metadata: {"doc_name": "Test2.pdf", "page": "9"}\n
The id will be created automatically as uuid v4
The chunks content and metadata will be logged (appended) into ./logs/<collection_name>.json
"""
collection = self.get_collection_chroma(collection_name)
num = len(documents_list)
ids=[str(uuid.uuid4()) for i in range(num) ]
collection.add(
documents= documents_list,
metadatas= metadata_list,
embeddings= embeddings_list,
ids=ids
)
logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json")
os.makedirs(os.path.dirname(logpath), exist_ok=True)
logs = []
try:
with open (logpath, 'r', encoding="utf-8") as chunklog:
logs = json.load(chunklog)
except (FileNotFoundError, json.decoder.JSONDecodeError):
logs = []
added_log= [{"chunk_id": ids[i], "metadata": metadata_list[i], "page_content": documents_list[i]} \
for i in range(num)]
logs.extend(added_log)
# write back
with open (logpath, "w", encoding="utf-8") as chunklog:
json.dump(logs, chunklog, indent=4)
print(f"Logged document information to '{logpath}'.")
def query_chroma(self, collection_name: str, query_embeddings: list[list[float]], n_results: int = 5) -> dict:
# return n closest results (chunks and metadatas) in order
collection = self.get_collection_chroma(collection_name)
result = collection.query(
query_embeddings=query_embeddings,
n_results=n_results,
)
return result
def update_chroma (self, collection_name: str, id_list: list[str], embeddings_list: list[list[float]], documents_list: list[str], metadata_list: list[dict]):
collection = self.get_collection_chroma(collection_name)
num = len(documents_list)
collection.update(
ids=id_list,
embeddings=embeddings_list,
metadatas=metadata_list,
documents=documents_list,
)
update_list = [{"chunk_id": id_list[i], "metadata": metadata_list[i], "page_content": documents_list[i]} for i in range(num)]
# update the chunk log
logs = []
logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json")
try:
with open (logpath, 'r', encoding="utf-8") as chunklog:
logs = json.load(chunklog)
except (FileNotFoundError, json.decoder.JSONDecodeError):
logs = [] # old_log does not exist or empty, then no need to update
else:
for i in range(num):
for log in logs:
if (log["chunk_id"] == update_list[i]["chunk_id"]):
log["metadata"] = update_list[i]["metadata"]
log["page_content"] = update_list[i]["page_content"]
break
with open (logpath, "w", encoding="utf-8") as chunklog:
json.dump(logs, chunklog, indent=4)
print(f"Updated log file at '{logpath}'.")
def delete_collection_entries_chroma(self, collection_name: str, id_list: list[str]):
collection = self.get_collection_chroma(collection_name)
collection.delete(ids=id_list)
print(f"Deleted entries with ids: {id_list} from collection '{collection_name}'.")
def delete_collection_chroma(self, collection_name: str):
print(f"The collection {collection_name} will be deleted forever!")
self.client.delete_collection(collection_name)
try:
logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json")
print(f"Collection {collection_name} has been removed, deleting log file of this collection")
os.remove(logpath)
except FileNotFoundError:
print("The log of this collection did not exist!")
def list_collections_chroma(self):
collections = self.client.list_collections()
# Generate a legal collection name from a PDF filename
def legal_pdf(filename: str) -> str:
pdf_index = filename.lower().rfind('.pdf')
if pdf_index != -1:
name_before_pdf = filename[:pdf_index]
else:
name_before_pdf = filename
name_before_pdf = name_before_pdf.strip()
name = re.sub(r'[^a-zA-Z0-9._-]', '', name_before_pdf)
name = name.lower()
while '..' in name:
name = name.replace('..', '.')
name = name[:63]
if len(name) < 3:
name = name.ljust(3, '0') # fill with '0' if the length is less than 3
if not re.match(r'^[a-z0-9]', name):
name = 'a' + name[1:]
if not re.match(r'[a-z0-9]$', name):
name = name[:-1] + 'a'
ip_pattern = re.compile(r'^(\d{1,3}\.){3}\d{1,3}$')
if ip_pattern.match(name):
name = 'ip_' + name
return name
def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings, mode: str):
# Load and split the PDF
split_start_time = time.time()
splitters = TextSplitting().mineru_recursive_splitter(file_path, survey_id, mode)
documents_list = [document.page_content for document in splitters]
for i in range(len(documents_list)):
documents_list[i] = documents_list[i].replace('\n', ' ')
print(f"Splitting took {time.time() - split_start_time} seconds.")
# Embed the documents
# embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
embed_start_time = time.time()
doc_results = embedder.embed_documents(documents_list)
if isinstance(doc_results, torch.Tensor):
embeddings_list = doc_results.tolist()
else:
embeddings_list = doc_results
print(f"Embedding took {time.time() - embed_start_time} seconds.")
# Prepare metadata
metadata_list = [{"doc_name": os.path.basename(file_path)} for i in range(len(documents_list))]
title = os.path.splitext(os.path.basename(file_path))[0]
title_new = title.strip()
invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*','_']
for char in invalid_chars:
title_new = title_new.replace(char, ' ')
collection_name = legal_pdf(title_new)
retriever = Retriever()
retriever.list_collections_chroma()
retriever.create_collection_chroma(collection_name)
retriever.add_documents_chroma(
collection_name=collection_name,
embeddings_list=embeddings_list,
documents_list=documents_list,
metadata_list=metadata_list
)
return collection_name, embeddings_list, documents_list, metadata_list,title_new
def query_embeddings(collection_name: str, query_list: list):
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
retriever = Retriever()
final_context = ""
seen_chunks = set()
for query_text in query_list:
query_embeddings = embedder.embed_query(query_text)
query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=2)
query_result_chunks = query_result["documents"][0]
# query_result_ids = query_result["ids"][0]
for chunk in query_result_chunks:
if chunk not in seen_chunks:
final_context += chunk.strip() + "//\n"
seen_chunks.add(chunk)
return final_context
# new, may be in parallel
def query_embeddings_new(collection_name: str, query_list: list):
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
retriever = Retriever()
final_context = ""
seen_chunks = set()
def process_query(query_text):
query_embeddings = embedder.embed_query(query_text)
query_result = retriever.query_chroma(
collection_name=collection_name,
query_embeddings=[query_embeddings],
n_results=2
)
query_result_chunks = query_result["documents"][0]
return query_result_chunks
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {executor.submit(process_query, query_text): query_text for query_text in query_list}
for future in concurrent.futures.as_completed(futures):
query_result_chunks = future.result()
for chunk in query_result_chunks:
if chunk not in seen_chunks:
final_context += chunk.strip() + "//\n"
seen_chunks.add(chunk)
return final_context
# wza
def query_embeddings_new_new(collection_name: str, query_list: list):
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
retriever = Retriever()
final_context = "" # Stores concatenated context
citation_data_list = [] # Stores chunk content and collection name as source
seen_chunks = set() # Ensures unique chunks are added
def process_query(query_text):
# Embed the query text and retrieve relevant chunks
query_embeddings = embedder.embed_query(query_text)
query_result = retriever.query_chroma(
collection_name=collection_name,
query_embeddings=[query_embeddings],
n_results=5 # Fixed number of results
)
return query_result
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_query = {executor.submit(process_query, q): q for q in query_list}
for future in concurrent.futures.as_completed(future_to_query):
query_text = future_to_query[future]
try:
query_result = future.result()
except Exception as e:
print(f"Query '{query_text}' failed with exception: {e}")
continue
if "documents" not in query_result or "distances" not in query_result:
continue
if not query_result["documents"] or not query_result["distances"]:
continue
docs_list = query_result["documents"][0] if query_result["documents"] else []
dist_list = query_result["distances"][0] if query_result["distances"] else []
if len(docs_list) != len(dist_list):
continue
for chunk, distance in zip(docs_list, dist_list):
processed_chunk = chunk.strip()
if processed_chunk not in seen_chunks:
final_context += processed_chunk + "//\n"
seen_chunks.add(processed_chunk)
citation_data_list.append({
"source": collection_name,
"distance": distance,
"content": processed_chunk,
})
return final_context, citation_data_list
# concurrent version for both collection names and queries
def query_multiple_collections(collection_names: list[str], query_list: list[str], survey_id: str) -> dict:
"""
Query multiple collections in parallel and return the combined results.
Args:
collection_names (list[str]): List of collection names to query.
query_list (list[str]): List of queries to execute on each collection.
Returns:
dict: Combined results from all collections, grouped by collection.
"""
# Define embedder inside the function
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
retriever = Retriever()
def query_single_collection(collection_name: str):
"""
Query a single collection for all queries in the query_list.
"""
final_context = ""
seen_chunks = set()
def process_query(query_text):
# Embed the query
query_embeddings = embedder.embed_query(query_text)
# Query the collection
query_result = retriever.query_chroma(
collection_name=collection_name,
query_embeddings=[query_embeddings],
n_results=5
)
query_result_chunks = query_result["documents"][0]
return query_result_chunks
# Process all queries in parallel for the given collection
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {executor.submit(process_query, query_text): query_text for query_text in query_list}
for future in concurrent.futures.as_completed(futures):
query_result_chunks = future.result()
for chunk in query_result_chunks:
if chunk not in seen_chunks:
final_context += chunk.strip() + "//\n"
seen_chunks.add(chunk)
return final_context
# Outer parallelism for multiple collections
results = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {executor.submit(query_single_collection, collection_name): collection_name for collection_name in collection_names}
for future in concurrent.futures.as_completed(futures):
collection_name = futures[future]
results[collection_name] = future.result()
# Automatically save the results to a JSON file
file_path = f'./src/static/data/info/{survey_id}/retrieved_context.json'
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=4)
print(f"Results saved to {file_path}")
return results |