Spaces:
Sleeping
Sleeping
File size: 5,074 Bytes
15aea1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
from typing import List, Union
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever
from langchain_chroma import Chroma
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from tqdm import tqdm
from transformers import AutoTokenizer
from ..utilities.llm_models import get_llm_model_embedding
from .document_loader import DocumentLoader
from .vector_store import get_collection_name
from .prompts import DEFAULT_QUERY_PROMPT
class VectorStoreManager:
"""
Manages vector store initialization, updates, and retrieval.
"""
def __init__(self, persist_directory: str, batch_size: int = 64):
"""
Initializes the VectorStoreManager with the given parameters.
Args:
persist_directory (str): Directory to persist the vector store.
batch_size (int): Number of documents to process in each batch.
"""
self.persist_directory = persist_directory
self.batch_size = batch_size
self.embeddings = get_llm_model_embedding()
self.collection_name = get_collection_name()
self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = {
"chroma": None,
"bm25": None,
}
self.tokenizer = AutoTokenizer.from_pretrained(
os.getenv("HF_MODEL", "meta-llama/Llama-3.2-1B")
)
self.vs_initialized = False
self.vector_store = None
def _batch_process_documents(self, documents: List[Document]):
"""
Processes documents in batches for vector store initialization.
Args:
documents (List[Document]): List of documents to process.
"""
for i in tqdm(
range(0, len(documents), self.batch_size), desc="Processing documents"
):
batch = documents[i : i + self.batch_size]
if not self.vs_initialized:
self.vector_stores["chroma"] = Chroma.from_documents(
collection_name=self.collection_name,
documents=batch,
embedding=self.embeddings,
persist_directory=self.persist_directory,
)
self.vs_initialized = True
else:
self.vector_stores["chroma"].add_documents(batch)
self.vector_stores["bm25"] = BM25Retriever.from_documents(
documents, tokenizer=self.tokenizer
)
def initialize_vector_store(self, documents: List[Document] = None):
"""
Initializes or loads the vector store.
Args:
documents (List[Document], optional): List of documents to initialize the vector store. Defaults to None.
"""
if documents:
self._batch_process_documents(documents)
else:
self.vector_stores["chroma"] = Chroma(
collection_name=self.collection_name,
persist_directory=self.persist_directory,
embedding_function=self.embeddings,
)
all_documents = self.vector_stores["chroma"].get(
include=["documents", "metadatas"]
)
documents = [
Document(page_content=content, id=doc_id, metadata=metadata)
for content, doc_id, metadata in zip(
all_documents["documents"],
all_documents["ids"],
all_documents["metadatas"],
)
]
self.vector_stores["bm25"] = BM25Retriever.from_documents(documents)
self.vs_initialized = True
def create_retriever(
self, llm, n_documents: int, bm25_portion: float = 0.8
) -> EnsembleRetriever:
"""
Creates an ensemble retriever combining Chroma and BM25.
Args:
llm: Language model to use for retrieval.
n_documents (int): Number of documents to retrieve.
bm25_portion (float): Proportion of BM25 retriever in the ensemble.
Returns:
EnsembleRetriever: The created ensemble retriever.
"""
self.vector_stores["bm25"].k = n_documents
self.vector_store = MultiQueryRetriever.from_llm(
retriever=EnsembleRetriever(
retrievers=[
self.vector_stores["bm25"],
self.vector_stores["chroma"].as_retriever(
search_kwargs={"k": n_documents}
),
],
weights=[bm25_portion, 1 - bm25_portion],
),
llm=llm,
include_original=True,
prompt=DEFAULT_QUERY_PROMPT
)
return self.vector_store
def load_and_process_documents(self, doc_dir) -> List[Document]:
"""
Loads and processes documents from the specified directory.
Returns:
List[Document]: List of loaded and processed documents.
"""
loader = DocumentLoader(doc_dir)
return loader.load_documents()
|