Spaces:
Runtime error
Runtime error
File size: 12,609 Bytes
0c8d55e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import json
import os
import argparse
from collections import defaultdict
def calculate_wiscore(consistency, realism, aesthetic_quality):
"""Calculates the WiScore based on given components."""
return (0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality) / 2
# Define expected prompt ID ranges at a global level for easy access
EXPECTED_PROMPT_RANGES = {
"culture": range(1, 401),
"space-time": range(401, 701), # Covers both TIME (401-567) and SPACE (568-700)
"science": range(701, 1001), # Covers BIOLOGY (701-800), PHYSICS (801-900), CHEMISTRY (901-1000)
"all": range(1, 1001) # Full range for combined evaluation
}
def process_jsonl_file_segment(file_path, category_arg=None):
"""
Processes a segment of a JSONL file, collecting scores and present prompt_ids.
Performs prompt_id validation if a specific category_arg is provided for a single file.
Returns collected data or None if critical errors or missing prompt_ids (for single-file validation).
"""
segment_scores = defaultdict(list)
segment_present_prompt_ids = set()
if not os.path.exists(file_path):
print(f"Error: File '{file_path}' not found.")
return None
try:
with open(file_path, 'r', encoding='utf-8') as file:
for line_num, line in enumerate(file, 1):
try:
data = json.loads(line)
prompt_id = data.get('prompt_id')
if prompt_id is None:
print(f"Warning: File '{file_path}', Line {line_num}: Missing 'prompt_id'. Skipping this line.")
continue
if not isinstance(prompt_id, int):
print(f"Warning: File '{file_path}', Line {line_num}: 'prompt_id' is not an integer. Skipping this line.")
continue
segment_present_prompt_ids.add(prompt_id)
consistency = data.get('consistency')
realism = data.get('realism')
aesthetic_quality = data.get('aesthetic_quality')
if not all(isinstance(val, (int, float)) for val in [consistency, realism, aesthetic_quality]):
print(f"Warning: File '{file_path}', Line {line_num}: One or more score values are not numeric. Skipping this line for category calculation.")
continue
wiscore = calculate_wiscore(consistency, realism, aesthetic_quality)
# Determine category based on prompt_id
if 1 <= prompt_id <= 400:
segment_scores['CULTURE'].append(wiscore)
elif 401 <= prompt_id <= 567:
segment_scores['TIME'].append(wiscore)
elif 568 <= prompt_id <= 700:
segment_scores['SPACE'].append(wiscore)
elif 701 <= prompt_id <= 800:
segment_scores['BIOLOGY'].append(wiscore)
elif 801 <= prompt_id <= 900:
segment_scores['PHYSICS'].append(wiscore)
elif 901 <= prompt_id <= 1000:
segment_scores['CHEMISTRY'].append(wiscore)
else:
print(f"Warning: File '{file_path}', Line {line_num}: prompt_id {prompt_id} is outside defined categories. Skipping this line.")
continue
except json.JSONDecodeError:
print(f"Warning: File '{file_path}', Line {line_num}: Invalid JSON format. Skipping this line.")
except KeyError as e:
print(f"Warning: File '{file_path}', Line {line_num}: Missing expected key '{e}'. Skipping this line.")
except Exception as e:
print(f"Error reading file '{file_path}': {e}")
return None
# --- Single-file prompt_id validation logic ---
if category_arg and category_arg != 'all' and category_arg in EXPECTED_PROMPT_RANGES:
expected_ids_for_this_category = set(EXPECTED_PROMPT_RANGES[category_arg])
missing_ids_in_segment = expected_ids_for_this_category - segment_present_prompt_ids
if missing_ids_in_segment:
print(f"Error: File '{file_path}': When evaluating as '--category {category_arg}', "
f"missing the following prompt_ids: {sorted(list(missing_ids_in_segment))}")
return None # Return None if required prompt_ids are missing for a specific category file
return {
'scores': segment_scores,
'present_prompt_ids': segment_present_prompt_ids,
'file_path': file_path
}
def main():
parser = argparse.ArgumentParser(
description="Evaluate JSONL files for model performance, categorizing scores by prompt_id."
)
parser.add_argument(
'files',
metavar='FILE',
nargs='+', # Accepts one or more file paths
help="Path(s) to the JSONL file(s) to be evaluated (e.g., cultural_common_sense_ModelName_scores.jsonl)"
)
parser.add_argument(
'--category',
type=str,
choices=['culture', 'space-time', 'science', 'all'],
default='all',
help="Specify the category of the JSONL file(s) for specific prompt_id validation. Choose from 'culture', 'space-time', 'science', or 'all' (default). If evaluating a single category file, use the corresponding category."
)
args = parser.parse_args()
all_raw_results = []
# Process each file to collect raw scores and prompt IDs
for file_path in args.files:
print(f"\n--- Processing file: {file_path} ---")
# Pass the category argument to process_jsonl_file_segment
# This enables single-file validation logic
results = process_jsonl_file_segment(file_path, args.category if len(args.files) == 1 else None)
if results:
all_raw_results.append(results)
else:
print(f"Could not process '{file_path}'. Please check previous warnings/errors.")
if not all_raw_results:
print("No valid data processed from any of the provided files. Exiting.")
return # Exit if no files were successfully processed
# Aggregate data across all successful files
aggregated_scores = defaultdict(list)
combined_present_prompt_ids = set()
final_file_reports = {} # To store calculated averages/counts per file for individual display
for file_data in all_raw_results:
file_path = file_data['file_path']
combined_present_prompt_ids.update(file_data['present_prompt_ids'])
# Calculate scores for this individual file (for individual file report)
current_file_avg_scores = {}
current_file_num_samples = {}
detected_categories_in_file = []
for category, scores_list in file_data['scores'].items():
aggregated_scores[category].extend(scores_list) # Aggregate for overall score later
if scores_list: # Only add to individual file report if samples exist
current_file_avg_scores[category] = sum(scores_list) / len(scores_list)
current_file_num_samples[category] = len(scores_list)
detected_categories_in_file.append(category)
final_file_reports[file_path] = {
'average': current_file_avg_scores,
'num_processed_samples': current_file_num_samples,
'detected_categories': detected_categories_in_file
}
# --- Step 1: Validate Prompt IDs for 'all' category scenario ---
# This check happens only when --category all is explicitly chosen or is the default for multiple files.
# Single-file specific category validation happens inside process_jsonl_file_segment.
if args.category == 'all':
expected_prompt_ids_for_all = set(EXPECTED_PROMPT_RANGES['all'])
missing_prompt_ids_in_combined = expected_prompt_ids_for_all - combined_present_prompt_ids
if missing_prompt_ids_in_combined:
print(f"\nError: When '--category all' is specified, the combined files are missing the following prompt_ids:")
print(f"Missing IDs: {sorted(list(missing_prompt_ids_in_combined))}")
print("\nAborting overall evaluation due to incomplete data.")
return # Exit if combined prompt IDs are missing when 'all' is expected
# --- Step 2: Display individual file reports ---
print("\n" + "="*50)
print(" Individual File Reports")
print("="*50 + "\n")
ordered_categories = ['CULTURE', 'TIME', 'SPACE', 'BIOLOGY', 'PHYSICS', 'CHEMISTRY']
for file_path, file_data in final_file_reports.items():
print(f"--- Evaluation Results for File: {file_path} ---")
categories_to_print = sorted([cat for cat in ordered_categories if cat in file_data['detected_categories']],
key=lambda x: ordered_categories.index(x))
if not categories_to_print:
print(" No scores found for any defined categories in this file.")
else:
for category in categories_to_print:
avg_score = file_data['average'].get(category, 0)
sample_count = file_data['num_processed_samples'].get(category, 0)
print(f" Category: {category}")
print(f" Average WiScore: {avg_score:.2f}")
print(f" Number of samples: {sample_count}\n")
print("-" * (len(file_path) + 30) + "\n")
# --- Step 3: Calculate and Display Overall Summary (if applicable) ---
print("\n" + "="*50)
print(" Overall Evaluation Summary")
print("="*50 + "\n")
# Calculate overall averages from aggregated scores
overall_avg_scores = {
category: sum(scores) / len(scores) if len(scores) > 0 else 0
for category, scores in aggregated_scores.items()
}
overall_num_samples = {
category: len(scores)
for category, scores in aggregated_scores.items()
}
# Print overall category scores (only for categories that have samples)
overall_categories_to_print = sorted([cat for cat in ordered_categories if overall_num_samples.get(cat, 0) > 0],
key=lambda x: ordered_categories.index(x))
if not overall_categories_to_print and args.category != 'all':
print("No valid scores found for any categories in the aggregated data.")
else:
print("Aggregated Category Scores:")
for category in overall_categories_to_print:
print(f" Category: {category}")
print(f" Average WiScore: {overall_avg_scores.get(category, 0):.2f}")
print(f" Number of samples: {overall_num_samples.get(category, 0)}\n")
# Calculate and print Overall WiScore if '--category all' was specified and all categories have samples
all_categories_have_overall_samples = all(overall_num_samples.get(cat, 0) > 0 for cat in ordered_categories)
if args.category == 'all' and all_categories_have_overall_samples:
cultural_score = overall_avg_scores.get('CULTURE', 0)
time_score = overall_avg_scores.get('TIME', 0)
space_score = overall_avg_scores.get('SPACE', 0)
biology_score = overall_avg_scores.get('BIOLOGY', 0)
physics_score = overall_avg_scores.get('PHYSICS', 0)
chemistry_score = overall_avg_scores.get('CHEMISTRY', 0)
overall_wiscore = (0.4 * cultural_score + 0.167 * time_score + 0.133 * space_score +
0.1 * biology_score + 0.1 * physics_score + 0.1 * chemistry_score)
print("\n--- Overall WiScore Across All Categories ---")
print(f"Overall WiScore: {overall_wiscore:.2f}")
print("Cultural\tTime\tSpace\tBiology\tPhysics\tChemistry\tOverall")
print(f"{cultural_score:.2f}\t\t{time_score:.2f}\t{space_score:.2f}\t{biology_score:.2f}\t{physics_score:.2f}\t{chemistry_score:.2f}\t\t{overall_wiscore:.2f}")
elif args.category == 'all' and not all_categories_have_overall_samples:
print("\nOverall WiScore cannot be calculated: Not all categories have samples in the aggregated data when '--category all' is specified.")
else:
print(f"\nOverall WiScore calculation skipped. To calculate overall score, use '--category all' and provide files covering all prompt IDs.")
if __name__ == "__main__":
main() |