File size: 7,477 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Aggregate ICL evals into composite scores."""
import logging
import math
from enum import Enum
from typing import Dict, Optional
from composer.core import Callback, State
from composer.loggers import Logger
__all__ = ['EvalGauntlet']
log = logging.getLogger(__name__)
class Weighting(Enum):
EQUAL = 1
SAMPLE_SZ = 2
LOG_SAMPLE_SZ = 3
class EvalGauntlet(Callback):
"""The EvalGauntlet aggregates ICL eval results.
After `eval_end`, this callback inspects the logger for different ICL metrics and aggregates the scores according to the aggregation
specification provided in the constructor.
Args:
logger_keys (list): These are the exact keys that the individual benchmark metrics will be
logged under in the logger after eval
tasks (dict): This contains the list of categories, as well as the subtasks within them, the
random baseline accuracy of each subtask, and the number of fewshot examples
used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet.yaml` to see the structure.
weighting (Weighting): The weighting scheme used to balance different tasks within each category.
Either assign them all equal weight, assign them weight proportional
to the dataset size, or assign them weight proportional to the log2 of the dataset size.
Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'.
subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
from the performance on each individual benchmark before aggregating.
rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
"""
def __init__(self,
logger_keys: list,
categories: dict,
weighting: str = 'EQUAL',
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
benchmark_sizes: Optional[dict] = None):
if isinstance(logger_keys, dict):
raise ValueError(
'logger_keys now requires a list type as input, not a dict')
if weighting != Weighting.EQUAL and benchmark_sizes is None:
raise Exception(
'When not using equal weighting, you must provide the benchmark sizes.'
)
if rescale_accuracy and not subtract_random_baseline:
raise Exception(
'Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.'
)
self.categories = categories
self.weighting = Weighting[weighting]
self.subtract_random_baseline = subtract_random_baseline
self.rescale_accuracy = rescale_accuracy
self.logger_keys = logger_keys
for category in self.categories:
for benchmark in category['benchmarks']:
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
if self.weighting != Weighting.EQUAL:
assert benchmark_sizes is not None
cumulative_samples = max(
sum(count for name, count in benchmark_sizes.items()
if name.startswith(bench_name)), 1)
else:
cumulative_samples = -1 # pyright
weight = None
if self.weighting == Weighting.EQUAL:
weight = 1
elif self.weighting == Weighting.SAMPLE_SZ:
weight = cumulative_samples
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
weight = max(math.log(cumulative_samples, 2), 1)
assert weight is not None
benchmark['weighting'] = weight
def compute_averages(self, state: State) -> Dict[str, float]:
results = {}
for key in self.logger_keys:
# starting at index 1 skips the "metric" part of the key which is superfluous
dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1]
if 'Accuracy' not in metric_name:
continue
metric = state.eval_metrics.get('/'.join(dl_name),
{}).get(metric_name, None)
if metric is None:
continue
val = metric.compute().item()
# ending at index 2 allows us to aggregate over dataloaders w/ subcategories
key = '/'.join(dl_name[0:2])
if key not in results:
results[key] = []
results[key].append(val)
return {k: sum(v) / len(v) for k, v in results.items()}
def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
new_metrics = self.compute_averages(state)
if len(new_metrics) == 0:
return {}
composite_scores = {}
for category in self.categories:
missing_metrics = []
composite_scores[category['name']] = []
for benchmark in category['benchmarks']:
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
if key not in new_metrics:
log.warning(
f'Could not find results for benchmark: {benchmark}.')
missing_metrics.append(key)
else:
score = new_metrics[key]
if self.subtract_random_baseline:
score -= benchmark['random_baseline']
if self.rescale_accuracy and self.subtract_random_baseline:
score /= 1.0 - benchmark['random_baseline']
composite_scores[category['name']].append({
'name': benchmark['name'],
'score': score,
'weighting': benchmark['weighting']
})
if len(missing_metrics) > 0:
log.warning(
f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
)
del composite_scores[category['name']]
continue
total_weight = sum(
k['weighting'] for k in composite_scores[category['name']])
composite_scores[category['name']] = sum(
k['score'] * (k['weighting'] / total_weight)
for k in composite_scores[category['name']])
composite_scores = {
f'icl/metrics/eval_gauntlet/{k}': v
for k, v in composite_scores.items()
}
composite_scores['icl/metrics/eval_gauntlet/average'] = sum(
composite_scores.values()) / len(composite_scores.values()) if len(
composite_scores.values()) > 0 else 0
if logger is not None:
logger.log_metrics(composite_scores)
return composite_scores
|