Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Any, Sequence | |
import numpy as np | |
import torch | |
from mmengine.evaluator import BaseMetric | |
from mmengine.logging import print_log | |
from rich.console import Console | |
from rich.table import Table | |
from xtuner.registry import BUILDER | |
class MMLUMetric(BaseMetric): | |
METAINFO = { | |
'subcategories': { | |
'abstract_algebra': ['math'], | |
'anatomy': ['health'], | |
'astronomy': ['physics'], | |
'business_ethics': ['business'], | |
'clinical_knowledge': ['health'], | |
'college_biology': ['biology'], | |
'college_chemistry': ['chemistry'], | |
'college_computer_science': ['computer science'], | |
'college_mathematics': ['math'], | |
'college_medicine': ['health'], | |
'college_physics': ['physics'], | |
'computer_security': ['computer science'], | |
'conceptual_physics': ['physics'], | |
'econometrics': ['economics'], | |
'electrical_engineering': ['engineering'], | |
'elementary_mathematics': ['math'], | |
'formal_logic': ['philosophy'], | |
'global_facts': ['other'], | |
'high_school_biology': ['biology'], | |
'high_school_chemistry': ['chemistry'], | |
'high_school_computer_science': ['computer science'], | |
'high_school_european_history': ['history'], | |
'high_school_geography': ['geography'], | |
'high_school_government_and_politics': ['politics'], | |
'high_school_macroeconomics': ['economics'], | |
'high_school_mathematics': ['math'], | |
'high_school_microeconomics': ['economics'], | |
'high_school_physics': ['physics'], | |
'high_school_psychology': ['psychology'], | |
'high_school_statistics': ['math'], | |
'high_school_us_history': ['history'], | |
'high_school_world_history': ['history'], | |
'human_aging': ['health'], | |
'human_sexuality': ['culture'], | |
'international_law': ['law'], | |
'jurisprudence': ['law'], | |
'logical_fallacies': ['philosophy'], | |
'machine_learning': ['computer science'], | |
'management': ['business'], | |
'marketing': ['business'], | |
'medical_genetics': ['health'], | |
'miscellaneous': ['other'], | |
'moral_disputes': ['philosophy'], | |
'moral_scenarios': ['philosophy'], | |
'nutrition': ['health'], | |
'philosophy': ['philosophy'], | |
'prehistory': ['history'], | |
'professional_accounting': ['other'], | |
'professional_law': ['law'], | |
'professional_medicine': ['health'], | |
'professional_psychology': ['psychology'], | |
'public_relations': ['politics'], | |
'security_studies': ['politics'], | |
'sociology': ['culture'], | |
'us_foreign_policy': ['politics'], | |
'virology': ['health'], | |
'world_religions': ['philosophy'], | |
}, | |
'categories': { | |
'STEM': [ | |
'physics', 'chemistry', 'biology', 'computer science', 'math', | |
'engineering' | |
], | |
'humanities': ['history', 'philosophy', 'law'], | |
'social sciences': | |
['politics', 'culture', 'economics', 'geography', 'psychology'], | |
'other (business, health, misc.)': ['other', 'business', 'health'], | |
}, | |
} | |
METAINFO['subcategories_list'] = list({ | |
subcat | |
for subcats in METAINFO['subcategories'].values() for subcat in subcats | |
}) | |
def __init__(self, tokenizer, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
tokenizer = BUILDER.build(tokenizer) | |
self.abcd_idx = [ | |
tokenizer.encode('A', add_special_tokens=False)[0], | |
tokenizer.encode('B', add_special_tokens=False)[0], | |
tokenizer.encode('C', add_special_tokens=False)[0], | |
tokenizer.encode('D', add_special_tokens=False)[0], | |
] | |
def ABCD_to_0123(abcd): | |
return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd] | |
def find_first_zero_index(tensor): | |
indices = torch.nonzero(tensor == 0) | |
if indices.numel() > 0: | |
return indices[0].item() | |
else: | |
return None | |
def accuracy(preds, gts): | |
"""Computes the accuracy for preds and gts.""" | |
correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)] | |
acc = np.mean(correct) * 100 | |
return acc | |
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: | |
"""Process one batch of data samples and predictions. The processed | |
results should be stored in ``self.results``, which will be used to | |
compute the metrics when all batches have been processed. | |
Args: | |
data_batch (Any): A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of outputs from | |
the model. | |
""" | |
subjects = data_batch['data_samples']['subjects'] | |
gts = [ | |
self.ABCD_to_0123(gt) | |
for gt in data_batch['data_samples']['labels'] | |
] | |
preds = [] | |
for sample, attn_mask, subject, gt in zip( | |
data_samples, data_batch['data']['attention_mask'], subjects, | |
gts): | |
pred_logits = sample['logits'] | |
first_zero_idx = self.find_first_zero_index(attn_mask) | |
pred_idx = -1 if first_zero_idx is None else first_zero_idx - 1 | |
pred_logtis_abcd = pred_logits[pred_idx, self.abcd_idx] | |
pred = torch.argmax(pred_logtis_abcd).item() | |
preds.append(pred) | |
self.results.append((subject, pred, gt)) | |
def compute_metrics(self, results: list) -> dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (list): The processed results of each batch. | |
Returns: | |
dict: The computed metrics. The keys are the names of the metrics, | |
and the values are corresponding results. | |
""" | |
subjects_results = { | |
subject: { | |
'preds': [], | |
'gts': [] | |
} | |
for subject in self.METAINFO['subcategories'].keys() | |
} | |
subcats_results = { | |
subcat: { | |
'preds': [], | |
'gts': [] | |
} | |
for subcat in self.METAINFO['subcategories_list'] | |
} | |
cats_results = { | |
cat: { | |
'preds': [], | |
'gts': [] | |
} | |
for cat in self.METAINFO['categories'].keys() | |
} | |
for subject, pred, gt in results: | |
subjects_results[subject]['preds'].append(pred) | |
subjects_results[subject]['gts'].append(gt) | |
subcats = self.METAINFO['subcategories'][subject] | |
for subcat in subcats: | |
subcats_results[subcat]['preds'].append(pred) | |
subcats_results[subcat]['gts'].append(gt) | |
for cat, subcats in self.METAINFO['categories'].items(): | |
for subcat in subcats: | |
if subcat in subcats_results: | |
cats_results[cat]['preds'].extend( | |
subcats_results[subcat]['preds']) | |
cats_results[cat]['gts'].extend( | |
subcats_results[subcat]['gts']) | |
subjects_metrics = dict() | |
subcats_metrics = dict() | |
cats_metrics = dict() | |
for subject in self.METAINFO['subcategories'].keys(): | |
assert len(subjects_results[subject]['preds']) == len( | |
subjects_results[subject]['gts']) | |
if len(subjects_results[subject]['preds']) == 0: | |
print_log(f'Skip subject {subject} for mmlu', 'current') | |
else: | |
score = self.accuracy(subjects_results[subject]['preds'], | |
subjects_results[subject]['gts']) | |
subjects_metrics[f'{subject}'] = score | |
for subcat in self.METAINFO['subcategories_list']: | |
assert len(subcats_results[subcat]['preds']) == len( | |
subcats_results[subcat]['gts']) | |
if len(subcats_results[subcat]['preds']) == 0: | |
print_log(f'Skip subcategory {subcat} for mmlu', 'current') | |
else: | |
score = self.accuracy(subcats_results[subcat]['preds'], | |
subcats_results[subcat]['gts']) | |
subcats_metrics[f'{subcat}'] = score | |
for cat in self.METAINFO['categories'].keys(): | |
assert len(cats_results[cat]['preds']) == len( | |
cats_results[cat]['gts']) | |
if len(cats_results[cat]['preds']) == 0: | |
print_log(f'Skip category {cat} for mmlu', 'current') | |
else: | |
score = self.accuracy(cats_results[cat]['preds'], | |
cats_results[cat]['gts']) | |
cats_metrics[f'{cat}'] = score | |
metrics = dict() | |
metrics.update(subjects_metrics) | |
metrics.update(subcats_metrics) | |
metrics.update(cats_metrics) | |
metrics['average'] = np.mean(list(subjects_metrics.values())) | |
table_metrics = dict() | |
table_metrics.update(cats_metrics) | |
table_metrics['average'] = np.mean(list(subjects_metrics.values())) | |
self._print_results(table_metrics) | |
return metrics | |
def _print_results(self, table_metrics: dict) -> None: | |
table_title = ' MMLU Benchmark ' | |
table = Table(title=table_title) | |
console = Console() | |
table.add_column('Categories', justify='left') | |
table.add_column('Accuracy (%)', justify='right') | |
for cat, acc in table_metrics.items(): | |
table.add_row(cat, f'{acc:.1f}') | |
with console.capture() as capture: | |
console.print(table, end='') | |
print_log('\n' + capture.get(), 'current') | |