zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Sequence
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.logging import print_log
from rich.console import Console
from rich.table import Table
from xtuner.registry import BUILDER
class MMLUMetric(BaseMetric):
METAINFO = {
'subcategories': {
'abstract_algebra': ['math'],
'anatomy': ['health'],
'astronomy': ['physics'],
'business_ethics': ['business'],
'clinical_knowledge': ['health'],
'college_biology': ['biology'],
'college_chemistry': ['chemistry'],
'college_computer_science': ['computer science'],
'college_mathematics': ['math'],
'college_medicine': ['health'],
'college_physics': ['physics'],
'computer_security': ['computer science'],
'conceptual_physics': ['physics'],
'econometrics': ['economics'],
'electrical_engineering': ['engineering'],
'elementary_mathematics': ['math'],
'formal_logic': ['philosophy'],
'global_facts': ['other'],
'high_school_biology': ['biology'],
'high_school_chemistry': ['chemistry'],
'high_school_computer_science': ['computer science'],
'high_school_european_history': ['history'],
'high_school_geography': ['geography'],
'high_school_government_and_politics': ['politics'],
'high_school_macroeconomics': ['economics'],
'high_school_mathematics': ['math'],
'high_school_microeconomics': ['economics'],
'high_school_physics': ['physics'],
'high_school_psychology': ['psychology'],
'high_school_statistics': ['math'],
'high_school_us_history': ['history'],
'high_school_world_history': ['history'],
'human_aging': ['health'],
'human_sexuality': ['culture'],
'international_law': ['law'],
'jurisprudence': ['law'],
'logical_fallacies': ['philosophy'],
'machine_learning': ['computer science'],
'management': ['business'],
'marketing': ['business'],
'medical_genetics': ['health'],
'miscellaneous': ['other'],
'moral_disputes': ['philosophy'],
'moral_scenarios': ['philosophy'],
'nutrition': ['health'],
'philosophy': ['philosophy'],
'prehistory': ['history'],
'professional_accounting': ['other'],
'professional_law': ['law'],
'professional_medicine': ['health'],
'professional_psychology': ['psychology'],
'public_relations': ['politics'],
'security_studies': ['politics'],
'sociology': ['culture'],
'us_foreign_policy': ['politics'],
'virology': ['health'],
'world_religions': ['philosophy'],
},
'categories': {
'STEM': [
'physics', 'chemistry', 'biology', 'computer science', 'math',
'engineering'
],
'humanities': ['history', 'philosophy', 'law'],
'social sciences':
['politics', 'culture', 'economics', 'geography', 'psychology'],
'other (business, health, misc.)': ['other', 'business', 'health'],
},
}
METAINFO['subcategories_list'] = list({
subcat
for subcats in METAINFO['subcategories'].values() for subcat in subcats
})
def __init__(self, tokenizer, *args, **kwargs):
super().__init__(*args, **kwargs)
tokenizer = BUILDER.build(tokenizer)
self.abcd_idx = [
tokenizer.encode('A', add_special_tokens=False)[0],
tokenizer.encode('B', add_special_tokens=False)[0],
tokenizer.encode('C', add_special_tokens=False)[0],
tokenizer.encode('D', add_special_tokens=False)[0],
]
@staticmethod
def ABCD_to_0123(abcd):
return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd]
@staticmethod
def find_first_zero_index(tensor):
indices = torch.nonzero(tensor == 0)
if indices.numel() > 0:
return indices[0].item()
else:
return None
@staticmethod
def accuracy(preds, gts):
"""Computes the accuracy for preds and gts."""
correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)]
acc = np.mean(correct) * 100
return acc
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Any): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""
subjects = data_batch['data_samples']['subjects']
gts = [
self.ABCD_to_0123(gt)
for gt in data_batch['data_samples']['labels']
]
preds = []
for sample, attn_mask, subject, gt in zip(
data_samples, data_batch['data']['attention_mask'], subjects,
gts):
pred_logits = sample['logits']
first_zero_idx = self.find_first_zero_index(attn_mask)
pred_idx = -1 if first_zero_idx is None else first_zero_idx - 1
pred_logtis_abcd = pred_logits[pred_idx, self.abcd_idx]
pred = torch.argmax(pred_logtis_abcd).item()
preds.append(pred)
self.results.append((subject, pred, gt))
def compute_metrics(self, results: list) -> dict:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
subjects_results = {
subject: {
'preds': [],
'gts': []
}
for subject in self.METAINFO['subcategories'].keys()
}
subcats_results = {
subcat: {
'preds': [],
'gts': []
}
for subcat in self.METAINFO['subcategories_list']
}
cats_results = {
cat: {
'preds': [],
'gts': []
}
for cat in self.METAINFO['categories'].keys()
}
for subject, pred, gt in results:
subjects_results[subject]['preds'].append(pred)
subjects_results[subject]['gts'].append(gt)
subcats = self.METAINFO['subcategories'][subject]
for subcat in subcats:
subcats_results[subcat]['preds'].append(pred)
subcats_results[subcat]['gts'].append(gt)
for cat, subcats in self.METAINFO['categories'].items():
for subcat in subcats:
if subcat in subcats_results:
cats_results[cat]['preds'].extend(
subcats_results[subcat]['preds'])
cats_results[cat]['gts'].extend(
subcats_results[subcat]['gts'])
subjects_metrics = dict()
subcats_metrics = dict()
cats_metrics = dict()
for subject in self.METAINFO['subcategories'].keys():
assert len(subjects_results[subject]['preds']) == len(
subjects_results[subject]['gts'])
if len(subjects_results[subject]['preds']) == 0:
print_log(f'Skip subject {subject} for mmlu', 'current')
else:
score = self.accuracy(subjects_results[subject]['preds'],
subjects_results[subject]['gts'])
subjects_metrics[f'{subject}'] = score
for subcat in self.METAINFO['subcategories_list']:
assert len(subcats_results[subcat]['preds']) == len(
subcats_results[subcat]['gts'])
if len(subcats_results[subcat]['preds']) == 0:
print_log(f'Skip subcategory {subcat} for mmlu', 'current')
else:
score = self.accuracy(subcats_results[subcat]['preds'],
subcats_results[subcat]['gts'])
subcats_metrics[f'{subcat}'] = score
for cat in self.METAINFO['categories'].keys():
assert len(cats_results[cat]['preds']) == len(
cats_results[cat]['gts'])
if len(cats_results[cat]['preds']) == 0:
print_log(f'Skip category {cat} for mmlu', 'current')
else:
score = self.accuracy(cats_results[cat]['preds'],
cats_results[cat]['gts'])
cats_metrics[f'{cat}'] = score
metrics = dict()
metrics.update(subjects_metrics)
metrics.update(subcats_metrics)
metrics.update(cats_metrics)
metrics['average'] = np.mean(list(subjects_metrics.values()))
table_metrics = dict()
table_metrics.update(cats_metrics)
table_metrics['average'] = np.mean(list(subjects_metrics.values()))
self._print_results(table_metrics)
return metrics
def _print_results(self, table_metrics: dict) -> None:
table_title = ' MMLU Benchmark '
table = Table(title=table_title)
console = Console()
table.add_column('Categories', justify='left')
table.add_column('Accuracy (%)', justify='right')
for cat, acc in table_metrics.items():
table.add_row(cat, f'{acc:.1f}')
with console.capture() as capture:
console.print(table, end='')
print_log('\n' + capture.get(), 'current')