|
import argparse |
|
import numpy as np |
|
from pathlib import Path |
|
import tqdm |
|
from pprint import pprint |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from scrl.config import load_config |
|
from scrl.training import setup_and_train |
|
from scrl.model import labels_to_summary |
|
from scrl.eval_metrics import compute_token_f1 |
|
import scrl.utils as utils |
|
from nltk import word_tokenize |
|
|
|
|
|
def evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, dataset): |
|
device = args.device |
|
idx_range = list(range(len(dataset))) |
|
dataset_indices = list(utils.batchify(idx_range, args.batch_size)) |
|
rewards = [] |
|
for i, indices in enumerate(dataset_indices): |
|
if args.max_val_steps != None and i >= args.max_val_steps: |
|
break |
|
batch = dataset[indices] |
|
input_ids = batch["input_ids"] |
|
input_ids = pad_sequence( |
|
[torch.tensor(ids) for ids in input_ids], batch_first=True |
|
) |
|
logits = model(input_ids.to(device)) |
|
probs = torch.softmax(logits, dim=2) |
|
argmax_labels = torch.argmax(logits, dim=2).to(device) |
|
argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer) |
|
argmax_rewards, _ = reward_generator(batch["document"], argmax_summaries) |
|
rewards += argmax_rewards |
|
avg_reward = np.mean(rewards) |
|
return avg_reward |
|
|
|
|
|
|
|
def evaluate_validation_dataset(args, manager, model, tokenizer, reward_generator, dataset_path): |
|
f1_scores = [] |
|
dataset = list(utils.read_jsonl(dataset_path)) |
|
dump_data = [] |
|
|
|
for item in tqdm.tqdm(dataset): |
|
src = item["text"] |
|
tgts = item["summaries"] |
|
|
|
input_ids = torch.tensor(tokenizer([src])["input_ids"]).to(args.device) |
|
logits = model.forward(input_ids) |
|
argmax_labels = torch.argmax(logits, dim=2) |
|
pred = labels_to_summary(input_ids, argmax_labels, tokenizer)[0] |
|
|
|
pred_tokens = word_tokenize(pred) |
|
src_tokens = word_tokenize(src) |
|
|
|
|
|
item_scores = [] |
|
for tgt in tgts: |
|
tgt_tokens = word_tokenize(tgt) |
|
pred_tokens = [t.lower() for t in pred_tokens] |
|
tgt_tokens = [t.lower() for t in tgt_tokens] |
|
token_f1 = compute_token_f1( |
|
tgt_tokens, pred_tokens, use_counts=True |
|
) |
|
item_scores.append(token_f1) |
|
|
|
if args.dump: |
|
probs = torch.softmax(logits, dim=2)[0].detach().tolist() |
|
dump_item = { |
|
"probs": probs, |
|
"source": src, |
|
"target": tgts[0], |
|
"f1-score": item_scores[0], |
|
"pred_summary": pred, |
|
"pred_labels": argmax_labels[0].tolist(), |
|
} |
|
dump_data.append(dump_item) |
|
|
|
item_score = np.mean(item_scores) |
|
f1_scores.append(item_score) |
|
score = np.mean(f1_scores) |
|
|
|
|
|
if args.dump: |
|
dataset_name = dataset_path.name.split(".jsonl")[0] |
|
dump_dir = manager.dir / f"dump-{dataset_name}" |
|
dump_dir.mkdir(exist_ok=True) |
|
utils.write_jsonl( |
|
dump_data, |
|
dump_dir / f"step-{manager.step}.jsonl", |
|
"w" |
|
) |
|
return score |
|
|
|
|
|
def evaluate(args, manager, model, tokenizer, reward_generator, holdout_data): |
|
step = manager.step |
|
val_reward = evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, holdout_data) |
|
|
|
reward_path = manager.dir / "val_rewards.jsonl" |
|
if reward_path.exists(): |
|
reward_results = list(utils.read_jsonl(reward_path)) |
|
prev_max = max([x["score"] for x in reward_results]) |
|
else: |
|
reward_results = [] |
|
prev_max = 0 |
|
if val_reward > prev_max: |
|
manager.save_model(model, step, "best_val_reward") |
|
reward_results.append({"step": step, "score": val_reward}) |
|
utils.write_jsonl(reward_results, reward_path, "w") |
|
if args.verbose: |
|
print("Validation Rewards:") |
|
pprint(reward_results) |
|
print() |
|
|
|
|
|
for val_data_path in args.validation_datasets: |
|
val_data_path = Path(val_data_path) |
|
dataset_name = val_data_path.name.split(".jsonl")[0] |
|
dataset_score = evaluate_validation_dataset( |
|
args, manager, model, tokenizer, reward_generator, val_data_path |
|
) |
|
result_path = Path(manager.dir / f"val_data_results.{dataset_name}.jsonl") |
|
if result_path.exists(): |
|
dataset_results = list(utils.read_jsonl(result_path)) |
|
prev_max = max([x["score"] for x in dataset_results]) |
|
else: |
|
dataset_results = [] |
|
prev_max = 0 |
|
if dataset_score > prev_max: |
|
manager.save_model(model, step, f"best_on_{dataset_name}") |
|
dataset_results.append({"step": step, "score": dataset_score}) |
|
utils.write_jsonl(dataset_results, result_path, "w") |
|
if args.verbose: |
|
print(f"Validation Dataset Results for {dataset_name}:") |
|
pprint(dataset_results) |
|
print() |
|
|
|
|
|
def main(args): |
|
utils.set_random_seed(0) |
|
setup_and_train(args, eval_func=evaluate) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", help="path to JSON config file") |
|
parser.add_argument("--device", default="cuda") |
|
parser.add_argument("--dump", action="store_true") |
|
parser.add_argument("--verbose", action="store_true") |
|
parser.add_argument( |
|
"--fresh", |
|
action="store_true", |
|
help="delete model directory and start from scratch" |
|
) |
|
main(load_config(parser.parse_args())) |
|
|