geekyrakshit commited on
Commit
bcd7446
·
1 Parent(s): 76f0b82

add: docs ffor assitant module

Browse files
docs/assistant/figure_annotation.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Figure Annotation
2
+
3
+ ::: medrag_multi_modal.assistant.figure_annotation
docs/assistant/llm_client.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # LLM Client
2
+
3
+ ::: medrag_multi_modal.assistant.llm_client
docs/assistant/medqa_assistant.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # MedQA Assistant
2
+
3
+ ::: medrag_multi_modal.assistant.medqa_assistant
medrag_multi_modal/assistant/figure_annotation.py CHANGED
@@ -22,6 +22,34 @@ class FigureAnnotations(BaseModel):
22
 
23
 
24
  class FigureAnnotatorFromPageImage(weave.Model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  figure_extraction_llm_client: LLMClient
26
  structured_output_llm_client: LLMClient
27
 
@@ -65,6 +93,22 @@ Here are some clues you need to follow:
65
 
66
  @weave.op()
67
  def predict(self, image_artifact_address: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
69
  metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
70
  annotations = []
 
22
 
23
 
24
  class FigureAnnotatorFromPageImage(weave.Model):
25
+ """
26
+ `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate figures from a page image of a scientific textbook.
27
+
28
+ !!! example "Example Usage"
29
+ ```python
30
+ import weave
31
+ from dotenv import load_dotenv
32
+
33
+ from medrag_multi_modal.assistant import (
34
+ FigureAnnotatorFromPageImage, LLMClient
35
+ )
36
+
37
+ load_dotenv()
38
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
39
+ figure_annotator = FigureAnnotatorFromPageImage(
40
+ figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
41
+ structured_output_llm_client=LLMClient(model_name="gpt-4o"),
42
+ )
43
+ annotations = figure_annotator.predict(
44
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6"
45
+ )
46
+ ```
47
+
48
+ Attributes:
49
+ figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations from the page image.
50
+ structured_output_llm_client (LLMClient): An LLM client used to convert the extracted annotations into a structured format.
51
+ """
52
+
53
  figure_extraction_llm_client: LLMClient
54
  structured_output_llm_client: LLMClient
55
 
 
93
 
94
  @weave.op()
95
  def predict(self, image_artifact_address: str):
96
+ """
97
+ Predicts figure annotations for images in a given artifact directory.
98
+
99
+ This function retrieves an artifact directory using the provided image artifact address.
100
+ It reads metadata from a JSONL file in the artifact directory and iterates over each item in the metadata.
101
+ For each item, it constructs the file path for the page image and checks for the presence of figure image files.
102
+ If figure image files are found, it reads and converts the page image, then uses the `annotate_figures` method
103
+ to extract figure annotations from the page image. The extracted annotations are then structured using the
104
+ `extract_structured_output` method and appended to the annotations list.
105
+
106
+ Args:
107
+ image_artifact_address (str): The address of the image artifact.
108
+
109
+ Returns:
110
+ list: A list of dictionaries containing page indices and their corresponding figure annotations.
111
+ """
112
  artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
113
  metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
114
  annotations = []
medrag_multi_modal/assistant/llm_client.py CHANGED
@@ -59,6 +59,17 @@ OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024
59
 
60
 
61
  class LLMClient(weave.Model):
 
 
 
 
 
 
 
 
 
 
 
62
  model_name: str
63
  client_type: Optional[ClientType]
64
 
@@ -196,6 +207,26 @@ class LLMClient(weave.Model):
196
  system_prompt: Optional[Union[str, list[str]]] = None,
197
  schema: Optional[Any] = None,
198
  ) -> Union[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  if self.client_type == ClientType.GEMINI:
200
  return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
201
  elif self.client_type == ClientType.MISTRAL:
 
59
 
60
 
61
  class LLMClient(weave.Model):
62
+ """
63
+ LLMClient is a class that interfaces with different large language model (LLM) providers
64
+ such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with
65
+ these different APIs and provides a unified interface for making predictions.
66
+
67
+ Args:
68
+ model_name (str): The name of the model to be used for predictions.
69
+ client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI).
70
+ If not provided, it is inferred from the model_name.
71
+ """
72
+
73
  model_name: str
74
  client_type: Optional[ClientType]
75
 
 
207
  system_prompt: Optional[Union[str, list[str]]] = None,
208
  schema: Optional[Any] = None,
209
  ) -> Union[str, Any]:
210
+ """
211
+ Predicts the response from a language model based on the provided prompts and schema.
212
+
213
+ This function determines the client type and calls the appropriate SDK execution function
214
+ to get the response from the language model. It supports multiple client types including
215
+ GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding
216
+ execution function with the provided user and system prompts, and an optional schema.
217
+
218
+ Args:
219
+ user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model.
220
+ system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model.
221
+ schema (Optional[Any]): The schema to be used for parsing the response, if applicable.
222
+
223
+ Returns:
224
+ Union[str, Any]: The response from the language model, which could be a string or any other type
225
+ depending on the schema provided.
226
+
227
+ Raises:
228
+ ValueError: If the client type is invalid.
229
+ """
230
  if self.client_type == ClientType.GEMINI:
231
  return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
232
  elif self.client_type == ClientType.MISTRAL:
medrag_multi_modal/assistant/medqa_assistant.py CHANGED
@@ -1,21 +1,19 @@
1
- from typing import Optional
2
 
3
  import weave
4
- from PIL import Image
5
 
6
  from ..retrieval import SimilarityMetric
7
  from .llm_client import LLMClient
8
 
9
 
10
  class MedQAAssistant(weave.Model):
 
11
  llm_client: LLMClient
12
  retriever: weave.Model
13
  top_k_chunks: int = 2
14
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
15
 
16
  @weave.op()
17
- def predict(self, query: str, image: Optional[Image.Image] = None) -> str:
18
- _image = image
19
  retrieved_chunks = self.retriever.predict(
20
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
21
  )
 
 
1
 
2
  import weave
 
3
 
4
  from ..retrieval import SimilarityMetric
5
  from .llm_client import LLMClient
6
 
7
 
8
  class MedQAAssistant(weave.Model):
9
+ """Cuming"""
10
  llm_client: LLMClient
11
  retriever: weave.Model
12
  top_k_chunks: int = 2
13
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
14
 
15
  @weave.op()
16
+ def predict(self, query: str) -> str:
 
17
  retrieved_chunks = self.retriever.predict(
18
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
19
  )
mkdocs.yml CHANGED
@@ -83,5 +83,9 @@ nav:
83
  - Contriever: 'retreival/contriever.md'
84
  - MedCPT: 'retreival/medcpt.md'
85
  - NV-Embed-v2: 'retreival/nv_embed_2.md'
 
 
 
 
86
 
87
  repo_url: https://github.com/soumik12345/medrag-multi-modal
 
83
  - Contriever: 'retreival/contriever.md'
84
  - MedCPT: 'retreival/medcpt.md'
85
  - NV-Embed-v2: 'retreival/nv_embed_2.md'
86
+ - Assistant:
87
+ - MedQA Assistant: 'assistant/medqa_assistant.md'
88
+ - Figure Annotation: 'assistant/figure_annotation.md'
89
+ - LLM Client: 'assistant/llm_client.md'
90
 
91
  repo_url: https://github.com/soumik12345/medrag-multi-modal