geekyrakshit commited on
Commit
6c6905f
·
1 Parent(s): e197ad0

update: MedQAAssistant + FigureAnnotatorFromPageImage

Browse files
medrag_multi_modal/assistant/figure_annotation.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
  from glob import glob
3
- from typing import Union
4
 
5
  import cv2
6
  import weave
7
  from PIL import Image
8
  from pydantic import BaseModel
9
- from rich.progress import track
10
 
11
  from ..utils import get_wandb_artifact, read_jsonl_file
12
  from .llm_client import LLMClient
@@ -23,7 +22,8 @@ class FigureAnnotations(BaseModel):
23
 
24
  class FigureAnnotatorFromPageImage(weave.Model):
25
  """
26
- `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate figures from a page image of a scientific textbook.
 
27
 
28
  !!! example "Example Usage"
29
  ```python
@@ -39,19 +39,35 @@ class FigureAnnotatorFromPageImage(weave.Model):
39
  figure_annotator = FigureAnnotatorFromPageImage(
40
  figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
41
  structured_output_llm_client=LLMClient(model_name="gpt-4o"),
 
42
  )
43
- annotations = figure_annotator.predict(
44
- image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6"
45
- )
46
  ```
47
 
48
- Attributes:
49
- figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations from the page image.
50
- structured_output_llm_client (LLMClient): An LLM client used to convert the extracted annotations into a structured format.
 
 
 
 
51
  """
52
 
53
  figure_extraction_llm_client: LLMClient
54
  structured_output_llm_client: LLMClient
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @weave.op()
57
  def annotate_figures(
@@ -92,7 +108,7 @@ Here are some clues you need to follow:
92
  )
93
 
94
  @weave.op()
95
- def predict(self, page_idx: int, image_artifact_address: str):
96
  """
97
  Predicts figure annotations for a specific page in a document.
98
 
@@ -105,22 +121,23 @@ Here are some clues you need to follow:
105
 
106
  Args:
107
  page_idx (int): The index of the page to annotate.
108
- image_artifact_address (str): The address of the image artifact containing the page images.
 
109
 
110
  Returns:
111
- dict: A dictionary containing the page index as the key and the extracted figure annotations
112
- as the value.
113
  """
114
- artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
115
- metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
116
  annotations = {}
117
- for item in track(metadata, description="Annotating images:"):
118
  if item["page_idx"] == page_idx:
119
  page_image_file = os.path.join(
120
- artifact_dir, f"page{item['page_idx']}.png"
121
  )
122
  figure_image_files = glob(
123
- os.path.join(artifact_dir, f"page{item['page_idx']}_fig*.png")
124
  )
125
  if len(figure_image_files) > 0:
126
  page_image = cv2.imread(page_image_file)
 
1
  import os
2
  from glob import glob
3
+ from typing import Optional, Union
4
 
5
  import cv2
6
  import weave
7
  from PIL import Image
8
  from pydantic import BaseModel
 
9
 
10
  from ..utils import get_wandb_artifact, read_jsonl_file
11
  from .llm_client import LLMClient
 
22
 
23
  class FigureAnnotatorFromPageImage(weave.Model):
24
  """
25
+ `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
26
+ figures from a page image of a scientific textbook.
27
 
28
  !!! example "Example Usage"
29
  ```python
 
39
  figure_annotator = FigureAnnotatorFromPageImage(
40
  figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
41
  structured_output_llm_client=LLMClient(model_name="gpt-4o"),
42
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
43
  )
44
+ annotations = figure_annotator.predict(page_idx=34)
 
 
45
  ```
46
 
47
+ Args:
48
+ figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
49
+ from the page image.
50
+ structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
51
+ annotations into a structured format.
52
+ image_artifact_address (Optional[str]): The address of the image artifact containing the
53
+ page images.
54
  """
55
 
56
  figure_extraction_llm_client: LLMClient
57
  structured_output_llm_client: LLMClient
58
+ _artifact_dir: str
59
+
60
+ def __init__(
61
+ self,
62
+ figure_extraction_llm_client: LLMClient,
63
+ structured_output_llm_client: LLMClient,
64
+ image_artifact_address: Optional[str] = None,
65
+ ):
66
+ super().__init__(
67
+ figure_extraction_llm_client=figure_extraction_llm_client,
68
+ structured_output_llm_client=structured_output_llm_client,
69
+ )
70
+ self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
71
 
72
  @weave.op()
73
  def annotate_figures(
 
108
  )
109
 
110
  @weave.op()
111
+ def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]:
112
  """
113
  Predicts figure annotations for a specific page in a document.
114
 
 
121
 
122
  Args:
123
  page_idx (int): The index of the page to annotate.
124
+ image_artifact_address (str): The address of the image artifact containing the
125
+ page images.
126
 
127
  Returns:
128
+ dict: A dictionary containing the page index as the key and the extracted figure
129
+ annotations as the value.
130
  """
131
+
132
+ metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
133
  annotations = {}
134
+ for item in metadata:
135
  if item["page_idx"] == page_idx:
136
  page_image_file = os.path.join(
137
+ self._artifact_dir, f"page{item['page_idx']}.png"
138
  )
139
  figure_image_files = glob(
140
+ os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
141
  )
142
  if len(figure_image_files) > 0:
143
  page_image = cv2.imread(page_image_file)
medrag_multi_modal/assistant/medqa_assistant.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
  import weave
4
 
5
  from ..retrieval import SimilarityMetric
@@ -8,7 +6,50 @@ from .llm_client import LLMClient
8
 
9
 
10
  class MedQAAssistant(weave.Model):
11
- """Cuming"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  llm_client: LLMClient
14
  retriever: weave.Model
@@ -17,7 +58,25 @@ class MedQAAssistant(weave.Model):
17
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
18
 
19
  @weave.op()
20
- def predict(self, query: str, image_artifact_address: Optional[str] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  retrieved_chunks = self.retriever.predict(
22
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
23
  )
@@ -29,14 +88,13 @@ class MedQAAssistant(weave.Model):
29
  page_indices.add(int(chunk["page_idx"]))
30
 
31
  figure_descriptions = []
32
- if image_artifact_address is not None:
33
- for page_idx in page_indices:
34
- figure_annotations = self.figure_annotator.predict(
35
- page_idx=page_idx, image_artifact_address=image_artifact_address
36
- )
37
- figure_descriptions += [
38
- item["figure_description"] for item in figure_annotations[page_idx]
39
- ]
40
 
41
  system_prompt = """
42
  You are an expert in medical science. You are given a query and a list of chunks from a medical document.
@@ -46,5 +104,5 @@ class MedQAAssistant(weave.Model):
46
  user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
47
  )
48
  page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
49
- response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
50
  return response
 
 
 
1
  import weave
2
 
3
  from ..retrieval import SimilarityMetric
 
6
 
7
 
8
  class MedQAAssistant(weave.Model):
9
+ """
10
+ `MedQAAssistant` is a class designed to assist with medical queries by leveraging a
11
+ language model client, a retriever model, and a figure annotator.
12
+
13
+ !!! example "Usage Example"
14
+ ```python
15
+ import weave
16
+ from dotenv import load_dotenv
17
+
18
+ from medrag_multi_modal.assistant import (
19
+ FigureAnnotatorFromPageImage,
20
+ LLMClient,
21
+ MedQAAssistant,
22
+ )
23
+ from medrag_multi_modal.retrieval import MedCPTRetriever
24
+
25
+ load_dotenv()
26
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
27
+
28
+ llm_client = LLMClient(model_name="gemini-1.5-flash")
29
+
30
+ retriever=MedCPTRetriever.from_wandb_artifact(
31
+ chunk_dataset_name="grays-anatomy-chunks:v0",
32
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
33
+ )
34
+
35
+ figure_annotator=FigureAnnotatorFromPageImage(
36
+ figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
37
+ structured_output_llm_client=LLMClient(model_name="gpt-4o"),
38
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
39
+ )
40
+ medqa_assistant = MedQAAssistant(
41
+ llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
42
+ )
43
+ medqa_assistant.predict(query="What is ribosome?")
44
+ ```
45
+
46
+ Args:
47
+ llm_client (LLMClient): The language model client used to generate responses.
48
+ retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
49
+ figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
50
+ top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
51
+ retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
52
+ """
53
 
54
  llm_client: LLMClient
55
  retriever: weave.Model
 
58
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
59
 
60
  @weave.op()
61
+ def predict(self, query: str) -> str:
62
+ """
63
+ Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
64
+ from a medical document and using a language model to generate the final response.
65
+
66
+ This function performs the following steps:
67
+ 1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
68
+ 2. Extracts the text and page indices from the retrieved chunks.
69
+ 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
70
+ 4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
71
+ 5. Uses the language model client to generate a response based on the constructed prompts.
72
+ 6. Appends the source information (page numbers) to the generated response.
73
+
74
+ Args:
75
+ query (str): The medical query to be answered.
76
+
77
+ Returns:
78
+ str: The generated response to the query, including source information.
79
+ """
80
  retrieved_chunks = self.retriever.predict(
81
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
82
  )
 
88
  page_indices.add(int(chunk["page_idx"]))
89
 
90
  figure_descriptions = []
91
+ for page_idx in page_indices:
92
+ figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
93
+ page_idx
94
+ ]
95
+ figure_descriptions += [
96
+ item["figure_description"] for item in figure_annotations
97
+ ]
 
98
 
99
  system_prompt = """
100
  You are an expert in medical science. You are given a query and a list of chunks from a medical document.
 
104
  user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
105
  )
106
  page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
107
+ response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
108
  return response