medrag / medrag_multi_modal /retrieval /multi_modal_retrieval.py
mratanusarkar's picture
chore: format & linting + __init__ + fix: imports
e0aff18
raw
history blame
5.6 kB
import os
from typing import Any, Optional
import wandb
import weave
from byaldi import RAGMultiModalModel
from PIL import Image
from ..utils import get_wandb_artifact
class MultiModalRetriever(weave.Model):
"""
MultiModalRetriever is a class that facilitates the retrieval of page images using ColPali.
This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
It can be initialized with a pre-trained model or from a specified W&B artifact. The class
also provides methods to index new data and to predict/retrieve documents based on a query.
!!! example "Indexing Data"
```python
import wandb
from medrag_multi_modal.retrieval import MultiModalRetriever
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
retriever = MultiModalRetriever()
retriever.index(
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
weave_dataset_name="grays-anatomy-images:v0",
index_name="grays-anatomy",
)
```
!!! example "Retrieving Documents"
```python
import weave
import wandb
from medrag_multi_modal.retrieval import MultiModalRetriever
weave.init(project_name="ml-colabs/medrag-multi-modal")
retriever = MultiModalRetriever.from_artifact(
index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
metadata_dataset_name="grays-anatomy-images:v0",
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
)
retriever.predict(
query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
top_k=3,
)
```
Attributes:
model_name (str): The name of the model to be used for retrieval.
"""
model_name: str
_docs_retrieval_model: Optional[RAGMultiModalModel] = None
_metadata: Optional[dict] = None
_data_artifact_dir: Optional[str] = None
def __init__(
self,
model_name: str = "vidore/colpali-v1.2",
docs_retrieval_model: Optional[RAGMultiModalModel] = None,
data_artifact_dir: Optional[str] = None,
metadata_dataset_name: Optional[str] = None,
):
super().__init__(model_name=model_name)
self._docs_retrieval_model = (
docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
)
self._data_artifact_dir = data_artifact_dir
self._metadata = (
[dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
if metadata_dataset_name
else None
)
@classmethod
def from_artifact(
cls,
index_artifact_name: str,
metadata_dataset_name: str,
data_artifact_name: str,
):
index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
docs_retrieval_model = RAGMultiModalModel.from_index(
index_path=os.path.join(index_artifact_dir, "index")
)
return cls(
docs_retrieval_model=docs_retrieval_model,
metadata_dataset_name=metadata_dataset_name,
data_artifact_dir=data_artifact_dir,
)
def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
self._docs_retrieval_model.index(
input_path=data_artifact_dir,
index_name=index_name,
store_collection_with_index=False,
overwrite=True,
)
if wandb.run:
artifact = wandb.Artifact(
name=index_name,
type="colpali-index",
metadata={"weave_dataset_name": weave_dataset_name},
)
artifact.add_dir(
local_path=os.path.join(".byaldi", index_name), name="index"
)
artifact.save()
@weave.op()
def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
"""
Predicts and retrieves the top-k most relevant documents/images for a given query
using ColPali.
This function uses the document retrieval model to search for the most relevant
documents based on the provided query. It returns a list of dictionaries, each
containing the document image, document ID, and the relevance score.
Args:
query (str): The search query string.
top_k (int, optional): The number of top results to retrieve. Defaults to 10.
Returns:
list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
- "doc_image" (PIL.Image.Image): The image of the document.
- "doc_id" (str): The ID of the document.
- "score" (float): The relevance score of the document.
"""
results = self._docs_retrieval_model.search(query=query, k=top_k)
retrieved_results = []
for result in results:
retrieved_results.append(
{
"doc_image": Image.open(
os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
),
"doc_id": result["doc_id"],
"score": result["score"],
}
)
return retrieved_results