import argparse import json import requests import base64 from PIL import Image from io import BytesIO from llava.conversation import conv_templates import time import os import glob import logging from datetime import datetime from tqdm import tqdm import re from typing import Dict, List, Optional, Union, Any, Tuple def process_image(image_path: str, target_size: int = 640) -> Image.Image: """Process and resize an image to match model requirements. Args: image_path: Path to the input image file target_size: Target size for both width and height in pixels Returns: PIL.Image: Processed and padded image with dimensions (target_size, target_size) """ image = Image.open(image_path) if image.mode != "RGB": image = image.convert("RGB") # Calculate scaling to maintain aspect ratio ratio = min(target_size / image.width, target_size / image.height) new_size = (int(image.width * ratio), int(image.height * ratio)) # Resize image image = image.resize(new_size, Image.LANCZOS) # Create new image with padding new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0)) # Paste resized image in center offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2) new_image.paste(image, offset) return new_image def validate_answer(response_text: str) -> Optional[str]: """Extract and validate a single-letter response from the model's output. Handles multiple response formats and edge cases. Args: response_text: The full text output from the model Returns: A single letter answer (A-F) or None if no valid answer found """ if not response_text: return None # Clean the response text cleaned = response_text.strip() # Comprehensive set of patterns to extract the answer extraction_patterns = [ # Strict format with explicit letter answer r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b", # Patterns for extracting from longer descriptions r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", r"\b(?:answer|option)\s*([A-F])[):]\s*", # Patterns for extracting from descriptive sentences r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b", # Patterns with contextual words r"characteriz[e]?d?\s+by\s+([A-F])\b", r"indicat[e]?s?\s+([A-F])\b", # Fallback to Option X or Letterr X formats r"Option\s*([A-F])\b", r"\b([A-F])\)\s*", # Fallback to standalone letter r"^\s*([A-F])\s*$", ] # Try each pattern for pattern in extraction_patterns: matches = re.findall(pattern, cleaned, re.IGNORECASE) for match in matches: # Ensure match is a single valid letter if isinstance(match, tuple): match = match[0] if match[0] in "ABCDEF" else None if match and match.upper() in "ABCDEF": return match.upper() # Final fallback: look for standalone letters in context context_matches = re.findall(r"\b([A-F])\b", cleaned.upper()) context_letters = [m for m in context_matches if m in "ABCDEF"] if context_letters: return context_letters[0] # No valid answer found return None def load_benchmark_questions(case_id: str) -> List[str]: """Find all question files for a given case ID. Args: case_id: The ID of the medical case Returns: List of paths to question JSON files """ benchmark_dir = "MedMAX/benchmark/questions" return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") def count_total_questions() -> Tuple[int, int]: """Count total number of cases and questions in benchmark. Returns: Tuple containing (total_cases, total_questions) """ total_cases = len(glob.glob("MedMAX/benchmark/questions/*")) total_questions = sum( len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json")) for case_id in os.listdir("MedMAX/benchmark/questions") ) return total_cases, total_questions def create_inference_request( question_data: Dict[str, Any], case_details: Dict[str, Any], case_id: str, question_id: str, worker_addr: str, model_name: str, raw_output: bool = False, ) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]: """Create and send inference request to worker. Args: question_data: Dictionary containing question details and figures case_details: Dictionary containing case information and figures case_id: Identifier for the medical case question_id: Identifier for the specific question worker_addr: Address of the worker endpoint model_name: Name of the model to use raw_output: Whether to return raw model output Returns: If raw_output is False: Tuple of (validated_answer, duration) If raw_output is True: Dictionary with full inference details """ system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'. """ prompt = f"""Given the following medical case: Please answer this multiple choice question: {question_data['question']} Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """ try: # Parse required figures if isinstance(question_data["figures"], str): try: required_figures = json.loads(question_data["figures"]) except json.JSONDecodeError: required_figures = [question_data["figures"]] elif isinstance(question_data["figures"], list): required_figures = question_data["figures"] else: required_figures = [str(question_data["figures"])] except Exception as e: print(f"Error parsing figures: {e}") required_figures = [] required_figures = [ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures ] # Get image paths image_paths = [] for figure in required_figures: base_figure_num = "".join(filter(str.isdigit, figure)) figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None matching_figures = [ case_figure for case_figure in case_details.get("figures", []) if case_figure["number"] == f"Figure {base_figure_num}" ] for case_figure in matching_figures: subfigures = [] if figure_letter: subfigures = [ subfig for subfig in case_figure.get("subfigures", []) if subfig.get("number", "").lower().endswith(figure_letter.lower()) or subfig.get("label", "").lower() == figure_letter.lower() ] else: subfigures = case_figure.get("subfigures", []) for subfig in subfigures: if "local_path" in subfig: image_paths.append("MedMAX/data/" + subfig["local_path"]) if not image_paths: print(f"No local images found for case {case_id}, question {question_id}") return "skipped", 0.0 # Return a special 'skipped' marker try: start_time = time.time() # Process each image processed_images = [process_image(path) for path in image_paths] # Create conversation conv = conv_templates["mistral_instruct"].copy() # Add image and message if "" not in prompt: text = prompt + "\n" else: text = prompt message = (text, processed_images[0], "Default") # Currently handling first image conv.append_message(conv.roles[0], message) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() headers = {"User-Agent": "LLaVA-Med Client"} pload = { "model": model_name, "prompt": prompt, "max_new_tokens": 150, # Reduce this since we only need one letter "temperature": 0.5, # Lower temperature for more focused responses "stop": conv.sep2, "images": conv.get_images(), "top_p": 1, # Lower top_p for more focused sampling "frequency_penalty": 0.0, "presence_penalty": 0.0, } max_retries = 3 retry_delay = 5 response_text = None for attempt in range(max_retries): try: response = requests.post( worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=30, ) complete_output = "" for chunk in response.iter_lines( chunk_size=8192, decode_unicode=False, delimiter=b"\0" ): if chunk: data = json.loads(chunk.decode("utf-8")) if data["error_code"] == 0: output = data["text"].split("[/INST]")[-1] complete_output = output else: print(f"\nError: {data['text']} (error_code: {data['error_code']})") if attempt < max_retries - 1: time.sleep(retry_delay) break return None, None if complete_output: response_text = complete_output break except (requests.exceptions.RequestException, json.JSONDecodeError) as e: if attempt < max_retries - 1: print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...") time.sleep(retry_delay) else: print(f"\nFailed after {max_retries} attempts: {str(e)}") return None, None duration = time.time() - start_time if raw_output: inference_details = { "raw_output": response_text, "validated_answer": validate_answer(response_text), "duration": duration, "prompt": prompt, "system_prompt": system_prompt, "image_paths": image_paths, "payload": pload, } return inference_details return validate_answer(response_text), duration except Exception as e: print(f"Error in inference request: {str(e)}") return None, None def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Remove image-related and large data from the payload to keep the log lean. Args: payload: Original request payload dictionary Returns: Cleaned payload dictionary with large data removed """ if not payload: return None # Create a copy of the payload to avoid modifying the original cleaned_payload = payload.copy() # Remove large or sensitive data if "images" in cleaned_payload: del cleaned_payload["images"] return cleaned_payload def main(): parser = argparse.ArgumentParser() parser.add_argument("--controller-address", type=str, default="http://localhost:21001") parser.add_argument("--worker-address", type=str) parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b") parser.add_argument("--output-dir", type=str, default="benchmark_results") parser.add_argument( "--raw-output", action="store_true", help="Return raw model output without validation" ) parser.add_argument( "--num-cases", type=int, help="Number of cases to process if looking at raw outputs", default=2, ) args = parser.parse_args() # Setup output directory os.makedirs(args.output_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Setup live logging files live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json") final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json") # Initialize live log file with open(live_log_filename, "w") as live_log_file: live_log_file.write("[\n") # Start of JSON array # Setup logging logging.basicConfig( filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"), level=logging.INFO, format="%(message)s", ) # Get worker address if args.worker_address: worker_addr = args.worker_address else: try: requests.post(args.controller_address + "/refresh_all_workers") ret = requests.post(args.controller_address + "/list_models") models = ret.json()["models"] ret = requests.post( args.controller_address + "/get_worker_address", json={"model": args.model_name} ) worker_addr = ret.json()["address"] print(f"Worker address: {worker_addr}") except requests.exceptions.RequestException as e: print(f"Failed to connect to controller: {e}") return if worker_addr == "": print("No available worker") return # Load cases with local paths with open("MedMAX/data/updated_cases.json", "r") as file: data = json.load(file) total_cases, total_questions = count_total_questions() print(f"\nStarting benchmark with {args.model_name}") print(f"Found {total_cases} cases with {total_questions} total questions") results = { "model": args.model_name, "timestamp": datetime.now().isoformat(), "total_cases": total_cases, "total_questions": total_questions, "results": [], } cases_processed = 0 questions_processed = 0 correct_answers = 0 skipped_questions = 0 total_processed_entries = 0 # Process each case for case_id, case_details in tqdm(data.items(), desc="Processing cases"): question_files = load_benchmark_questions(case_id) if not question_files: continue cases_processed += 1 for question_file in tqdm( question_files, desc=f"Processing questions for case {case_id}", leave=False ): with open(question_file, "r") as file: question_data = json.load(file) question_id = os.path.basename(question_file).split(".")[0] questions_processed += 1 # Get model's answer inference_result = create_inference_request( question_data, case_details, case_id, question_id, worker_addr, args.model_name, raw_output=True, # Always use raw output for detailed logging ) # Handle skipped questions if inference_result == ("skipped", 0.0): skipped_questions += 1 print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)") # Log skipped question skipped_entry = { "case_id": case_id, "question_id": question_id, "status": "skipped", "reason": "No images found", } with open(live_log_filename, "a") as live_log_file: json.dump(skipped_entry, live_log_file, indent=2) live_log_file.write(",\n") # Add comma for next entry continue # Extract information answer = inference_result["validated_answer"] duration = inference_result["duration"] # Prepare detailed logging entry log_entry = { "case_id": case_id, "question_id": question_id, "question": question_data["question"], "correct_answer": question_data["answer"], "raw_output": inference_result["raw_output"], "validated_answer": answer, "model_answer": answer, "is_correct": answer == question_data["answer"] if answer else False, "duration": duration, "system_prompt": inference_result["system_prompt"], "input_prompt": inference_result["prompt"], "image_paths": inference_result["image_paths"], "payload": clean_payload(inference_result["payload"]), } # Write to live log file with open(live_log_filename, "a") as live_log_file: json.dump(log_entry, live_log_file, indent=2) live_log_file.write(",\n") # Add comma for next entry # Print to console print(f"\nCase {case_id}, Question {question_id}") print(f"Model Answer: {answer}") print(f"Correct Answer: {question_data['answer']}") print(f"Time taken: {duration:.2f}s") # Track correct answers if answer == question_data["answer"]: correct_answers += 1 # Append to results results["results"].append(log_entry) total_processed_entries += 1 # Optional: break if reached specified number of cases if args.raw_output and cases_processed == args.num_cases: break # Optional: break if reached specified number of cases if args.raw_output and cases_processed == args.num_cases: break # Close live log file with open(live_log_filename, "a") as live_log_file: # Remove trailing comma and close JSON array live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n' live_log_file.write("\n]") # Calculate final statistics results["summary"] = { "cases_processed": cases_processed, "questions_processed": questions_processed, "total_processed_entries": total_processed_entries, "correct_answers": correct_answers, "skipped_questions": skipped_questions, "accuracy": ( correct_answers / (questions_processed - skipped_questions) if (questions_processed - skipped_questions) > 0 else 0 ), } # Save final results with open(final_results_filename, "w") as f: json.dump(results, f, indent=2) print(f"\nBenchmark Summary:") print(f"Total Cases Processed: {cases_processed}") print(f"Total Questions Processed: {questions_processed}") print(f"Total Processed Entries: {total_processed_entries}") print(f"Correct Answers: {correct_answers}") print(f"Skipped Questions: {skipped_questions}") print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%") print(f"\nResults saved to {args.output_dir}") print(f"Live log: {live_log_filename}") print(f"Final results: {final_results_filename}") if __name__ == "__main__": main()