File size: 10,122 Bytes
abd20d0
3bd446c
abd20d0
7df75ff
353a440
3bd446c
353a440
 
4ea2b30
33deb8d
4ea2b30
d197e7f
7df75ff
 
21537b7
d197e7f
21537b7
e0aff18
d197e7f
 
 
e0aff18
d197e7f
 
 
e0aff18
7df75ff
3bd446c
d197e7f
 
abd20d0
d197e7f
 
 
3bd446c
d197e7f
 
 
7df75ff
3bd446c
 
d197e7f
 
 
 
 
 
 
 
 
2ab36c4
3b25ef5
 
 
d197e7f
3b25ef5
 
 
 
3bd446c
3b25ef5
 
 
2ab36c4
3b25ef5
 
abd20d0
3b25ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ab36c4
3b25ef5
 
 
2ab36c4
3b25ef5
 
 
 
 
 
 
 
 
d197e7f
abd20d0
d197e7f
abd20d0
 
 
 
 
 
 
 
 
 
 
 
 
 
d197e7f
3b25ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ab36c4
3b25ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ab36c4
3b25ef5
 
 
2ab36c4
3b25ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d197e7f
 
 
 
 
 
 
 
 
2ab36c4
3b25ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ab36c4
3b25ef5
 
 
2ab36c4
3b25ef5
 
 
d197e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import os
from typing import TYPE_CHECKING, Any, Optional

import weave

if TYPE_CHECKING:
    from byaldi import RAGMultiModalModel

import wandb
from PIL import Image

from ..utils import get_wandb_artifact


class CalPaliRetriever(weave.Model):
    """
    CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.

    This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
    It can be initialized with a pre-trained model or from a specified W&B artifact. The class
    also provides methods to index new data and to predict/retrieve documents based on a query.

    Attributes:
        model_name (str): The name of the model to be used for retrieval.
    """

    model_name: str
    _docs_retrieval_model: Optional["RAGMultiModalModel"] = None
    _metadata: Optional[dict] = None
    _data_artifact_dir: Optional[str] = None

    def __init__(
        self,
        model_name: str = "vidore/colpali-v1.2",
        docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
        data_artifact_dir: Optional[str] = None,
        metadata_dataset_name: Optional[str] = None,
    ):
        super().__init__(model_name=model_name)
        from byaldi import RAGMultiModalModel

        self._docs_retrieval_model = (
            docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
        )
        self._data_artifact_dir = data_artifact_dir
        self._metadata = (
            [dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
            if metadata_dataset_name
            else None
        )

    def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
        """
        Indexes a dataset of documents and saves the index as a Weave artifact.

        This method retrieves a dataset of documents from a Weave artifact using the provided
        data artifact name. It then indexes the documents using the document retrieval model
        and assigns the specified index name. The index is stored locally without storing the
        collection with the index and overwrites any existing index with the same name.

        If a Weave run is active, the method creates a new Weave artifact with the specified
        index name and type "colpali-index". It adds the local index directory to the artifact
        and saves it to Weave, including metadata with the provided Weave dataset name.

        !!! example "Indexing Data"
            First you need to install `Byaldi` library by Answer.ai.

            ```bash
            uv pip install Byaldi>=0.0.5
            ```

            Next, you can index the data by running the following code:

            ```python
            import wandb
            from medrag_multi_modal.retrieval import CalPaliRetriever

            wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
            retriever = CalPaliRetriever()
            retriever.index(
                data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
                weave_dataset_name="grays-anatomy-images:v0",
                index_name="grays-anatomy",
            )
            ```

        ??? note "Optional Speedup using Flash Attention"
            If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
            installing the `flash-attn` package.

            ```bash
            uv pip install flash-attn --no-build-isolation
            ```

        Args:
            data_artifact_name (str): The name of the Weave artifact containing the dataset.
            weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata.
            index_name (str): The name to assign to the created index.
        """
        data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
        self._docs_retrieval_model.index(
            input_path=data_artifact_dir,
            index_name=index_name,
            store_collection_with_index=False,
            overwrite=True,
        )
        if wandb.run:
            artifact = wandb.Artifact(
                name=index_name,
                type="colpali-index",
                metadata={"weave_dataset_name": weave_dataset_name},
            )
            artifact.add_dir(
                local_path=os.path.join(".byaldi", index_name), name="index"
            )
            artifact.save()

    @classmethod
    def from_wandb_artifact(
        cls,
        index_artifact_name: str,
        metadata_dataset_name: str,
        data_artifact_name: str,
    ):
        """
        Creates an instance of the class from Weights & Biases (wandb) artifacts.

        This method retrieves the necessary artifacts from wandb to initialize the
        ColPaliRetriever. It fetches the index artifact directory and the data artifact
        directory using the provided artifact names. It then loads the document retrieval
        model from the index path within the index artifact directory. Finally, it returns
        an instance of the class initialized with the retrieved document retrieval model,
        metadata dataset name, and data artifact directory.

        !!! example "Retrieving Documents"
            First you need to install `Byaldi` library by Answer.ai.

            ```bash
            uv pip install Byaldi>=0.0.5
            ```

            Next, you can retrieve the documents by running the following code:

            ```python
            import weave

            import wandb
            from medrag_multi_modal.retrieval import CalPaliRetriever

            weave.init(project_name="ml-colabs/medrag-multi-modal")
            retriever = CalPaliRetriever.from_wandb_artifact(
                index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
                metadata_dataset_name="grays-anatomy-images:v0",
                data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
            )
            ```

        ??? note "Optional Speedup using Flash Attention"
            If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
            installing the `flash-attn` package.

            ```bash
            uv pip install flash-attn --no-build-isolation
            ```

        Args:
            index_artifact_name (str): The name of the wandb artifact containing the index.
            metadata_dataset_name (str): The name of the dataset containing metadata.
            data_artifact_name (str): The name of the wandb artifact containing the data.

        Returns:
            An instance of the class initialized with the retrieved document retrieval model,
            metadata dataset name, and data artifact directory.
        """
        from byaldi import RAGMultiModalModel

        index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
        data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
        docs_retrieval_model = RAGMultiModalModel.from_index(
            index_path=os.path.join(index_artifact_dir, "index")
        )
        return cls(
            docs_retrieval_model=docs_retrieval_model,
            metadata_dataset_name=metadata_dataset_name,
            data_artifact_dir=data_artifact_dir,
        )

    @weave.op()
    def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
        """
        Predicts and retrieves the top-k most relevant documents/images for a given query
        using ColPali.

        This function uses the document retrieval model to search for the most relevant
        documents based on the provided query. It returns a list of dictionaries, each
        containing the document image, document ID, and the relevance score.

        !!! example "Retrieving Documents"
            First you need to install `Byaldi` library by Answer.ai.

            ```bash
            uv pip install Byaldi>=0.0.5
            ```

            Next, you can retrieve the documents by running the following code:

            ```python
            import weave

            import wandb
            from medrag_multi_modal.retrieval import CalPaliRetriever

            weave.init(project_name="ml-colabs/medrag-multi-modal")
            retriever = CalPaliRetriever.from_wandb_artifact(
                index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
                metadata_dataset_name="grays-anatomy-images:v0",
                data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
            )
            retriever.predict(
                query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
                top_k=3,
            )
            ```

        ??? note "Optional Speedup using Flash Attention"
            If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
            installing the `flash-attn` package.

            ```bash
            uv pip install flash-attn --no-build-isolation
            ```

        Args:
            query (str): The search query string.
            top_k (int, optional): The number of top results to retrieve. Defaults to 10.

        Returns:
            list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
                - "doc_image" (PIL.Image.Image): The image of the document.
                - "doc_id" (str): The ID of the document.
                - "score" (float): The relevance score of the document.
        """
        results = self._docs_retrieval_model.search(query=query, k=top_k)
        retrieved_results = []
        for result in results:
            retrieved_results.append(
                {
                    "doc_image": Image.open(
                        os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
                    ),
                    "doc_id": result["doc_id"],
                    "score": result["score"],
                }
            )
        return retrieved_results