import re import json import os import glob import time import logging from datetime import datetime import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer from tqdm import tqdm # Configure model settings MODEL_NAME = "StanfordAIMI/CheXagent-2-3b" DTYPE = torch.bfloat16 DEVICE = "cuda" # Configure logging log_filename = f"model_inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s") def initialize_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]: """Initialize the CheXagent model and tokenizer. Returns: tuple containing: - AutoModelForCausalLM: The initialized CheXagent model - AutoTokenizer: The initialized tokenizer """ print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", trust_remote_code=True ) model = model.to(DTYPE) model.eval() return model, tokenizer def create_inference_request( question_data: dict, case_details: dict, case_id: str, question_id: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, ) -> str | None: """Create and execute an inference request for the CheXagent model. Args: question_data: Dictionary containing question details and metadata case_details: Dictionary containing case information and image paths case_id: Unique identifier for the medical case question_id: Unique identifier for the question model: The initialized CheXagent model tokenizer: The initialized tokenizer Returns: str | None: Single letter answer (A-F) if successful, None if failed """ system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer. Rules: 1. Respond with exactly one uppercase letter (A/B/C/D/E/F) 2. Do not add periods, explanations, or any other text 3. Do not use markdown or formatting 4. Do not restate the question 5. Do not explain your reasoning Examples of valid responses: A B C Examples of invalid responses: "A." "Answer: B" "C) This shows..." "The answer is D" """ prompt = f"""Given the following medical case: Please answer this multiple choice question: {question_data['question']} Base your answer only on the provided images and case information.""" # Parse required figures try: if isinstance(question_data["figures"], str): try: required_figures = json.loads(question_data["figures"]) except json.JSONDecodeError: required_figures = [question_data["figures"]] elif isinstance(question_data["figures"], list): required_figures = question_data["figures"] else: required_figures = [str(question_data["figures"])] except Exception as e: print(f"Error parsing figures: {e}") required_figures = [] required_figures = [ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures ] # Get image paths image_paths = [] for figure in required_figures: base_figure_num = "".join(filter(str.isdigit, figure)) figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None matching_figures = [ case_figure for case_figure in case_details.get("figures", []) if case_figure["number"] == f"Figure {base_figure_num}" ] for case_figure in matching_figures: subfigures = [] if figure_letter: subfigures = [ subfig for subfig in case_figure.get("subfigures", []) if subfig.get("number", "").lower().endswith(figure_letter.lower()) or subfig.get("label", "").lower() == figure_letter.lower() ] else: subfigures = case_figure.get("subfigures", []) for subfig in subfigures: if "local_path" in subfig: image_paths.append("medrax/data/" + subfig["local_path"]) if not image_paths: print(f"No local images found for case {case_id}, question {question_id}") return None try: start_time = time.time() # Prepare input for the model query = tokenizer.from_list_format( [*[{"image": path} for path in image_paths], {"text": prompt}] ) conv = [{"from": "system", "value": system_prompt}, {"from": "human", "value": query}] input_ids = tokenizer.apply_chat_template( conv, add_generation_prompt=True, return_tensors="pt" ) # Generate response with torch.no_grad(): output = model.generate( input_ids.to(DEVICE), do_sample=False, num_beams=1, temperature=1.0, top_p=1.0, use_cache=True, max_new_tokens=512, )[0] response = tokenizer.decode(output[input_ids.size(1) : -1]) duration = time.time() - start_time # Clean response clean_answer = validate_answer(response) # Log response log_entry = { "case_id": case_id, "question_id": question_id, "timestamp": datetime.now().isoformat(), "model": MODEL_NAME, "duration": round(duration, 2), "model_answer": clean_answer, "correct_answer": question_data["answer"], "input": { "question_data": { "question": question_data["question"], "explanation": question_data["explanation"], "metadata": question_data.get("metadata", {}), "figures": question_data["figures"], }, "image_paths": image_paths, }, } logging.info(json.dumps(log_entry)) return clean_answer except Exception as e: print(f"Error processing case {case_id}, question {question_id}: {str(e)}") log_entry = { "case_id": case_id, "question_id": question_id, "timestamp": datetime.now().isoformat(), "model": MODEL_NAME, "status": "error", "error": str(e), "input": { "question_data": { "question": question_data["question"], "explanation": question_data["explanation"], "metadata": question_data.get("metadata", {}), "figures": question_data["figures"], }, "image_paths": image_paths, }, } logging.info(json.dumps(log_entry)) return None def validate_answer(response_text: str) -> str | None: """Enforce strict single-letter response format. Args: response_text: Raw response text from the model Returns: str | None: Single uppercase letter (A-F) if valid, None if invalid """ if not response_text: return None # Remove all whitespace and convert to uppercase cleaned = response_text.strip().upper() # Check if it's exactly one valid letter if len(cleaned) == 1 and cleaned in "ABCDEF": return cleaned # If not, try to extract just the letter match = re.search(r"([A-F])", cleaned) return match.group(1) if match else None def load_benchmark_questions(case_id: str) -> list[str]: """Find all question files for a given case ID. Args: case_id: Unique identifier for the medical case Returns: list[str]: List of paths to question JSON files """ benchmark_dir = "../benchmark/questions" return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") def count_total_questions() -> tuple[int, int]: """Count total number of cases and questions in benchmark. Returns: tuple containing: - int: Total number of cases - int: Total number of questions """ total_cases = len(glob.glob("../benchmark/questions/*")) total_questions = sum( len(glob.glob(f"../benchmark/questions/{case_id}/*.json")) for case_id in os.listdir("../benchmark/questions") ) return total_cases, total_questions def main(): # Load the cases with local paths with open("medrax/data/updated_cases.json", "r") as file: data = json.load(file) # Initialize model and tokenizer model, tokenizer = initialize_model() total_cases, total_questions = count_total_questions() cases_processed = 0 questions_processed = 0 skipped_questions = 0 print(f"\nBeginning inference with {MODEL_NAME}") print(f"Found {total_cases} cases with {total_questions} total questions") # Process each case with progress bar for case_id, case_details in tqdm(data.items(), desc="Processing cases"): question_files = load_benchmark_questions(case_id) if not question_files: continue cases_processed += 1 for question_file in tqdm( question_files, desc=f"Processing questions for case {case_id}", leave=False ): with open(question_file, "r") as file: question_data = json.load(file) question_id = os.path.basename(question_file).split(".")[0] questions_processed += 1 answer = create_inference_request( question_data, case_details, case_id, question_id, model, tokenizer ) if answer is None: skipped_questions += 1 continue print(f"\nCase {case_id}, Question {question_id}") print(f"Model Answer: {answer}") print(f"Correct Answer: {question_data['answer']}") print(f"\nInference Summary:") print(f"Total Cases Processed: {cases_processed}") print(f"Total Questions Processed: {questions_processed}") print(f"Total Questions Skipped: {skipped_questions}") if __name__ == "__main__": main()