UniWorld-V1 / univa /eval /wise /step3_wise_cal.py
LinB203
init
0c8d55e
import json
import os
import argparse
from collections import defaultdict
def calculate_wiscore(consistency, realism, aesthetic_quality):
"""Calculates the WiScore based on given components."""
return (0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality) / 2
# Define expected prompt ID ranges at a global level for easy access
EXPECTED_PROMPT_RANGES = {
"culture": range(1, 401),
"space-time": range(401, 701), # Covers both TIME (401-567) and SPACE (568-700)
"science": range(701, 1001), # Covers BIOLOGY (701-800), PHYSICS (801-900), CHEMISTRY (901-1000)
"all": range(1, 1001) # Full range for combined evaluation
}
def process_jsonl_file_segment(file_path, category_arg=None):
"""
Processes a segment of a JSONL file, collecting scores and present prompt_ids.
Performs prompt_id validation if a specific category_arg is provided for a single file.
Returns collected data or None if critical errors or missing prompt_ids (for single-file validation).
"""
segment_scores = defaultdict(list)
segment_present_prompt_ids = set()
if not os.path.exists(file_path):
print(f"Error: File '{file_path}' not found.")
return None
try:
with open(file_path, 'r', encoding='utf-8') as file:
for line_num, line in enumerate(file, 1):
try:
data = json.loads(line)
prompt_id = data.get('prompt_id')
if prompt_id is None:
print(f"Warning: File '{file_path}', Line {line_num}: Missing 'prompt_id'. Skipping this line.")
continue
if not isinstance(prompt_id, int):
print(f"Warning: File '{file_path}', Line {line_num}: 'prompt_id' is not an integer. Skipping this line.")
continue
segment_present_prompt_ids.add(prompt_id)
consistency = data.get('consistency')
realism = data.get('realism')
aesthetic_quality = data.get('aesthetic_quality')
if not all(isinstance(val, (int, float)) for val in [consistency, realism, aesthetic_quality]):
print(f"Warning: File '{file_path}', Line {line_num}: One or more score values are not numeric. Skipping this line for category calculation.")
continue
wiscore = calculate_wiscore(consistency, realism, aesthetic_quality)
# Determine category based on prompt_id
if 1 <= prompt_id <= 400:
segment_scores['CULTURE'].append(wiscore)
elif 401 <= prompt_id <= 567:
segment_scores['TIME'].append(wiscore)
elif 568 <= prompt_id <= 700:
segment_scores['SPACE'].append(wiscore)
elif 701 <= prompt_id <= 800:
segment_scores['BIOLOGY'].append(wiscore)
elif 801 <= prompt_id <= 900:
segment_scores['PHYSICS'].append(wiscore)
elif 901 <= prompt_id <= 1000:
segment_scores['CHEMISTRY'].append(wiscore)
else:
print(f"Warning: File '{file_path}', Line {line_num}: prompt_id {prompt_id} is outside defined categories. Skipping this line.")
continue
except json.JSONDecodeError:
print(f"Warning: File '{file_path}', Line {line_num}: Invalid JSON format. Skipping this line.")
except KeyError as e:
print(f"Warning: File '{file_path}', Line {line_num}: Missing expected key '{e}'. Skipping this line.")
except Exception as e:
print(f"Error reading file '{file_path}': {e}")
return None
# --- Single-file prompt_id validation logic ---
if category_arg and category_arg != 'all' and category_arg in EXPECTED_PROMPT_RANGES:
expected_ids_for_this_category = set(EXPECTED_PROMPT_RANGES[category_arg])
missing_ids_in_segment = expected_ids_for_this_category - segment_present_prompt_ids
if missing_ids_in_segment:
print(f"Error: File '{file_path}': When evaluating as '--category {category_arg}', "
f"missing the following prompt_ids: {sorted(list(missing_ids_in_segment))}")
return None # Return None if required prompt_ids are missing for a specific category file
return {
'scores': segment_scores,
'present_prompt_ids': segment_present_prompt_ids,
'file_path': file_path
}
def main():
parser = argparse.ArgumentParser(
description="Evaluate JSONL files for model performance, categorizing scores by prompt_id."
)
parser.add_argument(
'files',
metavar='FILE',
nargs='+', # Accepts one or more file paths
help="Path(s) to the JSONL file(s) to be evaluated (e.g., cultural_common_sense_ModelName_scores.jsonl)"
)
parser.add_argument(
'--category',
type=str,
choices=['culture', 'space-time', 'science', 'all'],
default='all',
help="Specify the category of the JSONL file(s) for specific prompt_id validation. Choose from 'culture', 'space-time', 'science', or 'all' (default). If evaluating a single category file, use the corresponding category."
)
args = parser.parse_args()
all_raw_results = []
# Process each file to collect raw scores and prompt IDs
for file_path in args.files:
print(f"\n--- Processing file: {file_path} ---")
# Pass the category argument to process_jsonl_file_segment
# This enables single-file validation logic
results = process_jsonl_file_segment(file_path, args.category if len(args.files) == 1 else None)
if results:
all_raw_results.append(results)
else:
print(f"Could not process '{file_path}'. Please check previous warnings/errors.")
if not all_raw_results:
print("No valid data processed from any of the provided files. Exiting.")
return # Exit if no files were successfully processed
# Aggregate data across all successful files
aggregated_scores = defaultdict(list)
combined_present_prompt_ids = set()
final_file_reports = {} # To store calculated averages/counts per file for individual display
for file_data in all_raw_results:
file_path = file_data['file_path']
combined_present_prompt_ids.update(file_data['present_prompt_ids'])
# Calculate scores for this individual file (for individual file report)
current_file_avg_scores = {}
current_file_num_samples = {}
detected_categories_in_file = []
for category, scores_list in file_data['scores'].items():
aggregated_scores[category].extend(scores_list) # Aggregate for overall score later
if scores_list: # Only add to individual file report if samples exist
current_file_avg_scores[category] = sum(scores_list) / len(scores_list)
current_file_num_samples[category] = len(scores_list)
detected_categories_in_file.append(category)
final_file_reports[file_path] = {
'average': current_file_avg_scores,
'num_processed_samples': current_file_num_samples,
'detected_categories': detected_categories_in_file
}
# --- Step 1: Validate Prompt IDs for 'all' category scenario ---
# This check happens only when --category all is explicitly chosen or is the default for multiple files.
# Single-file specific category validation happens inside process_jsonl_file_segment.
if args.category == 'all':
expected_prompt_ids_for_all = set(EXPECTED_PROMPT_RANGES['all'])
missing_prompt_ids_in_combined = expected_prompt_ids_for_all - combined_present_prompt_ids
if missing_prompt_ids_in_combined:
print(f"\nError: When '--category all' is specified, the combined files are missing the following prompt_ids:")
print(f"Missing IDs: {sorted(list(missing_prompt_ids_in_combined))}")
print("\nAborting overall evaluation due to incomplete data.")
return # Exit if combined prompt IDs are missing when 'all' is expected
# --- Step 2: Display individual file reports ---
print("\n" + "="*50)
print(" Individual File Reports")
print("="*50 + "\n")
ordered_categories = ['CULTURE', 'TIME', 'SPACE', 'BIOLOGY', 'PHYSICS', 'CHEMISTRY']
for file_path, file_data in final_file_reports.items():
print(f"--- Evaluation Results for File: {file_path} ---")
categories_to_print = sorted([cat for cat in ordered_categories if cat in file_data['detected_categories']],
key=lambda x: ordered_categories.index(x))
if not categories_to_print:
print(" No scores found for any defined categories in this file.")
else:
for category in categories_to_print:
avg_score = file_data['average'].get(category, 0)
sample_count = file_data['num_processed_samples'].get(category, 0)
print(f" Category: {category}")
print(f" Average WiScore: {avg_score:.2f}")
print(f" Number of samples: {sample_count}\n")
print("-" * (len(file_path) + 30) + "\n")
# --- Step 3: Calculate and Display Overall Summary (if applicable) ---
print("\n" + "="*50)
print(" Overall Evaluation Summary")
print("="*50 + "\n")
# Calculate overall averages from aggregated scores
overall_avg_scores = {
category: sum(scores) / len(scores) if len(scores) > 0 else 0
for category, scores in aggregated_scores.items()
}
overall_num_samples = {
category: len(scores)
for category, scores in aggregated_scores.items()
}
# Print overall category scores (only for categories that have samples)
overall_categories_to_print = sorted([cat for cat in ordered_categories if overall_num_samples.get(cat, 0) > 0],
key=lambda x: ordered_categories.index(x))
if not overall_categories_to_print and args.category != 'all':
print("No valid scores found for any categories in the aggregated data.")
else:
print("Aggregated Category Scores:")
for category in overall_categories_to_print:
print(f" Category: {category}")
print(f" Average WiScore: {overall_avg_scores.get(category, 0):.2f}")
print(f" Number of samples: {overall_num_samples.get(category, 0)}\n")
# Calculate and print Overall WiScore if '--category all' was specified and all categories have samples
all_categories_have_overall_samples = all(overall_num_samples.get(cat, 0) > 0 for cat in ordered_categories)
if args.category == 'all' and all_categories_have_overall_samples:
cultural_score = overall_avg_scores.get('CULTURE', 0)
time_score = overall_avg_scores.get('TIME', 0)
space_score = overall_avg_scores.get('SPACE', 0)
biology_score = overall_avg_scores.get('BIOLOGY', 0)
physics_score = overall_avg_scores.get('PHYSICS', 0)
chemistry_score = overall_avg_scores.get('CHEMISTRY', 0)
overall_wiscore = (0.4 * cultural_score + 0.167 * time_score + 0.133 * space_score +
0.1 * biology_score + 0.1 * physics_score + 0.1 * chemistry_score)
print("\n--- Overall WiScore Across All Categories ---")
print(f"Overall WiScore: {overall_wiscore:.2f}")
print("Cultural\tTime\tSpace\tBiology\tPhysics\tChemistry\tOverall")
print(f"{cultural_score:.2f}\t\t{time_score:.2f}\t{space_score:.2f}\t{biology_score:.2f}\t{physics_score:.2f}\t{chemistry_score:.2f}\t\t{overall_wiscore:.2f}")
elif args.category == 'all' and not all_categories_have_overall_samples:
print("\nOverall WiScore cannot be calculated: Not all categories have samples in the aggregated data when '--category all' is specified.")
else:
print(f"\nOverall WiScore calculation skipped. To calculate overall score, use '--category all' and provide files covering all prompt IDs.")
if __name__ == "__main__":
main()