File size: 3,952 Bytes
15aea1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from typing import List

from langchain.retrievers import MultiQueryRetriever
from langchain_chroma import Chroma
from langchain_core.documents import Document
from tqdm import tqdm

from ..utilities.llm_models import get_llm_model_embedding
from .document_loader import DocumentLoader
from .prompts import DEFAULT_QUERY_PROMPT

def get_collection_name() -> str:
    """
    Derives the collection name from an environment variable.

    Returns:
        str: Processed collection name.
    """
    return "medivocate-" + os.getenv("HF_MODEL", "default_model").split(":")[0].split("/")[-1]


class VectorStoreManager:
    """
    Manages vector store initialization, updates, and retrieval.
    """

    def __init__(self, persist_directory: str, batch_size: int = 64):
        """
        Initializes the VectorStoreManager with the given parameters.

        Args:
            persist_directory (str): Directory to persist the vector store.
            batch_size (int): Number of documents to process in each batch.
        """
        self.persist_directory = persist_directory
        self.batch_size = batch_size
        self.embeddings = get_llm_model_embedding()
        self.collection_name = get_collection_name()
        self.vector_stores: dict[str, Chroma] = {"chroma": None}
        self.vs_initialized = False

    def _batch_process_documents(self, documents: List[Document]):
        """
        Processes documents in batches for vector store initialization.

        Args:
            documents (List[Document]): List of documents to process.
        """
        for i in tqdm(
            range(0, len(documents), self.batch_size), desc="Processing documents"
        ):
            batch = documents[i : i + self.batch_size]
            if not self.vs_initialized:
                self.vector_stores["chroma"] = Chroma.from_documents(
                    collection_name=self.collection_name,
                    documents=batch,
                    embedding=self.embeddings,
                    persist_directory=self.persist_directory,
                )
                self.vs_initialized = True
            else:
                self.vector_stores["chroma"].add_documents(batch)

    def initialize_vector_store(self, documents: List[Document] = None):
        """
        Initializes or loads the vector store.

        Args:
            documents (List[Document], optional): List of documents to initialize the vector store with.
        """
        if documents:
            self._batch_process_documents(documents)
        else:
            self.vector_stores["chroma"] = Chroma(
                collection_name=self.collection_name,
                persist_directory=self.persist_directory,
                embedding_function=self.embeddings,
            )
        self.vs_initialized = True

    def create_retriever(
        self, llm, n_documents: int, bm25_portion: float = 0.8
    ) -> MultiQueryRetriever:
        """
        Creates a retriever using Chroma.

        Args:
            llm: Language model to use for the retriever.
            n_documents (int): Number of documents to retrieve.
            bm25_portion (float): Portion of BM25 to use in the retriever.

        Returns:
            MultiQueryRetriever: Configured retriever.
        """
        self.vector_store = MultiQueryRetriever.from_llm(
            retriever=self.vector_stores["chroma"].as_retriever(
                search_kwargs={"k": n_documents}
            ),
            llm=llm,
            include_original=True,
            prompt=DEFAULT_QUERY_PROMPT
        )
        return self.vector_store

    def load_and_process_documents(self, doc_dir: str) -> List[Document]:
        """
        Loads and processes documents from the specified directory.

        Returns:
            List[Document]: List of processed documents.
        """
        loader = DocumentLoader(doc_dir)
        return loader.load_documents()