medrax.org / experiments /compare_runs.py
oldcai's picture
Upload folder using huggingface_hub
d7a7846 verified
import json
import argparse
import random
from typing import List, Dict, Any, Tuple
import re
from collections import defaultdict
# Define category order
CATEGORY_ORDER = [
"detection",
"classification",
"localization",
"comparison",
"relationship",
"diagnosis",
"characterization",
]
def extract_letter_answer(answer: str) -> str:
"""Extract just the letter answer from various answer formats.
Args:
answer: The answer string to extract a letter from
Returns:
str: The extracted letter in uppercase, or empty string if no letter found
"""
if not answer:
return ""
# Convert to string and clean
answer = str(answer).strip()
# If it's just a single letter A-F, return it
if len(answer) == 1 and answer.upper() in "ABCDEF":
return answer.upper()
# Try to match patterns like "A)", "A.", "A ", etc.
match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE)
if match:
return match.group(1).upper()
# Try to find any standalone A-F letters preceded by space or start of string
# and followed by space, period, parenthesis or end of string
matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE)
if matches:
return matches[0].upper()
# Last resort: just find any A-F letter
letters = re.findall(r"[A-F]", answer, re.IGNORECASE)
if letters:
return letters[0].upper()
# If no letter found, return original (cleaned)
return answer.strip().upper()
def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]:
"""Parse JSON Lines file and extract valid predictions.
Args:
file_path: Path to the JSON Lines file to parse
Returns:
Tuple containing:
- str: Model name or file path if model name not found
- List[Dict[str, Any]]: List of valid prediction entries
"""
valid_predictions = []
model_name = None
# First try to parse as LLaVA format
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if data.get("model") == "llava-med-v1.5-mistral-7b":
model_name = data["model"]
for result in data.get("results", []):
if all(k in result for k in ["case_id", "question_id", "correct_answer"]):
# Extract answer with priority: model_answer > validated_answer > raw_output
model_answer = (
result.get("model_answer")
or result.get("validated_answer")
or result.get("raw_output", "")
)
# Add default categories for LLaVA results
prediction = {
"case_id": result["case_id"],
"question_id": result["question_id"],
"model_answer": model_answer,
"correct_answer": result["correct_answer"],
"input": {
"question_data": {
"metadata": {
"categories": [
"detection",
"classification",
"localization",
"comparison",
"relationship",
"diagnosis",
"characterization",
]
}
}
},
}
valid_predictions.append(prediction)
return model_name, valid_predictions
except (json.JSONDecodeError, KeyError):
pass
# If not LLaVA format, process as original format
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if line.startswith("HTTP Request:"):
continue
try:
data = json.loads(line.strip())
if "model" in data:
model_name = data["model"]
if all(
k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"]
):
valid_predictions.append(data)
except json.JSONDecodeError:
continue
return model_name if model_name else file_path, valid_predictions
def filter_common_questions(
predictions_list: List[List[Dict[str, Any]]]
) -> List[List[Dict[str, Any]]]:
"""Ensure only questions that exist across all models are evaluated.
Args:
predictions_list: List of prediction lists from different models
Returns:
List[List[Dict[str, Any]]]: Filtered predictions containing only common questions
"""
question_sets = [
set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list
]
common_questions = set.intersection(*question_sets)
return [
[p for p in preds if (p["case_id"], p["question_id"]) in common_questions]
for preds in predictions_list
]
def calculate_accuracy(
predictions: List[Dict[str, Any]]
) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]:
"""Compute overall and category-level accuracy.
Args:
predictions: List of prediction entries to analyze
Returns:
Tuple containing:
- float: Overall accuracy percentage
- int: Number of correct predictions
- int: Total number of predictions
- Dict[str, Dict[str, float]]: Category-level accuracy statistics
"""
if not predictions:
return 0.0, 0, 0, {}
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
correct = 0
total = 0
sample_size = min(5, len(predictions))
sampled_indices = random.sample(range(len(predictions)), sample_size)
print("\nSample extracted answers:")
for i in sampled_indices:
pred = predictions[i]
model_ans = extract_letter_answer(pred["model_answer"])
correct_ans = extract_letter_answer(pred["correct_answer"])
print(f"QID: {pred['question_id']}")
print(f" Raw Model Answer: {pred['model_answer']}")
print(f" Extracted Model Answer: {model_ans}")
print(f" Raw Correct Answer: {pred['correct_answer']}")
print(f" Extracted Correct Answer: {correct_ans}")
print("-" * 80)
for pred in predictions:
try:
model_ans = extract_letter_answer(pred["model_answer"])
correct_ans = extract_letter_answer(pred["correct_answer"])
categories = (
pred.get("input", {})
.get("question_data", {})
.get("metadata", {})
.get("categories", [])
)
if model_ans and correct_ans:
total += 1
is_correct = model_ans == correct_ans
if is_correct:
correct += 1
for category in categories:
category_performance[category]["total"] += 1
if is_correct:
category_performance[category]["correct"] += 1
except KeyError:
continue
category_accuracies = {
category: {
"accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0,
"total": stats["total"],
"correct": stats["correct"],
}
for category, stats in category_performance.items()
}
return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies)
def compare_models(file_paths: List[str]) -> None:
"""Compare accuracy between multiple model prediction files.
Args:
file_paths: List of paths to model prediction files to compare
"""
# Parse all files
parsed_results = [parse_json_lines(file_path) for file_path in file_paths]
model_names, predictions_list = zip(*parsed_results)
# Get initial stats
print(f"\n📊 **Initial Accuracy**:")
results = []
category_results = []
for preds, name in zip(predictions_list, model_names):
acc, correct, total, category_acc = calculate_accuracy(preds)
results.append((acc, correct, total, name))
category_results.append(category_acc)
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
# Get common questions across all models
filtered_predictions = filter_common_questions(predictions_list)
print(
f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}"
)
# Compute accuracy on common questions
print(f"\n📊 **Accuracy on Common Questions**:")
filtered_results = []
filtered_category_results = []
for preds, name in zip(filtered_predictions, model_names):
acc, correct, total, category_acc = calculate_accuracy(preds)
filtered_results.append((acc, correct, total, name))
filtered_category_results.append(category_acc)
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
# Print category-wise accuracy
print("\nCategory Performance (Common Questions):")
for category in CATEGORY_ORDER:
print(f"\n{category.capitalize()}:")
for model_name, category_acc in zip(model_names, filtered_category_results):
stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0})
print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})")
def main():
parser = argparse.ArgumentParser(
description="Compare accuracy across multiple model prediction files"
)
parser.add_argument("files", nargs="+", help="Paths to model prediction files")
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
args = parser.parse_args()
random.seed(args.seed)
compare_models(args.files)
if __name__ == "__main__":
main()