import itertools from collections import defaultdict from typing import List, Optional, Sequence import torch from mmengine.evaluator import BaseMetric from mmengine.logging import print_log from rich.console import Console from rich.table import Table class RewardMetric(BaseMetric): r"""Reward model evaluation metric. """ default_prefix: Optional[str] = '' def __init__(self, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) def process(self, data_batch, data_samples: Sequence[dict]): """Process one batch of data samples. The processed results should be stored in ``self.results``, which will be used to computed the metrics when all batches have been processed. Args: data_batch: A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ logits = torch.cat( [sample['logits'].unsqueeze(0) for sample in data_samples], dim=0) labels = data_batch['data']['labels'] ds_names = data_batch['data_samples']['ds_names'] chosen_idx = torch.where(labels == 0) rejected_idx = torch.where(labels == 1) chosen_logits = logits[chosen_idx].cpu() rejected_logits = logits[rejected_idx].cpu() correct = (chosen_logits > rejected_logits).cpu() self.results.append({ 'chosen_logits': chosen_logits, 'rejected_logits': rejected_logits, 'correct': correct, 'ds_names': ds_names }) def compute_metrics(self, results: List): """Compute the metrics from processed results. Args: results (dict): The processed results of each batch. Returns: Dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ # NOTICE: don't access `self.results` from the method. metrics = {} correct = torch.cat([res['correct'] for res in results]) chosen_logits = torch.cat([res['chosen_logits'] for res in results]) rejected_logits = torch.cat( [res['rejected_logits'] for res in results]) ds_names = list(itertools.chain(*[res['ds_names'] for res in results])) # group by ds_names grouped_correct = defaultdict(list) grouped_chosen_logits = defaultdict(list) grouped_rejected_logits = defaultdict(list) for i, ds_name in enumerate(ds_names): grouped_correct[ds_name].append(correct[i]) grouped_chosen_logits[ds_name].append(chosen_logits[i]) grouped_rejected_logits[ds_name].append(rejected_logits[i]) # print metrics in a rich table table = Table(title='Reward Metrics') table.add_column('Dataset Name') table.add_column('Accuracy') table.add_column('Chosen Score') table.add_column('Rejected Score') for ds_name in grouped_correct.keys(): correct = torch.stack(grouped_correct[ds_name]) chosen_logits = torch.stack(grouped_chosen_logits[ds_name]) rejected_logits = torch.stack(grouped_rejected_logits[ds_name]) acc = correct.float().mean() metrics[f'accuracy/{ds_name}'] = acc.item() metrics[f'chosen_score/{ds_name}'] = chosen_logits.mean().item() metrics[f'rejected_score{ds_name}'] = rejected_logits.mean().item() table.add_row(ds_name, f'{acc:.4f}', f'{chosen_logits.mean():.4f}', f'{rejected_logits.mean():.4f}') console = Console() with console.capture() as capture: console.print(table, end='') print_log('\n' + capture.get(), 'current') return metrics