{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import operator\n", "import warnings\n", "from typing import *\n", "import traceback\n", "\n", "import os\n", "import torch\n", "from dotenv import load_dotenv\n", "from IPython.display import Image\n", "from langgraph.checkpoint.memory import MemorySaver\n", "from langgraph.graph import END, StateGraph\n", "from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n", "from langchain_openai import ChatOpenAI\n", "from transformers import logging\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import re\n", "\n", "from medrax.agent import *\n", "from medrax.tools import *\n", "from medrax.utils import *\n", "\n", "import json\n", "import openai\n", "import os\n", "import glob\n", "import time\n", "import logging\n", "from datetime import datetime\n", "from tenacity import retry, wait_exponential, stop_after_attempt\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "_ = load_dotenv()\n", "\n", "\n", "# Setup directory paths\n", "ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n", "PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n", "BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n", "MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n", "FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n", "\n", "model_name = \"medrax\"\n", "temperature = 0.2\n", "medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n", "log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n", "logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def get_tools():\n", " report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n", " xray_classification_tool = ChestXRayClassifierTool(device=device)\n", " segmentation_tool = ChestXRaySegmentationTool(device=device)\n", " grounding_tool = XRayPhraseGroundingTool(\n", " cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n", " )\n", " xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n", " llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n", "\n", " return [\n", " report_tool,\n", " xray_classification_tool,\n", " segmentation_tool,\n", " grounding_tool,\n", " xray_vqa_tool,\n", " llava_med_tool,\n", " ]\n", "\n", "\n", "def get_agent(tools):\n", " prompts = load_prompts_from_file(PROMPT_FILE)\n", " prompt = prompts[\"MEDICAL_ASSISTANT\"]\n", "\n", " checkpointer = MemorySaver()\n", " model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n", " agent = Agent(\n", " model,\n", " tools=tools,\n", " log_tools=True,\n", " log_dir=\"logs\",\n", " system_prompt=prompt,\n", " checkpointer=checkpointer,\n", " )\n", " thread = {\"configurable\": {\"thread_id\": \"1\"}}\n", " return agent, thread\n", "\n", "\n", "def run_medrax(agent, thread, prompt, image_urls=[]):\n", " messages = [\n", " HumanMessage(\n", " content=[\n", " {\"type\": \"text\", \"text\": prompt},\n", " ]\n", " + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n", " )\n", " ]\n", "\n", " final_response = None\n", " for event in agent.workflow.stream({\"messages\": messages}, thread):\n", " for v in event.values():\n", " final_response = v\n", "\n", " final_response = final_response[\"messages\"][-1].content.strip()\n", " agent_state = agent.workflow.get_state(thread)\n", "\n", " return final_response, str(agent_state)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n", " # Parse required figures\n", " try:\n", " # Try multiple ways of parsing figures\n", " if isinstance(question_data[\"figures\"], str):\n", " try:\n", " required_figures = json.loads(question_data[\"figures\"])\n", " except json.JSONDecodeError:\n", " required_figures = [question_data[\"figures\"]]\n", " elif isinstance(question_data[\"figures\"], list):\n", " required_figures = question_data[\"figures\"]\n", " else:\n", " required_figures = [str(question_data[\"figures\"])]\n", " except Exception as e:\n", " print(f\"Error parsing figures: {e}\")\n", " required_figures = []\n", "\n", " # Ensure each figure starts with \"Figure \"\n", " required_figures = [\n", " fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n", " ]\n", "\n", " subfigures = []\n", " for figure in required_figures:\n", " # Handle both regular figures and those with letter suffixes\n", " base_figure_num = \"\".join(filter(str.isdigit, figure))\n", " figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n", "\n", " # Find matching figures in case details\n", " matching_figures = [\n", " case_figure\n", " for case_figure in case_details.get(\"figures\", [])\n", " if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n", " ]\n", "\n", " if not matching_figures:\n", " print(f\"No matching figure found for {figure} in case {case_id}\")\n", " continue\n", "\n", " for case_figure in matching_figures:\n", " # If a specific letter is specified, filter subfigures\n", " if figure_letter:\n", " matching_subfigures = [\n", " subfig\n", " for subfig in case_figure.get(\"subfigures\", [])\n", " if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n", " or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n", " ]\n", " subfigures.extend(matching_subfigures)\n", " else:\n", " # If no letter specified, add all subfigures\n", " subfigures.extend(case_figure.get(\"subfigures\", []))\n", "\n", " # Add images to content\n", " figure_prompt = \"\"\n", " image_urls = []\n", "\n", " for subfig in subfigures:\n", " if \"number\" in subfig:\n", " subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n", " subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n", " figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n", " if \"url\" in subfig:\n", " image_urls.append(subfig[\"url\"])\n", " else:\n", " print(f\"Subfigure missing URL: {subfig}\")\n", "\n", " prompt = (\n", " f\"Answer this question correctly using chain of thought reasoning and \"\n", " \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n", " \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n", " f\"{question_data['question']}\\n{figure_prompt}\"\n", " )\n", "\n", " try:\n", " start_time = time.time()\n", "\n", " final_response, agent_state = run_medrax(\n", " agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n", " )\n", " model_answer, agent_state = run_medrax(\n", " agent=agent,\n", " thread=thread,\n", " prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n", " )\n", " duration = time.time() - start_time\n", "\n", " log_entry = {\n", " \"case_id\": case_id,\n", " \"question_id\": question_id,\n", " \"timestamp\": datetime.now().isoformat(),\n", " \"model\": model_name,\n", " \"temperature\": temperature,\n", " \"duration\": round(duration, 2),\n", " \"usage\": \"\",\n", " \"cost\": 0,\n", " \"raw_response\": final_response,\n", " \"model_answer\": model_answer.strip(),\n", " \"correct_answer\": question_data[\"answer\"][0],\n", " \"input\": {\n", " \"messages\": prompt,\n", " \"question_data\": {\n", " \"question\": question_data[\"question\"],\n", " \"explanation\": question_data[\"explanation\"],\n", " \"metadata\": question_data.get(\"metadata\", {}),\n", " \"figures\": question_data[\"figures\"],\n", " },\n", " \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n", " \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n", " },\n", " \"agent_state\": agent_state,\n", " }\n", " logging.info(json.dumps(log_entry))\n", " return final_response, model_answer.strip()\n", "\n", " except Exception as e:\n", " log_entry = {\n", " \"case_id\": case_id,\n", " \"question_id\": question_id,\n", " \"timestamp\": datetime.now().isoformat(),\n", " \"model\": model_name,\n", " \"temperature\": temperature,\n", " \"status\": \"error\",\n", " \"error\": str(e),\n", " \"cost\": 0,\n", " \"input\": {\n", " \"messages\": prompt,\n", " \"question_data\": {\n", " \"question\": question_data[\"question\"],\n", " \"explanation\": question_data[\"explanation\"],\n", " \"metadata\": question_data.get(\"metadata\", {}),\n", " \"figures\": question_data[\"figures\"],\n", " },\n", " \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n", " \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n", " },\n", " }\n", " logging.info(json.dumps(log_entry))\n", " print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n", " return \"\", \"\"\n", "\n", "\n", "def load_benchmark_questions(case_id):\n", " benchmark_dir = \"../benchmark/questions\"\n", " return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n", "\n", "\n", "def count_total_questions():\n", " total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n", " total_questions = sum(\n", " len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n", " for case_id in os.listdir(\"../benchmark/questions\")\n", " )\n", " return total_cases, total_questions\n", "\n", "\n", "def main(tools):\n", " with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n", " data = json.load(file)\n", "\n", " total_cases, total_questions = count_total_questions()\n", " cases_processed = 0\n", " questions_processed = 0\n", " skipped_questions = 0\n", "\n", " print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n", "\n", " for case_id, case_details in data.items():\n", " if int(case_details[\"case_id\"]) <= 17158:\n", " continue\n", "\n", " print(f\"----------------------------------------------------------------\")\n", " agent, thread = get_agent(tools)\n", "\n", " question_files = load_benchmark_questions(case_id)\n", " if not question_files:\n", " continue\n", "\n", " cases_processed += 1\n", " for question_file in question_files:\n", " with open(question_file, \"r\") as file:\n", " question_data = json.load(file)\n", " question_id = os.path.basename(question_file).split(\".\")[0]\n", "\n", " # agent, thread = get_agent(tools)\n", " questions_processed += 1\n", " final_response, model_answer = create_multimodal_request(\n", " question_data, case_details, case_id, question_id, agent, thread\n", " )\n", "\n", " # Handle cases where response is None\n", " if final_response is None:\n", " skipped_questions += 1\n", " print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n", " continue\n", "\n", " print(\n", " f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n", " )\n", " print(f\"Case ID: {case_id}\")\n", " print(f\"Question ID: {question_id}\")\n", " print(f\"Final Response: {final_response}\")\n", " print(f\"Model Answer: {model_answer}\")\n", " print(f\"Correct Answer: {question_data['answer']}\")\n", " print(f\"----------------------------------------------------------------\\n\")\n", "\n", " print(f\"\\nBenchmark Summary:\")\n", " print(f\"Total Cases Processed: {cases_processed}\")\n", " print(f\"Total Questions Processed: {questions_processed}\")\n", " print(f\"Total Questions Skipped: {skipped_questions}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tools = get_tools()\n", "main(tools)" ] } ], "metadata": { "kernelspec": { "display_name": "medmax", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }