In [1]:
import operator
import warnings
from typing import *
import traceback

import os
import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging
import matplotlib.pyplot as plt
import numpy as np
import re

from medrax.agent import *
from medrax.tools import *
from medrax.utils import *

import json
import openai
import os
import glob
import time
import logging
from datetime import datetime
from tenacity import retry, wait_exponential, stop_after_attempt

warnings.filterwarnings("ignore")
_ = load_dotenv()


# Setup directory paths
ROOT = "set this directory to where MedRAX is, .e.g /home/MedRAX"
PROMPT_FILE = f"{ROOT}/medrax/docs/system_prompts.txt"
BENCHMARK_FILE = f"{ROOT}/benchmark/questions"
MODEL_DIR = f"set this to where the tool models are, e.g /home/models"
FIGURES_DIR = f"{ROOT}/benchmark/figures"

model_name = "medrax"
temperature = 0.2
medrax_logs = f"{ROOT}/experiments/medrax_logs"
log_filename = f"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s", force=True)
device = "cuda"

In [2]:
def get_tools():
 report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)
 xray_classification_tool = ChestXRayClassifierTool(device=device)
 segmentation_tool = ChestXRaySegmentationTool(device=device)
 grounding_tool = XRayPhraseGroundingTool(
 cache_dir=MODEL_DIR, temp_dir="temp", device=device, load_in_8bit=True
 )
 xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)
 llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)

 return [
 report_tool,
 xray_classification_tool,
 segmentation_tool,
 grounding_tool,
 xray_vqa_tool,
 llava_med_tool,
 ]


def get_agent(tools):
 prompts = load_prompts_from_file(PROMPT_FILE)
 prompt = prompts["MEDICAL_ASSISTANT"]

 checkpointer = MemorySaver()
 model = ChatOpenAI(model="gpt-4o", temperature=temperature, top_p=0.95)
 agent = Agent(
 model,
 tools=tools,
 log_tools=True,
 log_dir="logs",
 system_prompt=prompt,
 checkpointer=checkpointer,
 )
 thread = {"configurable": {"thread_id": "1"}}
 return agent, thread


def run_medrax(agent, thread, prompt, image_urls=[]):
 messages = [
 HumanMessage(
 content=[
 {"type": "text", "text": prompt},
 ]
 + [{"type": "image_url", "image_url": {"url": image_url}} for image_url in image_urls]
 )
 ]

 final_response = None
 for event in agent.workflow.stream({"messages": messages}, thread):
 for v in event.values():
 final_response = v

 final_response = final_response["messages"][-1].content.strip()
 agent_state = agent.workflow.get_state(thread)

 return final_response, str(agent_state)

In [3]:
def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):
 # Parse required figures
 try:
 # Try multiple ways of parsing figures
 if isinstance(question_data["figures"], str):
 try:
 required_figures = json.loads(question_data["figures"])
 except json.JSONDecodeError:
 required_figures = [question_data["figures"]]
 elif isinstance(question_data["figures"], list):
 required_figures = question_data["figures"]
 else:
 required_figures = [str(question_data["figures"])]
 except Exception as e:
 print(f"Error parsing figures: {e}")
 required_figures = []

 # Ensure each figure starts with "Figure "
 required_figures = [
 fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
 ]

 subfigures = []
 for figure in required_figures:
 # Handle both regular figures and those with letter suffixes
 base_figure_num = "".join(filter(str.isdigit, figure))
 figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None

 # Find matching figures in case details
 matching_figures = [
 case_figure
 for case_figure in case_details.get("figures", [])
 if case_figure["number"] == f"Figure {base_figure_num}"
 ]

 if not matching_figures:
 print(f"No matching figure found for {figure} in case {case_id}")
 continue

 for case_figure in matching_figures:
 # If a specific letter is specified, filter subfigures
 if figure_letter:
 matching_subfigures = [
 subfig
 for subfig in case_figure.get("subfigures", [])
 if subfig.get("number", "").lower().endswith(figure_letter.lower())
 or subfig.get("label", "").lower() == figure_letter.lower()
 ]
 subfigures.extend(matching_subfigures)
 else:
 # If no letter specified, add all subfigures
 subfigures.extend(case_figure.get("subfigures", []))

 # Add images to content
 figure_prompt = ""
 image_urls = []

 for subfig in subfigures:
 if "number" in subfig:
 subfig_number = subfig["number"].lower().strip().replace(" ", "_") + ".jpg"
 subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)
 figure_prompt += f"{subfig_number} located at {subfig_path}\n"
 if "url" in subfig:
 image_urls.append(subfig["url"])
 else:
 print(f"Subfigure missing URL: {subfig}")

 prompt = (
 f"Answer this question correctly using chain of thought reasoning and "
 "carefully evaluating choices. Solve using our own vision and reasoning and then"
 "use tools to complement your reasoning. Trust your own judgement over any tools.\n"
 f"{question_data['question']}\n{figure_prompt}"
 )

 try:
 start_time = time.time()

 final_response, agent_state = run_medrax(
 agent=agent, thread=thread, prompt=prompt, image_urls=image_urls
 )
 model_answer, agent_state = run_medrax(
 agent=agent,
 thread=thread,
 prompt="If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)",
 )
 duration = time.time() - start_time

 log_entry = {
 "case_id": case_id,
 "question_id": question_id,
 "timestamp": datetime.now().isoformat(),
 "model": model_name,
 "temperature": temperature,
 "duration": round(duration, 2),
 "usage": "",
 "cost": 0,
 "raw_response": final_response,
 "model_answer": model_answer.strip(),
 "correct_answer": question_data["answer"][0],
 "input": {
 "messages": prompt,
 "question_data": {
 "question": question_data["question"],
 "explanation": question_data["explanation"],
 "metadata": question_data.get("metadata", {}),
 "figures": question_data["figures"],
 },
 "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
 "image_captions": [subfig.get("caption", "") for subfig in subfigures],
 },
 "agent_state": agent_state,
 }
 logging.info(json.dumps(log_entry))
 return final_response, model_answer.strip()

 except Exception as e:
 log_entry = {
 "case_id": case_id,
 "question_id": question_id,
 "timestamp": datetime.now().isoformat(),
 "model": model_name,
 "temperature": temperature,
 "status": "error",
 "error": str(e),
 "cost": 0,
 "input": {
 "messages": prompt,
 "question_data": {
 "question": question_data["question"],
 "explanation": question_data["explanation"],
 "metadata": question_data.get("metadata", {}),
 "figures": question_data["figures"],
 },
 "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
 "image_captions": [subfig.get("caption", "") for subfig in subfigures],
 },
 }
 logging.info(json.dumps(log_entry))
 print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
 return "", ""


def load_benchmark_questions(case_id):
 benchmark_dir = "../benchmark/questions"
 return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")


def count_total_questions():
 total_cases = len(glob.glob("../benchmark/questions/*"))
 total_questions = sum(
 len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
 for case_id in os.listdir("../benchmark/questions")
 )
 return total_cases, total_questions


def main(tools):
 with open("../data/eurorad_metadata.json", "r") as file:
 data = json.load(file)

 total_cases, total_questions = count_total_questions()
 cases_processed = 0
 questions_processed = 0
 skipped_questions = 0

 print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\n")

 for case_id, case_details in data.items():
 if int(case_details["case_id"]) <= 17158:
 continue

 print(f"----------------------------------------------------------------")
 agent, thread = get_agent(tools)

 question_files = load_benchmark_questions(case_id)
 if not question_files:
 continue

 cases_processed += 1
 for question_file in question_files:
 with open(question_file, "r") as file:
 question_data = json.load(file)
 question_id = os.path.basename(question_file).split(".")[0]

 # agent, thread = get_agent(tools)
 questions_processed += 1
 final_response, model_answer = create_multimodal_request(
 question_data, case_details, case_id, question_id, agent, thread
 )

 # Handle cases where response is None
 if final_response is None:
 skipped_questions += 1
 print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
 continue

 print(
 f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
 )
 print(f"Case ID: {case_id}")
 print(f"Question ID: {question_id}")
 print(f"Final Response: {final_response}")
 print(f"Model Answer: {model_answer}")
 print(f"Correct Answer: {question_data['answer']}")
 print(f"----------------------------------------------------------------\n")

 print(f"\nBenchmark Summary:")
 print(f"Total Cases Processed: {cases_processed}")
 print(f"Total Questions Processed: {questions_processed}")
 print(f"Total Questions Skipped: {skipped_questions}")

In [None]:
tools = get_tools()
main(tools)