from typing import Dict, List, Optional, Tuple, Union, Any import json import os import sys import argparse from collections import defaultdict from tqdm import tqdm QUESTION_TYPES = { "Detailed Finding Analysis": ["detection", "localization", "characterization"], "Pattern Recognition & Relations": ["detection", "classification", "relationship"], "Spatial Understanding": ["localization", "comparison", "relationship"], "Clinical Decision Making": ["classification", "comparison", "diagnosis"], "Diagnostic Classification": ["classification", "characterization", "diagnosis"], } def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]: """ Extract just the letter from various answer formats. Args: answer: The answer text to extract letter from Returns: Optional[str]: The extracted letter in uppercase, or None if no letter found """ if not answer: return None # Convert to string and clean answer = str(answer).strip() # If it's just a single letter, return it if len(answer) == 1 and answer.isalpha(): return answer.upper() # Try to extract letter from format like "A)" or "A." if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ": return answer[0].upper() # Try to extract letter from format like "A) Some text" if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")): return answer[0].upper() return None def analyze_gpt4_results( results_file: str, max_questions: Optional[int] = None ) -> Tuple[float, Dict, Dict, List[str], List[str]]: """ Analyze results in GPT-4 format. Args: results_file: Path to results file max_questions: Maximum number of questions to analyze Returns: Tuple containing: - overall_accuracy (float) - category_accuracies (Dict) - question_type_stats (Dict) - correct_ids (List[str]) - incorrect_ids (List[str]) """ category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) all_questions = 0 all_correct = 0 correct_ids = [] incorrect_ids = [] with open(results_file, "r") as f: lines = f.readlines() processed_questions = 0 for line in tqdm(lines, desc="Analyzing Benchmark Results"): # Check if we've hit the maximum questions if max_questions is not None and processed_questions >= max_questions: break if line.startswith("HTTP Request:"): continue try: entry = json.loads(line) metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) question_id = entry.get("question_id") model_letter = extract_answer_letter(entry.get("model_answer")) correct_letter = extract_answer_letter(entry.get("correct_answer")) if model_letter and correct_letter: all_questions += 1 processed_questions += 1 is_correct = model_letter == correct_letter if is_correct: all_correct += 1 correct_ids.append(question_id) else: incorrect_ids.append(question_id) for category in metadata.get("categories", []): category_performance[category]["total"] += 1 if is_correct: category_performance[category]["correct"] += 1 except json.JSONDecodeError: continue return process_results( category_performance, all_questions, all_correct, correct_ids, incorrect_ids ) def analyze_llama_results( results_file: str, max_questions: Optional[int] = None ) -> Tuple[float, Dict, Dict, List[str], List[str]]: """ Analyze results in Llama format. Args: results_file: Path to results file max_questions: Maximum number of questions to analyze Returns: Tuple containing: - overall_accuracy (float) - category_accuracies (Dict) - question_type_stats (Dict) - correct_ids (List[str]) - incorrect_ids (List[str]) """ category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) all_questions = 0 all_correct = 0 correct_ids = [] incorrect_ids = [] with open(results_file, "r") as f: lines = f.readlines() # If max_questions is set, limit the number of lines processed if max_questions is not None: lines = lines[:max_questions] for line in tqdm(lines, desc="Analyzing Benchmark Results"): if line.startswith("HTTP Request:"): continue try: entry = json.loads(line) metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) question_id = entry.get("question_id") model_letter = extract_answer_letter(entry.get("model_answer")) correct_letter = extract_answer_letter(entry.get("correct_answer")) if model_letter and correct_letter: all_questions += 1 is_correct = model_letter == correct_letter if is_correct: all_correct += 1 correct_ids.append(question_id) else: incorrect_ids.append(question_id) for category in metadata.get("categories", []): category_performance[category]["total"] += 1 if is_correct: category_performance[category]["correct"] += 1 except json.JSONDecodeError: continue return process_results( category_performance, all_questions, all_correct, correct_ids, incorrect_ids ) def analyze_chexagent_results( results_file: str, max_questions: Optional[int] = None ) -> Tuple[float, Dict, Dict, List[str], List[str]]: """ Analyze results in CheXagent format. Args: results_file: Path to results file max_questions: Maximum number of questions to analyze Returns: Tuple containing: - overall_accuracy (float) - category_accuracies (Dict) - question_type_stats (Dict) - correct_ids (List[str]) - incorrect_ids (List[str]) """ category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) all_questions = 0 all_correct = 0 correct_ids = [] incorrect_ids = [] with open(results_file, "r") as f: lines = f.readlines() # If max_questions is set, limit the number of lines processed if max_questions is not None: lines = lines[:max_questions] for line in tqdm(lines, desc="Analyzing Benchmark Results"): try: entry = json.loads(line) metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) question_id = entry.get("question_id") model_letter = extract_answer_letter(entry.get("model_answer")) correct_letter = extract_answer_letter(entry.get("correct_answer")) if model_letter and correct_letter: all_questions += 1 is_correct = model_letter == correct_letter if is_correct: all_correct += 1 correct_ids.append(question_id) else: incorrect_ids.append(question_id) for category in metadata.get("categories", []): category_performance[category]["total"] += 1 if is_correct: category_performance[category]["correct"] += 1 except json.JSONDecodeError: continue return process_results( category_performance, all_questions, all_correct, correct_ids, incorrect_ids ) def process_results( category_performance: Dict, all_questions: int, all_correct: int, correct_ids: Optional[List[str]] = None, incorrect_ids: Optional[List[str]] = None, ) -> Tuple[float, Dict, Dict, List[str], List[str]]: """ Process raw results into final statistics. Args: category_performance: Dict containing performance by category all_questions: Total number of questions all_correct: Total number of correct answers correct_ids: List of IDs for correctly answered questions incorrect_ids: List of IDs for incorrectly answered questions Returns: Tuple containing: - overall_accuracy (float) - category_accuracies (Dict) - question_type_stats (Dict) - correct_ids (List[str]) - incorrect_ids (List[str]) """ category_accuracies = { category: { "accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0, "total": stats["total"], "correct": stats["correct"], } for category, stats in category_performance.items() } question_type_stats = {} for qtype, categories in QUESTION_TYPES.items(): total = sum( category_performance[cat]["total"] for cat in categories if cat in category_performance ) correct = sum( category_performance[cat]["correct"] for cat in categories if cat in category_performance ) question_type_stats[qtype] = { "accuracy": (correct / total * 100) if total > 0 else 0, "total": total, "correct": correct, } overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0 return ( overall_accuracy, category_accuracies, question_type_stats, correct_ids or [], incorrect_ids or [], ) def print_analysis( overall_accuracy: float, category_accuracies: Dict, question_type_stats: Dict, correct_ids: List[str], incorrect_ids: List[str], model_name: str, ) -> None: """ Print analysis results. Args: overall_accuracy: Overall accuracy percentage category_accuracies: Dict containing accuracy metrics by category question_type_stats: Dict containing stats by question type correct_ids: List of IDs for correctly answered questions incorrect_ids: List of IDs for incorrectly answered questions model_name: Name of the model being analyzed """ total_questions = len(correct_ids) + len(incorrect_ids) print( f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)" ) print("\nCategory Performance:") sorted_categories = sorted( category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True ) for category, metrics in sorted_categories: print(f"{category}:") print(f" Accuracy: {metrics['accuracy']:.2f}%") print(f" Total Questions: {metrics['total']}") print(f" Correct Questions: {metrics['correct']}") print("\nQuestion Type Performance:") sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True) for qtype, metrics in sorted_types: print(f"\n{qtype}:") print(f" Accuracy: {metrics['accuracy']:.2f}%") print(f" Total Questions: {metrics['total']}") print(f" Correct Questions: {metrics['correct']}") print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}") # Save question IDs to JSON question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids} output_filename = f"{model_name}_question_ids.json" with open(output_filename, "w") as f: json.dump(question_ids, f, indent=2) print(f"\nQuestion IDs have been saved to {output_filename}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Analyze benchmark results") parser.add_argument("results_file", help="Path to results file") parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory") parser.add_argument( "--model", choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"], default="gpt4", help="Specify model format (default: gpt4)", ) parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze") args = parser.parse_args() if args.model == "gpt4": results = analyze_gpt4_results(args.results_file, args.max_questions) elif args.model == "llama": results = analyze_llama_results(args.results_file, args.max_questions) elif args.model == "chexagent": results = analyze_chexagent_results(args.results_file, args.max_questions) elif args.model == "medrax": results = analyze_gpt4_results(args.results_file, args.max_questions) else: parser.error(f"Unsupported model: {args.model}") print_analysis(*results, args.model)