import torch import uuid import re import os import json import chromadb from .asg_splitter import TextSplitting from langchain_community.embeddings import HuggingFaceEmbeddings import time import concurrent.futures from .path_utils import get_path, setup_hf_cache from langchain_text_splitters import RecursiveCharacterTextSplitter # 设置 Hugging Face 缓存目录 cache_dir = setup_hf_cache() class Retriever: client = None cur_dir = os.getcwd() chromadb_path = os.path.join(cur_dir, "chromadb") def __init__ (self): self.client = chromadb.PersistentClient(path=self.chromadb_path) def create_collection_chroma(self, collection_name: str): """ The Collection will be created with collection_name, the name must follow the rules:\n 0. Collection name must be unique, if the name exists then try to get this collection\n 1. The length of the name must be between 3 and 63 characters.\n 2. The name must start and end with a lowercase letter or a digit, and it can contain dots, dashes, and underscores in between.\n 3. The name must not contain two consecutive dots.\n 4. The name must not be a valid IP address.\n """ try: self.client.create_collection(name=collection_name) except chromadb.db.base.UniqueConstraintError: self.get_collection_chroma(collection_name) return collection_name def get_collection_chroma (self, collection_name: str): collection = self.client.get_collection(name=collection_name) return collection def add_documents_chroma (self, collection_name: str, embeddings_list: list[list[float]], documents_list: list[dict], metadata_list: list[dict]) : """ Please make sure that embeddings_list and metadata_list are matched with documents_list\n Example of one metadata: {"doc_name": "Test2.pdf", "page": "9"}\n The id will be created automatically as uuid v4 The chunks content and metadata will be logged (appended) into ./logs/.json """ collection = self.get_collection_chroma(collection_name) num = len(documents_list) ids=[str(uuid.uuid4()) for i in range(num) ] collection.add( documents= documents_list, metadatas= metadata_list, embeddings= embeddings_list, ids=ids ) logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json") os.makedirs(os.path.dirname(logpath), exist_ok=True) logs = [] try: with open (logpath, 'r', encoding="utf-8") as chunklog: logs = json.load(chunklog) except (FileNotFoundError, json.decoder.JSONDecodeError): logs = [] added_log= [{"chunk_id": ids[i], "metadata": metadata_list[i], "page_content": documents_list[i]} \ for i in range(num)] logs.extend(added_log) # write back with open (logpath, "w", encoding="utf-8") as chunklog: json.dump(logs, chunklog, indent=4) print(f"Logged document information to '{logpath}'.") def query_chroma(self, collection_name: str, query_embeddings: list[list[float]], n_results: int = 5) -> dict: # return n closest results (chunks and metadatas) in order collection = self.get_collection_chroma(collection_name) result = collection.query( query_embeddings=query_embeddings, n_results=n_results, ) return result def update_chroma (self, collection_name: str, id_list: list[str], embeddings_list: list[list[float]], documents_list: list[str], metadata_list: list[dict]): collection = self.get_collection_chroma(collection_name) num = len(documents_list) collection.update( ids=id_list, embeddings=embeddings_list, metadatas=metadata_list, documents=documents_list, ) update_list = [{"chunk_id": id_list[i], "metadata": metadata_list[i], "page_content": documents_list[i]} for i in range(num)] # update the chunk log logs = [] logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json") try: with open (logpath, 'r', encoding="utf-8") as chunklog: logs = json.load(chunklog) except (FileNotFoundError, json.decoder.JSONDecodeError): logs = [] # old_log does not exist or empty, then no need to update else: for i in range(num): for log in logs: if (log["chunk_id"] == update_list[i]["chunk_id"]): log["metadata"] = update_list[i]["metadata"] log["page_content"] = update_list[i]["page_content"] break with open (logpath, "w", encoding="utf-8") as chunklog: json.dump(logs, chunklog, indent=4) print(f"Updated log file at '{logpath}'.") def delete_collection_entries_chroma(self, collection_name: str, id_list: list[str]): collection = self.get_collection_chroma(collection_name) collection.delete(ids=id_list) print(f"Deleted entries with ids: {id_list} from collection '{collection_name}'.") def delete_collection_chroma(self, collection_name: str): print(f"The collection {collection_name} will be deleted forever!") self.client.delete_collection(collection_name) try: logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json") print(f"Collection {collection_name} has been removed, deleting log file of this collection") os.remove(logpath) except FileNotFoundError: print("The log of this collection did not exist!") def list_collections_chroma(self): collections = self.client.list_collections() # Generate a legal collection name from a PDF filename def legal_pdf(filename: str) -> str: pdf_index = filename.lower().rfind('.pdf') if pdf_index != -1: name_before_pdf = filename[:pdf_index] else: name_before_pdf = filename name_before_pdf = name_before_pdf.strip() name = re.sub(r'[^a-zA-Z0-9._-]', '', name_before_pdf) name = name.lower() while '..' in name: name = name.replace('..', '.') name = name[:63] if len(name) < 3: name = name.ljust(3, '0') # fill with '0' if the length is less than 3 if not re.match(r'^[a-z0-9]', name): name = 'a' + name[1:] if not re.match(r'[a-z0-9]$', name): name = name[:-1] + 'a' ip_pattern = re.compile(r'^(\d{1,3}\.){3}\d{1,3}$') if ip_pattern.match(name): name = 'ip_' + name return name def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings, mode: str): # Load and split the PDF split_start_time = time.time() splitters = TextSplitting().mineru_recursive_splitter(file_path, survey_id, mode) if not splitters: raise ValueError(f"Failed to load or split PDF: {file_path}") documents_list = [document.page_content for document in splitters] for i in range(len(documents_list)): documents_list[i] = documents_list[i].replace('\n', ' ') print(f"Splitting took {time.time() - split_start_time} seconds.") # Embed the documents # embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") embed_start_time = time.time() doc_results = embedder.embed_documents(documents_list) if isinstance(doc_results, torch.Tensor): embeddings_list = doc_results.tolist() else: embeddings_list = doc_results print(f"Embedding took {time.time() - embed_start_time} seconds.") # Prepare metadata metadata_list = [{"doc_name": os.path.basename(file_path)} for i in range(len(documents_list))] title = os.path.splitext(os.path.basename(file_path))[0] title_new = title.strip() invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*','_'] for char in invalid_chars: title_new = title_new.replace(char, ' ') collection_name = legal_pdf(title_new) retriever = Retriever() retriever.list_collections_chroma() retriever.create_collection_chroma(collection_name) retriever.add_documents_chroma( collection_name=collection_name, embeddings_list=embeddings_list, documents_list=documents_list, metadata_list=metadata_list ) return collection_name, embeddings_list, documents_list, metadata_list,title_new def query_embeddings(collection_name: str, query_list: list): try: embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) except Exception as e: print(f"Error initializing embedder: {e}") embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") retriever = Retriever() final_context = "" seen_chunks = set() for query_text in query_list: query_embeddings = embedder.embed_query(query_text) query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=2) query_result_chunks = query_result["documents"][0] # query_result_ids = query_result["ids"][0] for chunk in query_result_chunks: if chunk not in seen_chunks: final_context += chunk.strip() + "//\n" seen_chunks.add(chunk) return final_context # new, may be in parallel def query_embeddings_new(collection_name: str, query_list: list): try: embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) except Exception as e: print(f"Error initializing embedder: {e}") embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") retriever = Retriever() final_context = "" seen_chunks = set() def process_query(query_text): query_embeddings = embedder.embed_query(query_text) query_result = retriever.query_chroma( collection_name=collection_name, query_embeddings=[query_embeddings], n_results=2 ) query_result_chunks = query_result["documents"][0] return query_result_chunks with concurrent.futures.ThreadPoolExecutor() as executor: futures = {executor.submit(process_query, query_text): query_text for query_text in query_list} for future in concurrent.futures.as_completed(futures): query_result_chunks = future.result() for chunk in query_result_chunks: if chunk not in seen_chunks: final_context += chunk.strip() + "//\n" seen_chunks.add(chunk) return final_context # wza def query_embeddings_new_new(collection_name: str, query_list: list): try: embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) except Exception as e: print(f"Error initializing embedder: {e}") embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") retriever = Retriever() final_context = "" # Stores concatenated context citation_data_list = [] # Stores chunk content and collection name as source seen_chunks = set() # Ensures unique chunks are added def process_query(query_text): # Embed the query text and retrieve relevant chunks query_embeddings = embedder.embed_query(query_text) query_result = retriever.query_chroma( collection_name=collection_name, query_embeddings=[query_embeddings], n_results=5 # Fixed number of results ) return query_result with concurrent.futures.ThreadPoolExecutor() as executor: future_to_query = {executor.submit(process_query, q): q for q in query_list} for future in concurrent.futures.as_completed(future_to_query): query_text = future_to_query[future] try: query_result = future.result() except Exception as e: print(f"Query '{query_text}' failed with exception: {e}") continue if "documents" not in query_result or "distances" not in query_result: continue if not query_result["documents"] or not query_result["distances"]: continue docs_list = query_result["documents"][0] if query_result["documents"] else [] dist_list = query_result["distances"][0] if query_result["distances"] else [] if len(docs_list) != len(dist_list): continue for chunk, distance in zip(docs_list, dist_list): processed_chunk = chunk.strip() if processed_chunk not in seen_chunks: final_context += processed_chunk + "//\n" seen_chunks.add(processed_chunk) citation_data_list.append({ "source": collection_name, "distance": distance, "content": processed_chunk, }) return final_context, citation_data_list # concurrent version for both collection names and queries def query_multiple_collections(collection_names: list[str], query_list: list[str], survey_id: str) -> dict: """ Query multiple collections in parallel and return the combined results. Args: collection_names (list[str]): List of collection names to query. query_list (list[str]): List of queries to execute on each collection. Returns: dict: Combined results from all collections, grouped by collection. """ # Define embedder inside the function try: embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) except Exception as e: print(f"Error initializing embedder: {e}") embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") retriever = Retriever() def query_single_collection(collection_name: str): """ Query a single collection for all queries in the query_list. """ final_context = "" seen_chunks = set() def process_query(query_text): # Embed the query query_embeddings = embedder.embed_query(query_text) # Query the collection query_result = retriever.query_chroma( collection_name=collection_name, query_embeddings=[query_embeddings], n_results=5 ) query_result_chunks = query_result["documents"][0] return query_result_chunks # Process all queries in parallel for the given collection with concurrent.futures.ThreadPoolExecutor() as executor: futures = {executor.submit(process_query, query_text): query_text for query_text in query_list} for future in concurrent.futures.as_completed(futures): query_result_chunks = future.result() for chunk in query_result_chunks: if chunk not in seen_chunks: final_context += chunk.strip() + "//\n" seen_chunks.add(chunk) return final_context # Outer parallelism for multiple collections results = {} with concurrent.futures.ThreadPoolExecutor() as executor: futures = {executor.submit(query_single_collection, collection_name): collection_name for collection_name in collection_names} for future in concurrent.futures.as_completed(futures): collection_name = futures[future] results[collection_name] = future.result() # Automatically save the results to a JSON file file_path = get_path('info', survey_id, 'retrieved_context.json') with open(file_path, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=4) print(f"Results saved to {file_path}") return results