|
|
|
|
|
|
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from composer.core.types import Batch |
|
from composer.metrics import InContextLearningMetric |
|
from composer.metrics.nlp import (InContextLearningLMAccuracy, |
|
InContextLearningLMExpectedCalibrationError, |
|
InContextLearningMCExpectedCalibrationError, |
|
InContextLearningMultipleChoiceAccuracy, |
|
InContextLearningQAAccuracy, |
|
LanguageCrossEntropy, LanguagePerplexity) |
|
from composer.models import ComposerModel |
|
from torchmetrics import Metric |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class InferenceAPIEvalWrapper(ComposerModel): |
|
|
|
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): |
|
self.tokenizer = tokenizer |
|
self.labels = None |
|
|
|
eval_metrics = [ |
|
LanguageCrossEntropy(), |
|
LanguagePerplexity(), |
|
InContextLearningLMAccuracy(), |
|
InContextLearningMultipleChoiceAccuracy(), |
|
InContextLearningQAAccuracy(), |
|
InContextLearningLMExpectedCalibrationError(), |
|
InContextLearningMCExpectedCalibrationError() |
|
] |
|
self.eval_metrics = { |
|
metric.__class__.__name__: metric for metric in eval_metrics |
|
} |
|
super().__init__() |
|
|
|
def get_metrics(self, is_train: bool = False): |
|
if is_train: |
|
raise NotImplementedError( |
|
'You cannot use inference wrappers for training') |
|
else: |
|
metrics = self.eval_metrics |
|
|
|
return metrics if metrics else {} |
|
|
|
def get_next_token_logit_tensor(self, |
|
prompt: str) -> Optional[torch.Tensor]: |
|
raise NotImplementedError |
|
|
|
def rebatch(self, batch: Batch): |
|
|
|
return batch |
|
|
|
def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): |
|
|
|
|
|
|
|
output_logits_batch = [] |
|
for tokens, cont_idxs in zip(batch['input_ids'], |
|
batch['continuation_indices']): |
|
|
|
seqlen = tokens.shape[0] |
|
tokens = tokens.tolist() |
|
cont_idxs = cont_idxs.tolist() |
|
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] |
|
output_logits = torch.nn.functional.one_hot( |
|
torch.tensor(tokens[1:cont_idxs[0]]), |
|
num_classes=self.tokenizer.vocab_size) |
|
for i in range(len(expected_cont_tokens)): |
|
|
|
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] + |
|
expected_cont_tokens[0:i]) |
|
next_logit_tensor = self.get_next_token_logit_tensor(prompt) |
|
if next_logit_tensor is None: |
|
continue |
|
output_logits = torch.cat( |
|
[output_logits, |
|
next_logit_tensor.reshape(1, -1)]) |
|
padding = torch.nn.functional.one_hot( |
|
torch.full((seqlen - output_logits.shape[0],), |
|
self.tokenizer.pad_token_id), |
|
num_classes=self.tokenizer.vocab_size) |
|
output_logits = torch.cat([output_logits, padding]) |
|
output_logits_batch.append(output_logits) |
|
|
|
return torch.stack(output_logits_batch).to(batch['input_ids'].device) |
|
|
|
def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None: |
|
batch = self.rebatch(batch) |
|
self.labels = batch.pop('labels') |
|
self.labels[:, :-1] = self.labels[:, 1:].clone() |
|
self.labels[:, -1] = -100 |
|
if isinstance(metric, InContextLearningMetric) and batch.get( |
|
'mode', None) == 'icl_task': |
|
assert self.labels is not None |
|
metric.update(batch, outputs, self.labels) |
|
else: |
|
raise NotImplementedError( |
|
'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task' |
|
) |
|
|
|
def forward(self): |
|
raise NotImplementedError( |
|
"Inference API wrapper doesn't support forward") |
|
|
|
def loss(self): |
|
raise NotImplementedError("Inference API wrapper doesn't support loss") |
|
|