Spaces:
Sleeping
Sleeping
File size: 3,952 Bytes
15aea1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
from typing import List
from langchain.retrievers import MultiQueryRetriever
from langchain_chroma import Chroma
from langchain_core.documents import Document
from tqdm import tqdm
from ..utilities.llm_models import get_llm_model_embedding
from .document_loader import DocumentLoader
from .prompts import DEFAULT_QUERY_PROMPT
def get_collection_name() -> str:
"""
Derives the collection name from an environment variable.
Returns:
str: Processed collection name.
"""
return "medivocate-" + os.getenv("HF_MODEL", "default_model").split(":")[0].split("/")[-1]
class VectorStoreManager:
"""
Manages vector store initialization, updates, and retrieval.
"""
def __init__(self, persist_directory: str, batch_size: int = 64):
"""
Initializes the VectorStoreManager with the given parameters.
Args:
persist_directory (str): Directory to persist the vector store.
batch_size (int): Number of documents to process in each batch.
"""
self.persist_directory = persist_directory
self.batch_size = batch_size
self.embeddings = get_llm_model_embedding()
self.collection_name = get_collection_name()
self.vector_stores: dict[str, Chroma] = {"chroma": None}
self.vs_initialized = False
def _batch_process_documents(self, documents: List[Document]):
"""
Processes documents in batches for vector store initialization.
Args:
documents (List[Document]): List of documents to process.
"""
for i in tqdm(
range(0, len(documents), self.batch_size), desc="Processing documents"
):
batch = documents[i : i + self.batch_size]
if not self.vs_initialized:
self.vector_stores["chroma"] = Chroma.from_documents(
collection_name=self.collection_name,
documents=batch,
embedding=self.embeddings,
persist_directory=self.persist_directory,
)
self.vs_initialized = True
else:
self.vector_stores["chroma"].add_documents(batch)
def initialize_vector_store(self, documents: List[Document] = None):
"""
Initializes or loads the vector store.
Args:
documents (List[Document], optional): List of documents to initialize the vector store with.
"""
if documents:
self._batch_process_documents(documents)
else:
self.vector_stores["chroma"] = Chroma(
collection_name=self.collection_name,
persist_directory=self.persist_directory,
embedding_function=self.embeddings,
)
self.vs_initialized = True
def create_retriever(
self, llm, n_documents: int, bm25_portion: float = 0.8
) -> MultiQueryRetriever:
"""
Creates a retriever using Chroma.
Args:
llm: Language model to use for the retriever.
n_documents (int): Number of documents to retrieve.
bm25_portion (float): Portion of BM25 to use in the retriever.
Returns:
MultiQueryRetriever: Configured retriever.
"""
self.vector_store = MultiQueryRetriever.from_llm(
retriever=self.vector_stores["chroma"].as_retriever(
search_kwargs={"k": n_documents}
),
llm=llm,
include_original=True,
prompt=DEFAULT_QUERY_PROMPT
)
return self.vector_store
def load_and_process_documents(self, doc_dir: str) -> List[Document]:
"""
Loads and processes documents from the specified directory.
Returns:
List[Document]: List of processed documents.
"""
loader = DocumentLoader(doc_dir)
return loader.load_documents()
|