geekyrakshit commited on
Commit
bb79bf4
·
unverified ·
2 Parent(s): a7ff122 d197e7f

Merge pull request #4 from soumik12345/feat/colpali-retrieval

Browse files
.gitignore CHANGED
@@ -7,4 +7,6 @@ cursor_prompt.txt
7
  test.py
8
  **.pdf
9
  images/
10
- wandb/
 
 
 
7
  test.py
8
  **.pdf
9
  images/
10
+ wandb/
11
+ .byaldi/
12
+ artifacts/
README.md CHANGED
@@ -1,3 +1,3 @@
1
  # MedRAG Multi-Modal
2
 
3
- Multi-modal RAG for medical docmain.
 
1
  # MedRAG Multi-Modal
2
 
3
+ Multi-modal RAG for medical docmain.
docs/index.md CHANGED
@@ -1,3 +1,40 @@
1
  # MedRAG Multi-Modal
2
 
3
- Multi-modal RAG for medical docmain.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # MedRAG Multi-Modal
2
 
3
+ Multi-modal RAG for medical docmain.
4
+
5
+ ## Installation
6
+
7
+ ### For Development
8
+
9
+ For MacOS, you need to run
10
+
11
+ ```bash
12
+ brew install poppler
13
+ ```
14
+
15
+ For Debian/Ubuntu, you need to run
16
+
17
+ ```bash
18
+ sudo apt-get install -y poppler-utils
19
+ ```
20
+
21
+ Then, you can install the dependencies using uv in the virtual environment `.venv` using
22
+
23
+ ```bash
24
+ git clone https://github.com/soumik12345/medrag-multi-modal
25
+ cd medrag-multi-modal
26
+ pip install -U pip uv
27
+ uv sync
28
+ ```
29
+
30
+ After this, you need to activate the virtual environment using
31
+
32
+ ```bash
33
+ source .venv/bin/activate
34
+ ```
35
+
36
+ In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using
37
+
38
+ ```bash
39
+ uv pip install flash-attn --no-build-isolation
40
+ ```
docs/installation/development.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setting up the development environment
2
+
3
+ ## Install Poppler
4
+
5
+ For MacOS, you need to run
6
+
7
+ ```bash
8
+ brew install poppler
9
+ ```
10
+
11
+ For Debian/Ubuntu, you need to run
12
+
13
+ ```bash
14
+ sudo apt-get install -y poppler-utils
15
+ ```
16
+
17
+ ## Install the dependencies
18
+
19
+ Then, you can install the dependencies using uv in the virtual environment `.venv` using
20
+
21
+ ```bash
22
+ git clone https://github.com/soumik12345/medrag-multi-modal
23
+ cd medrag-multi-modal
24
+ pip install -U pip uv
25
+ uv sync
26
+ ```
27
+
28
+ After this, you need to activate the virtual environment using
29
+
30
+ ```bash
31
+ source .venv/bin/activate
32
+ ```
33
+
34
+ ## [Optional] Install Flash Attention
35
+
36
+ In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using
37
+
38
+ ```bash
39
+ uv pip install flash-attn --no-build-isolation
40
+ ```
docs/installation/install.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+
3
+ You just need to clone the repository and run the install.sh script
4
+
5
+ ```bash
6
+ git clone https://github.com/soumik12345/medrag-multi-modal
7
+ cd medrag-multi-modal
8
+ sh install.sh
9
+ ```
docs/retreival/multi_modal_retrieval.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Multi-Modal Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.multi_modal_retrieval
install.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ OS_TYPE=$(uname -s)
4
+
5
+ if [ "$OS_TYPE" = "Darwin" ]; then
6
+ echo "Detected macOS."
7
+ brew install poppler
8
+ elif [ "$OS_TYPE" = "Linux" ]; then
9
+ if [ -f /etc/os-release ]; then
10
+ . /etc/os-release
11
+ if [ "$ID" = "ubuntu" ] || [ "$ID" = "debian" ]; then
12
+ echo "Detected Ubuntu/Debian."
13
+ sudo apt-get update
14
+ sudo apt-get install -y poppler-utils
15
+ else
16
+ echo "Unsupported Linux distribution: $ID"
17
+ exit 1
18
+ fi
19
+ else
20
+ echo "Cannot detect Linux distribution."
21
+ exit 1
22
+ fi
23
+ else
24
+ echo "Unsupported OS: $OS_TYPE"
25
+ exit 1
26
+ fi
27
+
28
+ git clone https://github.com/soumik12345/medrag-multi-modal
29
+ cd medrag-multi-modal
30
+ pip install -U .[core]
medrag_multi_modal/document_loader/load_image.py CHANGED
@@ -13,8 +13,8 @@ from medrag_multi_modal.document_loader.load_text import TextLoader
13
 
14
  class ImageLoader(TextLoader):
15
  """
16
- ImageLoader is a class that extends the `TextLoader` class to handle the extraction and
17
- loading of images from a PDF file.
18
 
19
  This class provides functionality to convert specific pages of a PDF document into images
20
  and optionally publish these images to a Weave dataset.
 
13
 
14
  class ImageLoader(TextLoader):
15
  """
16
+ `ImageLoader` is a class that extends the `TextLoader` class to handle the extraction and
17
+ loading of pages from a PDF file as images.
18
 
19
  This class provides functionality to convert specific pages of a PDF document into images
20
  and optionally publish these images to a Weave dataset.
medrag_multi_modal/retrieval/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .multi_modal_retrieval import MultiModalRetriever
2
+
3
+ __all__ = ["MultiModalRetriever"]
medrag_multi_modal/retrieval/multi_modal_retrieval.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Optional
3
+
4
+ import weave
5
+ from byaldi import RAGMultiModalModel
6
+ from PIL import Image
7
+
8
+ import wandb
9
+
10
+ from ..utils import get_wandb_artifact
11
+
12
+
13
+ class MultiModalRetriever(weave.Model):
14
+ """
15
+ MultiModalRetriever is a class that facilitates the retrieval of page images using ColPali.
16
+
17
+ This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
18
+ It can be initialized with a pre-trained model or from a specified W&B artifact. The class
19
+ also provides methods to index new data and to predict/retrieve documents based on a query.
20
+
21
+ !!! example "Indexing Data"
22
+ ```python
23
+ import wandb
24
+ from medrag_multi_modal.retrieval import MultiModalRetriever
25
+
26
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
27
+ retriever = MultiModalRetriever()
28
+ retriever.index(
29
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
30
+ weave_dataset_name="grays-anatomy-images:v0",
31
+ index_name="grays-anatomy",
32
+ )
33
+ ```
34
+
35
+ !!! example "Retrieving Documents"
36
+ ```python
37
+ import weave
38
+
39
+ import wandb
40
+ from medrag_multi_modal.retrieval import MultiModalRetriever
41
+
42
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
43
+ retriever = MultiModalRetriever.from_artifact(
44
+ index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
45
+ metadata_dataset_name="grays-anatomy-images:v0",
46
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
47
+ )
48
+ retriever.predict(
49
+ query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
50
+ top_k=3,
51
+ )
52
+ ```
53
+
54
+ Attributes:
55
+ model_name (str): The name of the model to be used for retrieval.
56
+ """
57
+ model_name: str
58
+ _docs_retrieval_model: Optional[RAGMultiModalModel] = None
59
+ _metadata: Optional[dict] = None
60
+ _data_artifact_dir: Optional[str] = None
61
+
62
+ def __init__(
63
+ self,
64
+ model_name: str = "vidore/colpali-v1.2",
65
+ docs_retrieval_model: Optional[RAGMultiModalModel] = None,
66
+ data_artifact_dir: Optional[str] = None,
67
+ metadata_dataset_name: Optional[str] = None,
68
+ ):
69
+ super().__init__(model_name=model_name)
70
+ self._docs_retrieval_model = (
71
+ docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
72
+ )
73
+ self._data_artifact_dir = data_artifact_dir
74
+ self._metadata = (
75
+ [dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
76
+ if metadata_dataset_name
77
+ else None
78
+ )
79
+
80
+ @classmethod
81
+ def from_artifact(
82
+ cls,
83
+ index_artifact_name: str,
84
+ metadata_dataset_name: str,
85
+ data_artifact_name: str,
86
+ ):
87
+ index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
88
+ data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
89
+ docs_retrieval_model = RAGMultiModalModel.from_index(
90
+ index_path=os.path.join(index_artifact_dir, "index")
91
+ )
92
+ return cls(
93
+ docs_retrieval_model=docs_retrieval_model,
94
+ metadata_dataset_name=metadata_dataset_name,
95
+ data_artifact_dir=data_artifact_dir,
96
+ )
97
+
98
+ def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
99
+ data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
100
+ self._docs_retrieval_model.index(
101
+ input_path=data_artifact_dir,
102
+ index_name=index_name,
103
+ store_collection_with_index=False,
104
+ overwrite=True,
105
+ )
106
+ if wandb.run:
107
+ artifact = wandb.Artifact(
108
+ name=index_name,
109
+ type="colpali-index",
110
+ metadata={"weave_dataset_name": weave_dataset_name},
111
+ )
112
+ artifact.add_dir(
113
+ local_path=os.path.join(".byaldi", index_name), name="index"
114
+ )
115
+ artifact.save()
116
+
117
+ @weave.op()
118
+ def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
119
+ """
120
+ Predicts and retrieves the top-k most relevant documents/images for a given query
121
+ using ColPali.
122
+
123
+ This function uses the document retrieval model to search for the most relevant
124
+ documents based on the provided query. It returns a list of dictionaries, each
125
+ containing the document image, document ID, and the relevance score.
126
+
127
+ Args:
128
+ query (str): The search query string.
129
+ top_k (int, optional): The number of top results to retrieve. Defaults to 10.
130
+
131
+ Returns:
132
+ list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
133
+ - "doc_image" (PIL.Image.Image): The image of the document.
134
+ - "doc_id" (str): The ID of the document.
135
+ - "score" (float): The relevance score of the document.
136
+ """
137
+ results = self._docs_retrieval_model.search(query=query, k=top_k)
138
+ retrieved_results = []
139
+ for result in results:
140
+ retrieved_results.append(
141
+ {
142
+ "doc_image": Image.open(
143
+ os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
144
+ ),
145
+ "doc_id": result["doc_id"],
146
+ "score": result["score"],
147
+ }
148
+ )
149
+ return retrieved_results
medrag_multi_modal/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+
3
+
4
+ def get_wandb_artifact(artifact_name: str, artifact_type: str) -> str:
5
+ if wandb.run:
6
+ artifact = wandb.use_artifact(artifact_name, type=artifact_type)
7
+ artifact_dir = artifact.download()
8
+ else:
9
+ api = wandb.Api()
10
+ artifact = api.artifact(artifact_name)
11
+ artifact_dir = artifact.download()
12
+ return artifact_dir
mkdocs.yml CHANGED
@@ -59,9 +59,14 @@ extra_javascript:
59
 
60
  nav:
61
  - Home: 'index.md'
 
 
 
62
  - Document Loader:
63
  - Text Loader: 'document_loader/load_text.md'
64
  - Text and Image Loader: 'document_loader/load_text_image.md'
65
  - Image Loader: 'document_loader/load_image.md'
 
 
66
 
67
- repo_url: https://github.com/soumik12345/medrag-multi-modal
 
59
 
60
  nav:
61
  - Home: 'index.md'
62
+ - Setup:
63
+ - Installation: 'installation/install.md'
64
+ - Development: 'installation/development.md'
65
  - Document Loader:
66
  - Text Loader: 'document_loader/load_text.md'
67
  - Text and Image Loader: 'document_loader/load_text_image.md'
68
  - Image Loader: 'document_loader/load_image.md'
69
+ - Retrieval:
70
+ - Multi-Modal Retrieval: 'retreival/multi_modal_retrieval.md'
71
 
72
+ repo_url: https://github.com/soumik12345/medrag-multi-modal