|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""GPT zero-shot evaluation.""" |
|
|
|
import math |
|
|
|
import torch |
|
|
|
from megatron import get_args |
|
from megatron import print_rank_0, is_last_rank |
|
from megatron import get_tokenizer |
|
from megatron import mpu |
|
from megatron.checkpointing import load_checkpoint |
|
from megatron.model import GPTModel |
|
from megatron.training import get_model |
|
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model |
|
from megatron.p2p_communication import recv_forward, send_forward |
|
from tasks.finetune_utils import build_data_loader |
|
|
|
from .datasets import build_dataset |
|
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP |
|
from megatron.model import DistributedDataParallel as LocalDDP |
|
from megatron.model import Float16Module |
|
|
|
def get_model_provider(eval_metric): |
|
"""Based on evaluation metric set the parallel-output flag and |
|
return the model provider.""" |
|
|
|
def model_provider(pre_process=True, post_process=True): |
|
"""Build the model.""" |
|
|
|
if eval_metric == 'loss': |
|
parallel_output = True |
|
elif eval_metric == 'accuracy': |
|
parallel_output = False |
|
else: |
|
raise NotImplementedError('output type for {} evaluation metric ' |
|
'is not supported.'.format(eval_metric)) |
|
|
|
print_rank_0('building GPT model ...') |
|
model = GPTModel(num_tokentypes=0, parallel_output=parallel_output, |
|
pre_process=pre_process, post_process=post_process) |
|
|
|
return model |
|
|
|
return model_provider |
|
|
|
|
|
def process_batch(batch): |
|
"""Process batch and produce inputs for the model.""" |
|
args = get_args() |
|
tokenizer = get_tokenizer() |
|
|
|
loss_mask = batch['pad_mask'].long().cuda().contiguous().byte() |
|
tokens_ = batch['text'].long().cuda().contiguous() |
|
labels = tokens_[:, 1:].contiguous() |
|
tokens = tokens_[:, :-1].contiguous() |
|
|
|
|
|
attention_mask, _, position_ids = get_ltor_masks_and_position_ids( |
|
tokens, |
|
tokenizer.eod, |
|
args.reset_position_ids, |
|
args.reset_attention_mask, |
|
args.eod_mask_loss) |
|
|
|
return tokens, labels, attention_mask, position_ids, loss_mask |
|
|
|
|
|
def forward_step(batch, model, eval_metric): |
|
"""Forward step.""" |
|
|
|
|
|
tokens, labels, attention_mask, position_ids, loss_mask = process_batch( |
|
batch) |
|
|
|
|
|
args = get_args() |
|
args.micro_batch_size = len(labels) |
|
|
|
input_tensor = recv_forward() |
|
|
|
|
|
unwrapped_model = unwrap_model( |
|
model, (torchDDP, LocalDDP, Float16Module)) |
|
unwrapped_model.set_input_tensor(input_tensor) |
|
output = model(tokens, position_ids, attention_mask) |
|
|
|
send_forward(output) |
|
|
|
if mpu.is_pipeline_last_stage(): |
|
|
|
if eval_metric == 'loss': |
|
losses = mpu.vocab_parallel_cross_entropy( |
|
output.contiguous().float(), labels.contiguous()) |
|
loss = torch.sum( |
|
losses.view(-1) * loss_mask.contiguous().view(-1).float()) |
|
return loss |
|
|
|
|
|
if eval_metric == 'accuracy': |
|
outputs = torch.argmax(output, -1) |
|
correct = (outputs == labels).float() |
|
correct[(1 - loss_mask).bool()] = 1 |
|
correct = correct.prod(-1) |
|
return correct.sum() |
|
|
|
raise NotImplementedError('forward method for evaluation metric {} ' |
|
'is not implemented.'.format(eval_metric)) |
|
return None |
|
|
|
|
|
def evaluate(data_loader, model, eval_metric): |
|
"""Evaluation.""" |
|
args = get_args() |
|
|
|
|
|
model.eval() |
|
|
|
total_output = 0.0 |
|
with torch.no_grad(): |
|
|
|
for iteration, batch in enumerate(data_loader): |
|
if iteration % args.log_interval == 0: |
|
print_rank_0('> working on iteration: {}'.format(iteration)) |
|
|
|
output = forward_step(batch, model, eval_metric) |
|
|
|
|
|
if mpu.is_pipeline_last_stage(): |
|
torch.distributed.all_reduce(output, |
|
group=mpu.get_data_parallel_group()) |
|
|
|
total_output += output |
|
|
|
return total_output |
|
|
|
|
|
def evaluate_and_print_results(task, data_loader, model, eval_metric): |
|
"""Evaluate and print results on screen.""" |
|
|
|
|
|
output = evaluate(data_loader, model, eval_metric) |
|
|
|
string = ' validation results on {} | '.format(task) |
|
if is_last_rank(): |
|
if eval_metric == 'loss': |
|
num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens |
|
num_original_tokens = data_loader.dataset.num_original_tokens |
|
val_loss = output / (num_tokenized_tokens - 1) |
|
ppl = math.exp(min(20, val_loss)) |
|
token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) |
|
adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) |
|
string += 'avg loss: {:.4E} | '.format(val_loss) |
|
string += 'ppl: {:.4E} | '.format(ppl) |
|
string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) |
|
string += 'token ratio: {} |'.format(token_ratio) |
|
|
|
elif eval_metric == 'accuracy': |
|
num_examples = len(data_loader.dataset) |
|
acc = output / num_examples |
|
string += 'number correct: {:.4E} | '.format(output) |
|
string += 'total examples: {:.4E} | '.format(num_examples) |
|
string += 'avg accuracy: {:.4E}'.format(acc) |
|
|
|
else: |
|
raise NotImplementedError('evaluation method for {} metric is not ' |
|
'implemented yet.'.format(eval_metric)) |
|
|
|
length = len(string) + 1 |
|
print('-' * length) |
|
print(string) |
|
print('-' * length) |
|
|
|
|
|
def main(): |
|
"""Main program.""" |
|
args = get_args() |
|
|
|
if args.num_layers_per_virtual_pipeline_stage is not None: |
|
print("Interleaved pipeline schedule is not yet supported for text generation.") |
|
exit() |
|
|
|
if args.task == 'LAMBADA': |
|
eval_metric = 'accuracy' |
|
elif args.task == 'WIKITEXT103': |
|
eval_metric = 'loss' |
|
else: |
|
raise NotImplementedError('{} task is not implemented.'.format( |
|
args.task)) |
|
|
|
|
|
model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False) |
|
if args.load is not None: |
|
_ = load_checkpoint(model, None, None) |
|
|
|
assert len(model) == 1, "Above condition should have caught this" |
|
model = model[0] |
|
|
|
|
|
dataset = build_dataset(args.task) |
|
dataloader = build_data_loader(dataset, args.micro_batch_size, |
|
args.num_workers, drop_last=False) |
|
|
|
|
|
evaluate_and_print_results(args.task, dataloader, model, eval_metric) |
|
|
|
print_rank_0('done :-)') |
|
|