|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluation utilities.""" |
|
|
|
import os |
|
import time |
|
from functools import partial |
|
|
|
import torch |
|
|
|
from megatron import get_args |
|
from megatron import print_rank_last, is_last_rank |
|
from megatron import mpu |
|
from megatron.schedules import get_forward_backward_func |
|
from tasks.finetune_utils import build_data_loader |
|
from tasks.finetune_utils import process_batch |
|
import json |
|
import numpy as np |
|
from tasks.label_dict import get_label_dict |
|
|
|
def accuracy_func_provider(single_dataset_provider): |
|
"""Provide function that calculates accuracies.""" |
|
args = get_args() |
|
|
|
|
|
datapaths = [args.valid_data[0], args.test_data[0]] |
|
dataloaders = [] |
|
for datapath in datapaths: |
|
dataset = single_dataset_provider(datapath) |
|
dataloader = build_data_loader( |
|
dataset, args.micro_batch_size, num_workers=args.num_workers, |
|
drop_last=(mpu.get_data_parallel_world_size() > 1)) |
|
dataloaders.append((dataset.dataset_name, dataloader)) |
|
|
|
def _generate_prediction_json(predictions, step, save_acc): |
|
|
|
probs_list = predictions[0] |
|
|
|
ids_list = predictions[2] |
|
min_id = min(ids_list) |
|
max_id = max(ids_list) |
|
LABELS = get_label_dict(args.task, write2file=True) |
|
output_submit_file = os.path.join(args.res_path[0], args.task+"_prediction_{}_{}.json".format(step, save_acc)) |
|
with open(output_submit_file, "w") as writer: |
|
for i in range(min_id, max_id + 1): |
|
label_index = ids_list.index(i) |
|
pred_prob_list = probs_list[label_index] |
|
label = pred_prob_list.index(max(pred_prob_list)) |
|
json_d = {} |
|
if min_id == 1: |
|
json_d['id'] = i - 1 |
|
else: |
|
json_d['id'] = i |
|
json_d["label"] = LABELS[str(label)] |
|
writer.write(json.dumps(json_d) + '\n') |
|
|
|
def _generate_prediction_prob(predictions, step, save_acc): |
|
|
|
probs_list = predictions[0] |
|
ids_list = predictions[2] |
|
min_id = min(ids_list) |
|
max_id = max(ids_list) |
|
|
|
output_prob_file = os.path.join(args.res_path[0], args.task+"_prob_{}_{}".format(step, save_acc)) |
|
prob_arr = [] |
|
for i in range(min_id, max_id + 1): |
|
label_index = ids_list.index(i) |
|
prob_arr.append(probs_list[label_index]) |
|
prob_arr = np.array(prob_arr) |
|
np.save(output_prob_file, prob_arr) |
|
|
|
def metrics_func(model, step): |
|
print_rank_last('calculating metrics ...') |
|
correct = 0 |
|
total = 0 |
|
|
|
for index, (name, dataloader) in enumerate(dataloaders): |
|
if index == 1: |
|
output_predictions = True |
|
assert mpu.get_data_parallel_world_size() == 1 |
|
named_predictions = [] |
|
names = 'predictions' |
|
else: |
|
output_predictions = False |
|
|
|
output = calculate_correct_answers(name, model, dataloader, |
|
step, output_predictions) |
|
if not output_predictions: |
|
correct_ans, total_count = output |
|
else: |
|
correct_ans, total_count, predictions = output |
|
named_predictions.append((name, predictions)) |
|
names += '_' + name |
|
if not output_predictions: |
|
correct += correct_ans |
|
total += total_count |
|
save_acc = str(round(correct / total, 4) * 10000)[:4] |
|
|
|
if output_predictions: |
|
print_rank_last("generate prediction...") |
|
|
|
_generate_prediction_json(predictions, step, save_acc) |
|
_generate_prediction_prob(predictions, step, save_acc) |
|
print_rank_last("generate done") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return metrics_func |
|
|
|
|
|
def calculate_correct_answers(name, model, dataloader, |
|
step, output_predictions): |
|
"""Calculate correct over total answers and return prediction if the |
|
`output_predictions` is true.""" |
|
args = get_args() |
|
forward_backward_func = get_forward_backward_func() |
|
start_time = time.time() |
|
for m in model: |
|
m.eval() |
|
saved_micro_batch_size = args.micro_batch_size |
|
saved_global_batch_size = args.global_batch_size |
|
|
|
ds = dataloader.dataset |
|
if hasattr(ds, 'sample_multiplier'): |
|
|
|
|
|
|
|
|
|
|
|
sample_multiplier = ds.sample_multiplier |
|
else: |
|
sample_multiplier = 1 |
|
micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size |
|
num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel |
|
|
|
def loss_func(output_predictions, labels, output_tensor): |
|
logits = output_tensor |
|
|
|
loss_dict = {} |
|
|
|
if output_predictions: |
|
|
|
loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)( |
|
logits.float()).data.cpu().numpy().tolist() |
|
loss_dict['labels'] = labels.data.cpu().numpy().tolist() |
|
loss_dict['ids'] = batch['uid'].cpu().numpy().tolist() |
|
|
|
predicted = torch.argmax(logits, dim=-1) |
|
corrects = (predicted == labels) |
|
|
|
loss_dict['total'] = labels.size(0) |
|
loss_dict['correct'] = corrects.sum().item() |
|
|
|
return 0, loss_dict |
|
|
|
|
|
def correct_answers_forward_step(batch, model): |
|
try: |
|
batch_ = next(batch) |
|
except BaseException: |
|
batch_ = batch |
|
tokens, types, labels, attention_mask = process_batch(batch_) |
|
|
|
|
|
args = get_args() |
|
output_tensor = model(tokens, attention_mask, tokentype_ids=types) |
|
|
|
return output_tensor, partial(loss_func, output_predictions, labels) |
|
|
|
with torch.no_grad(): |
|
|
|
total = 0 |
|
correct = 0 |
|
if output_predictions: |
|
|
|
assert mpu.get_data_parallel_world_size() == 1 |
|
softmaxes = [] |
|
labels = [] |
|
ids = [] |
|
for _, batch in enumerate(dataloader): |
|
|
|
|
|
|
|
actual_batch_size = len(batch['label']) |
|
|
|
args.micro_batch_size = actual_batch_size * sample_multiplier |
|
args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches |
|
|
|
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, |
|
optimizer=None, timers=None, forward_only=True) |
|
|
|
for loss_dict in loss_dicts: |
|
if output_predictions: |
|
softmaxes.extend(loss_dict['softmaxes']) |
|
labels.extend(loss_dict['labels']) |
|
ids.extend(loss_dict['ids']) |
|
total += loss_dict['total'] |
|
correct += loss_dict['correct'] |
|
|
|
|
|
for m in model: |
|
m.train() |
|
args.micro_batch_size = saved_micro_batch_size |
|
args.global_batch_size = saved_global_batch_size |
|
|
|
|
|
if mpu.is_pipeline_last_stage(): |
|
unreduced = torch.cuda.LongTensor([correct, total]) |
|
torch.distributed.all_reduce(unreduced, |
|
group=mpu.get_data_parallel_group()) |
|
|
|
|
|
|
|
correct_ans = unreduced[0].item() |
|
total_count = unreduced[1].item() |
|
percent = float(correct_ans) * 100.0 / float(total_count) |
|
elapsed_time = time.time() - start_time |
|
if not output_predictions: |
|
print_rank_last(' > |step: {} | metrics for {}: correct / total ' |
|
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( |
|
step, name, correct_ans, total_count, |
|
percent, elapsed_time)) |
|
|
|
if output_predictions: |
|
return correct_ans, total_count, (softmaxes, labels, ids) |
|
return correct_ans, total_count |
|
if output_predictions: |
|
return 0, 0, () |
|
return 0, 0 |
|
|