import json import argparse import random from typing import List, Dict, Any, Tuple import re from collections import defaultdict # Define category order CATEGORY_ORDER = [ "detection", "classification", "localization", "comparison", "relationship", "diagnosis", "characterization", ] def extract_letter_answer(answer: str) -> str: """Extract just the letter answer from various answer formats. Args: answer: The answer string to extract a letter from Returns: str: The extracted letter in uppercase, or empty string if no letter found """ if not answer: return "" # Convert to string and clean answer = str(answer).strip() # If it's just a single letter A-F, return it if len(answer) == 1 and answer.upper() in "ABCDEF": return answer.upper() # Try to match patterns like "A)", "A.", "A ", etc. match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE) if match: return match.group(1).upper() # Try to find any standalone A-F letters preceded by space or start of string # and followed by space, period, parenthesis or end of string matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE) if matches: return matches[0].upper() # Last resort: just find any A-F letter letters = re.findall(r"[A-F]", answer, re.IGNORECASE) if letters: return letters[0].upper() # If no letter found, return original (cleaned) return answer.strip().upper() def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]: """Parse JSON Lines file and extract valid predictions. Args: file_path: Path to the JSON Lines file to parse Returns: Tuple containing: - str: Model name or file path if model name not found - List[Dict[str, Any]]: List of valid prediction entries """ valid_predictions = [] model_name = None # First try to parse as LLaVA format try: with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) if data.get("model") == "llava-med-v1.5-mistral-7b": model_name = data["model"] for result in data.get("results", []): if all(k in result for k in ["case_id", "question_id", "correct_answer"]): # Extract answer with priority: model_answer > validated_answer > raw_output model_answer = ( result.get("model_answer") or result.get("validated_answer") or result.get("raw_output", "") ) # Add default categories for LLaVA results prediction = { "case_id": result["case_id"], "question_id": result["question_id"], "model_answer": model_answer, "correct_answer": result["correct_answer"], "input": { "question_data": { "metadata": { "categories": [ "detection", "classification", "localization", "comparison", "relationship", "diagnosis", "characterization", ] } } }, } valid_predictions.append(prediction) return model_name, valid_predictions except (json.JSONDecodeError, KeyError): pass # If not LLaVA format, process as original format with open(file_path, "r", encoding="utf-8") as f: for line in f: if line.startswith("HTTP Request:"): continue try: data = json.loads(line.strip()) if "model" in data: model_name = data["model"] if all( k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"] ): valid_predictions.append(data) except json.JSONDecodeError: continue return model_name if model_name else file_path, valid_predictions def filter_common_questions( predictions_list: List[List[Dict[str, Any]]] ) -> List[List[Dict[str, Any]]]: """Ensure only questions that exist across all models are evaluated. Args: predictions_list: List of prediction lists from different models Returns: List[List[Dict[str, Any]]]: Filtered predictions containing only common questions """ question_sets = [ set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list ] common_questions = set.intersection(*question_sets) return [ [p for p in preds if (p["case_id"], p["question_id"]) in common_questions] for preds in predictions_list ] def calculate_accuracy( predictions: List[Dict[str, Any]] ) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]: """Compute overall and category-level accuracy. Args: predictions: List of prediction entries to analyze Returns: Tuple containing: - float: Overall accuracy percentage - int: Number of correct predictions - int: Total number of predictions - Dict[str, Dict[str, float]]: Category-level accuracy statistics """ if not predictions: return 0.0, 0, 0, {} category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) correct = 0 total = 0 sample_size = min(5, len(predictions)) sampled_indices = random.sample(range(len(predictions)), sample_size) print("\nSample extracted answers:") for i in sampled_indices: pred = predictions[i] model_ans = extract_letter_answer(pred["model_answer"]) correct_ans = extract_letter_answer(pred["correct_answer"]) print(f"QID: {pred['question_id']}") print(f" Raw Model Answer: {pred['model_answer']}") print(f" Extracted Model Answer: {model_ans}") print(f" Raw Correct Answer: {pred['correct_answer']}") print(f" Extracted Correct Answer: {correct_ans}") print("-" * 80) for pred in predictions: try: model_ans = extract_letter_answer(pred["model_answer"]) correct_ans = extract_letter_answer(pred["correct_answer"]) categories = ( pred.get("input", {}) .get("question_data", {}) .get("metadata", {}) .get("categories", []) ) if model_ans and correct_ans: total += 1 is_correct = model_ans == correct_ans if is_correct: correct += 1 for category in categories: category_performance[category]["total"] += 1 if is_correct: category_performance[category]["correct"] += 1 except KeyError: continue category_accuracies = { category: { "accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0, "total": stats["total"], "correct": stats["correct"], } for category, stats in category_performance.items() } return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies) def compare_models(file_paths: List[str]) -> None: """Compare accuracy between multiple model prediction files. Args: file_paths: List of paths to model prediction files to compare """ # Parse all files parsed_results = [parse_json_lines(file_path) for file_path in file_paths] model_names, predictions_list = zip(*parsed_results) # Get initial stats print(f"\nšŸ“Š **Initial Accuracy**:") results = [] category_results = [] for preds, name in zip(predictions_list, model_names): acc, correct, total, category_acc = calculate_accuracy(preds) results.append((acc, correct, total, name)) category_results.append(category_acc) print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") # Get common questions across all models filtered_predictions = filter_common_questions(predictions_list) print( f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}" ) # Compute accuracy on common questions print(f"\nšŸ“Š **Accuracy on Common Questions**:") filtered_results = [] filtered_category_results = [] for preds, name in zip(filtered_predictions, model_names): acc, correct, total, category_acc = calculate_accuracy(preds) filtered_results.append((acc, correct, total, name)) filtered_category_results.append(category_acc) print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") # Print category-wise accuracy print("\nCategory Performance (Common Questions):") for category in CATEGORY_ORDER: print(f"\n{category.capitalize()}:") for model_name, category_acc in zip(model_names, filtered_category_results): stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0}) print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})") def main(): parser = argparse.ArgumentParser( description="Compare accuracy across multiple model prediction files" ) parser.add_argument("files", nargs="+", help="Paths to model prediction files") parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") args = parser.parse_args() random.seed(args.seed) compare_models(args.files) if __name__ == "__main__": main()