geekyrakshit commited on
Commit
b123ef7
·
1 Parent(s): 1f48fed

add: page number citation for MedQAAssistant

Browse files
medrag_multi_modal/assistant/medqa_assistant.py CHANGED
@@ -17,10 +17,19 @@ class MedQAAssistant(weave.Model):
17
  retrieved_chunks = self.retriever.predict(
18
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
19
  )
20
- retrieved_chunks = [chunk["text"] for chunk in retrieved_chunks]
 
 
 
 
 
 
 
21
  system_prompt = """
22
- You are a medical expert. You are given a query and a list of chunks from a medical document.
23
  """
24
- return self.llm_client.predict(
25
- system_prompt=system_prompt, user_prompt=retrieved_chunks
26
  )
 
 
 
17
  retrieved_chunks = self.retriever.predict(
18
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
19
  )
20
+
21
+ retrieved_chunk_texts = []
22
+ page_indices = set()
23
+ for chunk in retrieved_chunks:
24
+ retrieved_chunk_texts.append(chunk["text"])
25
+ page_indices.add(int(chunk["page_idx"]))
26
+ page_numbers = ", ".join(map(str, page_indices))
27
+
28
  system_prompt = """
29
+ You are an expert in medical science. You are given a query and a list of chunks from a medical document.
30
  """
31
+ response = self.llm_client.predict(
32
+ system_prompt=system_prompt, user_prompt=[query, *retrieved_chunk_texts]
33
  )
34
+ response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
35
+ return response