File size: 12,609 Bytes
0c8d55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import json
import os
import argparse
from collections import defaultdict

def calculate_wiscore(consistency, realism, aesthetic_quality):
    """Calculates the WiScore based on given components."""
    return (0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality) / 2

# Define expected prompt ID ranges at a global level for easy access
EXPECTED_PROMPT_RANGES = {
    "culture": range(1, 401),
    "space-time": range(401, 701), # Covers both TIME (401-567) and SPACE (568-700)
    "science": range(701, 1001), # Covers BIOLOGY (701-800), PHYSICS (801-900), CHEMISTRY (901-1000)
    "all": range(1, 1001) # Full range for combined evaluation
}

def process_jsonl_file_segment(file_path, category_arg=None):
    """
    Processes a segment of a JSONL file, collecting scores and present prompt_ids.
    Performs prompt_id validation if a specific category_arg is provided for a single file.
    Returns collected data or None if critical errors or missing prompt_ids (for single-file validation).
    """
    segment_scores = defaultdict(list)
    segment_present_prompt_ids = set()
    
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' not found.")
        return None

    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line_num, line in enumerate(file, 1):
                try:
                    data = json.loads(line)
                    
                    prompt_id = data.get('prompt_id')
                    if prompt_id is None:
                        print(f"Warning: File '{file_path}', Line {line_num}: Missing 'prompt_id'. Skipping this line.")
                        continue
                    
                    if not isinstance(prompt_id, int):
                        print(f"Warning: File '{file_path}', Line {line_num}: 'prompt_id' is not an integer. Skipping this line.")
                        continue

                    segment_present_prompt_ids.add(prompt_id)
                    
                    consistency = data.get('consistency')
                    realism = data.get('realism')
                    aesthetic_quality = data.get('aesthetic_quality')

                    if not all(isinstance(val, (int, float)) for val in [consistency, realism, aesthetic_quality]):
                        print(f"Warning: File '{file_path}', Line {line_num}: One or more score values are not numeric. Skipping this line for category calculation.")
                        continue
                    
                    wiscore = calculate_wiscore(consistency, realism, aesthetic_quality)

                    # Determine category based on prompt_id
                    if 1 <= prompt_id <= 400:
                        segment_scores['CULTURE'].append(wiscore)
                    elif 401 <= prompt_id <= 567:
                        segment_scores['TIME'].append(wiscore)
                    elif 568 <= prompt_id <= 700:
                        segment_scores['SPACE'].append(wiscore)
                    elif 701 <= prompt_id <= 800:
                        segment_scores['BIOLOGY'].append(wiscore)
                    elif 801 <= prompt_id <= 900:
                        segment_scores['PHYSICS'].append(wiscore)
                    elif 901 <= prompt_id <= 1000:
                        segment_scores['CHEMISTRY'].append(wiscore)
                    else:
                        print(f"Warning: File '{file_path}', Line {line_num}: prompt_id {prompt_id} is outside defined categories. Skipping this line.")
                        continue

                except json.JSONDecodeError:
                    print(f"Warning: File '{file_path}', Line {line_num}: Invalid JSON format. Skipping this line.")
                except KeyError as e:
                    print(f"Warning: File '{file_path}', Line {line_num}: Missing expected key '{e}'. Skipping this line.")
    except Exception as e:
        print(f"Error reading file '{file_path}': {e}")
        return None
    
    # --- Single-file prompt_id validation logic ---
    if category_arg and category_arg != 'all' and category_arg in EXPECTED_PROMPT_RANGES:
        expected_ids_for_this_category = set(EXPECTED_PROMPT_RANGES[category_arg])
        missing_ids_in_segment = expected_ids_for_this_category - segment_present_prompt_ids

        if missing_ids_in_segment:
            print(f"Error: File '{file_path}': When evaluating as '--category {category_arg}', "
                  f"missing the following prompt_ids: {sorted(list(missing_ids_in_segment))}")
            return None # Return None if required prompt_ids are missing for a specific category file
    
    return {
        'scores': segment_scores,
        'present_prompt_ids': segment_present_prompt_ids,
        'file_path': file_path
    }

def main():
    parser = argparse.ArgumentParser(
        description="Evaluate JSONL files for model performance, categorizing scores by prompt_id."
    )
    parser.add_argument(
        'files',
        metavar='FILE',
        nargs='+',  # Accepts one or more file paths
        help="Path(s) to the JSONL file(s) to be evaluated (e.g., cultural_common_sense_ModelName_scores.jsonl)"
    )
    parser.add_argument(
        '--category',
        type=str,
        choices=['culture', 'space-time', 'science', 'all'],
        default='all',
        help="Specify the category of the JSONL file(s) for specific prompt_id validation. Choose from 'culture', 'space-time', 'science', or 'all' (default). If evaluating a single category file, use the corresponding category."
    )
    
    args = parser.parse_args()
    
    all_raw_results = []
    
    # Process each file to collect raw scores and prompt IDs
    for file_path in args.files:
        print(f"\n--- Processing file: {file_path} ---")
        # Pass the category argument to process_jsonl_file_segment
        # This enables single-file validation logic
        results = process_jsonl_file_segment(file_path, args.category if len(args.files) == 1 else None)
        if results:
            all_raw_results.append(results)
        else:
            print(f"Could not process '{file_path}'. Please check previous warnings/errors.")

    if not all_raw_results:
        print("No valid data processed from any of the provided files. Exiting.")
        return # Exit if no files were successfully processed

    # Aggregate data across all successful files
    aggregated_scores = defaultdict(list)
    combined_present_prompt_ids = set()
    final_file_reports = {} # To store calculated averages/counts per file for individual display

    for file_data in all_raw_results:
        file_path = file_data['file_path']
        combined_present_prompt_ids.update(file_data['present_prompt_ids'])
        
        # Calculate scores for this individual file (for individual file report)
        current_file_avg_scores = {}
        current_file_num_samples = {}
        detected_categories_in_file = []

        for category, scores_list in file_data['scores'].items():
            aggregated_scores[category].extend(scores_list) # Aggregate for overall score later
            
            if scores_list: # Only add to individual file report if samples exist
                current_file_avg_scores[category] = sum(scores_list) / len(scores_list)
                current_file_num_samples[category] = len(scores_list)
                detected_categories_in_file.append(category)

        final_file_reports[file_path] = {
            'average': current_file_avg_scores,
            'num_processed_samples': current_file_num_samples,
            'detected_categories': detected_categories_in_file
        }

    # --- Step 1: Validate Prompt IDs for 'all' category scenario ---
    # This check happens only when --category all is explicitly chosen or is the default for multiple files.
    # Single-file specific category validation happens inside process_jsonl_file_segment.
    if args.category == 'all':
        expected_prompt_ids_for_all = set(EXPECTED_PROMPT_RANGES['all'])
        missing_prompt_ids_in_combined = expected_prompt_ids_for_all - combined_present_prompt_ids

        if missing_prompt_ids_in_combined:
            print(f"\nError: When '--category all' is specified, the combined files are missing the following prompt_ids:")
            print(f"Missing IDs: {sorted(list(missing_prompt_ids_in_combined))}")
            print("\nAborting overall evaluation due to incomplete data.")
            return # Exit if combined prompt IDs are missing when 'all' is expected

    # --- Step 2: Display individual file reports ---
    print("\n" + "="*50)
    print("                 Individual File Reports")
    print("="*50 + "\n")

    ordered_categories = ['CULTURE', 'TIME', 'SPACE', 'BIOLOGY', 'PHYSICS', 'CHEMISTRY']

    for file_path, file_data in final_file_reports.items():
        print(f"--- Evaluation Results for File: {file_path} ---")
        
        categories_to_print = sorted([cat for cat in ordered_categories if cat in file_data['detected_categories']],
                                     key=lambda x: ordered_categories.index(x))

        if not categories_to_print:
            print("  No scores found for any defined categories in this file.")
        else:
            for category in categories_to_print:
                avg_score = file_data['average'].get(category, 0)
                sample_count = file_data['num_processed_samples'].get(category, 0)
                print(f"  Category: {category}")
                print(f"    Average WiScore: {avg_score:.2f}")
                print(f"    Number of samples: {sample_count}\n")
        print("-" * (len(file_path) + 30) + "\n")

    # --- Step 3: Calculate and Display Overall Summary (if applicable) ---
    print("\n" + "="*50)
    print("                 Overall Evaluation Summary")
    print("="*50 + "\n")

    # Calculate overall averages from aggregated scores
    overall_avg_scores = {
        category: sum(scores) / len(scores) if len(scores) > 0 else 0
        for category, scores in aggregated_scores.items()
    }
    overall_num_samples = {
        category: len(scores)
        for category, scores in aggregated_scores.items()
    }

    # Print overall category scores (only for categories that have samples)
    overall_categories_to_print = sorted([cat for cat in ordered_categories if overall_num_samples.get(cat, 0) > 0],
                                          key=lambda x: ordered_categories.index(x))
    
    if not overall_categories_to_print and args.category != 'all':
        print("No valid scores found for any categories in the aggregated data.")
    else:
        print("Aggregated Category Scores:")
        for category in overall_categories_to_print:
            print(f"  Category: {category}")
            print(f"    Average WiScore: {overall_avg_scores.get(category, 0):.2f}")
            print(f"    Number of samples: {overall_num_samples.get(category, 0)}\n")

    # Calculate and print Overall WiScore if '--category all' was specified and all categories have samples
    all_categories_have_overall_samples = all(overall_num_samples.get(cat, 0) > 0 for cat in ordered_categories)
    
    if args.category == 'all' and all_categories_have_overall_samples:
        cultural_score = overall_avg_scores.get('CULTURE', 0)
        time_score = overall_avg_scores.get('TIME', 0)
        space_score = overall_avg_scores.get('SPACE', 0)
        biology_score = overall_avg_scores.get('BIOLOGY', 0)
        physics_score = overall_avg_scores.get('PHYSICS', 0)
        chemistry_score = overall_avg_scores.get('CHEMISTRY', 0)

        overall_wiscore = (0.4 * cultural_score + 0.167 * time_score + 0.133 * space_score +
                           0.1 * biology_score + 0.1 * physics_score + 0.1 * chemistry_score)
        
        print("\n--- Overall WiScore Across All Categories ---")
        print(f"Overall WiScore: {overall_wiscore:.2f}")
        print("Cultural\tTime\tSpace\tBiology\tPhysics\tChemistry\tOverall")
        print(f"{cultural_score:.2f}\t\t{time_score:.2f}\t{space_score:.2f}\t{biology_score:.2f}\t{physics_score:.2f}\t{chemistry_score:.2f}\t\t{overall_wiscore:.2f}")
    elif args.category == 'all' and not all_categories_have_overall_samples:
        print("\nOverall WiScore cannot be calculated: Not all categories have samples in the aggregated data when '--category all' is specified.")
    else:
        print(f"\nOverall WiScore calculation skipped. To calculate overall score, use '--category all' and provide files covering all prompt IDs.")


if __name__ == "__main__":
    main()