|
import argparse |
|
import shutil |
|
import logging |
|
import random |
|
import time |
|
from pprint import pprint |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
from scrl.rewards import load_rewards |
|
from scrl.data import load_data_for_training |
|
from scrl.config import load_config |
|
from scrl.model import load_model, LinearTokenSelector, labels_to_summary |
|
import scrl.utils as utils |
|
import scrl.sampling as sampling |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import AutoModel, AutoTokenizer |
|
from sklearn import preprocessing |
|
|
|
from nltk import word_tokenize |
|
|
|
|
|
def print_if(x, do_print=True): |
|
if do_print: |
|
print(x) |
|
|
|
|
|
class TrainingManager: |
|
""" |
|
Object for saving/loading model checkpoints and for tracking and saving |
|
metrics measured during training, e.g. loss, rewards. |
|
|
|
The following directory struture is build around one training run: |
|
|
|
dir/ |
|
val_scores.json |
|
checkpoints/ |
|
latest-model-500/ |
|
classifier.bin |
|
encoder.bin |
|
best-model-200/ |
|
[...] |
|
series/ |
|
loss.npy |
|
[...] |
|
totals/ |
|
loss.npy |
|
[...] |
|
""" |
|
def __init__(self, dir): |
|
self.step = 0 |
|
self.total_seconds = 0 |
|
self.start_time = None |
|
self.series = defaultdict(list) |
|
self.totals = defaultdict(float) |
|
self.dir = dir |
|
dir.mkdir(exist_ok=True) |
|
for subdir_name in ("checkpoints", "series", "totals"): |
|
(dir / subdir_name).mkdir(exist_ok=True) |
|
|
|
def start_clock(self): |
|
self.start_time = time.time() - self.total_seconds |
|
|
|
def load(self): |
|
|
|
for p in (self.dir / "series").iterdir(): |
|
k = p.name.split(".npy")[0] |
|
self.series[k] = list(utils.load_numpy(p)) |
|
for p in (self.dir / "totals").iterdir(): |
|
k = p.name.split(".npy")[0] |
|
self.totals[k] = utils.load_numpy(p) |
|
|
|
latest_model_dir = self.find_old_model("latest-model") |
|
self.total_seconds = utils.read_json(self.dir / "time.json")["total_seconds"] |
|
last_step = int(latest_model_dir.name.split("-")[-1]) |
|
self.step = last_step + 1 |
|
|
|
def update_metric(self, key, value): |
|
self.totals[key] += value |
|
self.series[key].append(value) |
|
|
|
def mean_metric(self, key): |
|
return self.totals[key] / (self.step + 1) |
|
|
|
def save_latest_model(self, model, checkpoint_id): |
|
self.save_model(model, checkpoint_id, prefix="latest-model") |
|
|
|
def save_model(self, model, checkpoint_id, prefix): |
|
old_model_dir = self.find_old_model(prefix) |
|
model_dir = self.dir / "checkpoints" / f"{prefix}-{checkpoint_id}" |
|
model_dir.mkdir() |
|
model.save( |
|
classifier_path = model_dir / "classifier.bin", |
|
encoder_path = model_dir / "encoder.bin" |
|
) |
|
if old_model_dir: |
|
shutil.rmtree(old_model_dir) |
|
|
|
def find_old_model(self, prefix): |
|
model_path = None |
|
for p in (self.dir / "checkpoints").iterdir(): |
|
if p.name.startswith(f"{prefix}"): |
|
model_path = p |
|
return model_path |
|
|
|
def is_empty(self): |
|
latest_model_dir = self.find_old_model("latest-model") |
|
return latest_model_dir is None |
|
|
|
def save_data(self): |
|
for k, v in self.series.items(): |
|
utils.save_numpy(v, self.dir / "series" / f"{k}.npy") |
|
for k, v in self.totals.items(): |
|
utils.save_numpy(v, self.dir / "totals" / f"{k}.npy") |
|
utils.write_json({ |
|
"step": self.step, |
|
"total_seconds": self.total_seconds |
|
}, self.dir / "time.json") |
|
|
|
|
|
def label_variance(probs): |
|
|
|
variances = [] |
|
for i in range(probs.size(0)): |
|
distrib = probs[i, :, 0] |
|
var = torch.var(distrib) |
|
variances.append(var) |
|
return var.mean().item() |
|
|
|
|
|
def check_gradient(model): |
|
is_zero = [] |
|
is_none = [] |
|
for name, param in list(model.named_parameters()): |
|
if (param.requires_grad): |
|
grad = param.grad |
|
if grad is None: |
|
is_none.append(name) |
|
else: |
|
gradsum = param.grad.sum().item() |
|
if gradsum == 0: |
|
is_zero.append(name) |
|
print("zero-grad:", len(is_zero), is_zero) |
|
print("none-grad:", len(is_none), is_none) |
|
print() |
|
|
|
|
|
def get_mean_max_prob(probs): |
|
return probs.max(dim=2).values.mean().item() |
|
|
|
|
|
def print_training_progress(args, manager, model, probs, argmax_summaries, sample_summaries, batch, argmax_details): |
|
print(f"[step: {manager.step}] [duration(s): {round(manager.total_seconds)}]") |
|
print(f"[example/s: {(args.batch_size * (manager.step + 1)) / manager.total_seconds:.3f}]") |
|
print(f"[s/step: {manager.total_seconds / (manager.step+1):.3f}]") |
|
print(f"[avg-loss: {manager.mean_metric('loss')}]") |
|
print(f"[avg-max-prob: {manager.mean_metric('mean_max_prob'):.3f}]") |
|
print(f"[avg-a-reward: {manager.mean_metric('argmax_reward'):.3f}]") |
|
print(f"[avg-s-reward: {manager.mean_metric('sample_reward'):.3f}]") |
|
print(f"[avg-len: {manager.mean_metric('argmax_len'):.1f}]") |
|
print() |
|
print(f"[a-reward: {manager.series['argmax_reward'][-1]:.3f}]") |
|
print(f"[s-reward: {manager.series['sample_reward'][-1]:.3f}]") |
|
print(f"[max-prob: {manager.series['mean_max_prob'][-1]:.3f}]") |
|
print() |
|
print("[sentences]") |
|
print("\n".join(batch["document"])) |
|
print("\n[current policy summaries]") |
|
print("\n".join(argmax_summaries)) |
|
print("\n[sampled summaries]") |
|
print("\n".join(sample_summaries)) |
|
print() |
|
print("Reward Breakdown:") |
|
pprint(argmax_details) |
|
print() |
|
check_gradient(model) |
|
print("="*100) |
|
|
|
|
|
def setup_model(args): |
|
|
|
model_dir = Path(args.model_dir) |
|
if args.fresh and model_dir.exists(): |
|
utils.ask_rmdir(model_dir) |
|
manager = TrainingManager(model_dir) |
|
if not manager.is_empty(): |
|
manager.load() |
|
|
|
if not (model_dir / "config.json").exists(): |
|
shutil.copy(args.config, model_dir / "config.json") |
|
|
|
|
|
if manager.step == 0: |
|
encoder = AutoModel.from_pretrained(args.encoder_model_id) |
|
embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] |
|
model = LinearTokenSelector(encoder, embedding_size).to(args.device) |
|
else: |
|
print("loading latest model from step", manager.step - 1) |
|
model = load_model( |
|
model_dir, prefix="latest", device=args.device |
|
) |
|
return manager, model |
|
|
|
|
|
def setup_dataset_indices(args, step): |
|
""" |
|
Load pre-built indices that determine in which order we traverse a dataset. |
|
If we continue interrupted training state, we move indices accordingly. |
|
""" |
|
dataset_indices = utils.batchify( |
|
utils.load_numpy(args.indices), |
|
args.batch_size |
|
) |
|
if step > 0: |
|
utils.move_generator(dataset_indices, step) |
|
return dataset_indices |
|
|
|
|
|
def train( |
|
args, |
|
manager, |
|
model, |
|
tokenizer, |
|
reward_generator, |
|
dataset, |
|
dataset_indices, |
|
eval_func |
|
): |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) |
|
n_train = len(dataset["train"]) |
|
device = args.device |
|
model.train() |
|
manager.start_clock() |
|
|
|
for indices in dataset_indices: |
|
|
|
step = manager.step |
|
manager.total_seconds = time.time() - manager.start_time |
|
if args.max_train_steps and step >= args.max_train_steps + 1: |
|
break |
|
if args.max_train_seconds and manager.total_seconds >= args.max_train_seconds: |
|
break |
|
|
|
optimizer.zero_grad() |
|
|
|
batch = dataset["train"][indices] |
|
input_ids = pad_sequence( |
|
[torch.tensor(ids) for ids in batch["input_ids"]], |
|
batch_first=True |
|
).to(device) |
|
|
|
logits = model(input_ids) |
|
probs = torch.softmax(logits, dim=2) |
|
|
|
argmax_labels = torch.argmax(logits, dim=2).to(device) |
|
argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer) |
|
argmax_rewards, argmax_details = reward_generator(batch["document"], argmax_summaries) |
|
a_reward = np.mean(argmax_rewards) |
|
|
|
(sample_probs, sample_summaries, sample_rewards, sample_details, |
|
sample_labels, sample_data) = sampling.best_of_k_samples( |
|
args, manager, tokenizer, reward_generator, |
|
input_ids, batch, probs, |
|
k_samples=args.k_samples, |
|
) |
|
s_reward = np.mean(sample_rewards) |
|
|
|
if args.sample_aggregation == "max": |
|
loss = (a_reward - s_reward) * sample_probs.sum(1).mean() |
|
else: |
|
loss = 0. |
|
for sample_probs_i, s_rewards_i in zip(sample_data["probs"], sample_data["rewards"]): |
|
s_reward_i = np.mean(s_rewards_i) |
|
loss_i = (a_reward_i - s_reward_i) * sample_probs_i.sum(1).mean() |
|
loss += loss_i |
|
loss /= len(sample_data["rewards"]) |
|
|
|
if args.sample_aggregation == "mean" or a_reward != s_reward: |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
argmax_len = np.mean([len(word_tokenize(s)) for s in argmax_summaries]) |
|
|
|
manager.update_metric("time", time.time()) |
|
manager.update_metric("loss", loss.item()) |
|
manager.update_metric("argmax_reward", a_reward) |
|
manager.update_metric("sample_reward", s_reward) |
|
manager.update_metric("sample_prob", sample_probs.detach().cpu().numpy().mean()) |
|
manager.update_metric("mean_max_prob", get_mean_max_prob(probs)) |
|
manager.update_metric("label_variance", label_variance(probs)) |
|
manager.update_metric("argmax_len", argmax_len) |
|
for rname, rvalues in argmax_details.items(): |
|
manager.update_metric(f"reward|{rname}", np.mean(rvalues)) |
|
|
|
if args.eval_every != None and (step > 0 and step % args.eval_every == 0): |
|
eval_func( |
|
args, manager, model, tokenizer, reward_generator, |
|
dataset["validation"] |
|
) |
|
model.train() |
|
|
|
if args.save_every != None and (step % args.save_every == 0): |
|
manager.save_latest_model(model, step) |
|
manager.save_data() |
|
|
|
if args.print_every != None and (args.verbose and step % args.print_every == 0): |
|
print_training_progress( |
|
args, manager, model, probs, |
|
argmax_summaries, sample_summaries, batch, |
|
argmax_details |
|
) |
|
manager.step += 1 |
|
|
|
|
|
def setup_and_train(args, eval_func): |
|
|
|
print_if("loading model", args.verbose) |
|
manager, model = setup_model(args) |
|
|
|
print_if("loading tokenizer", args.verbose) |
|
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_id) |
|
|
|
print_if("loading rewards", args.verbose) |
|
reward_generator = load_rewards(args) |
|
print_if("rewards:", reward_generator.reward_names) |
|
|
|
print_if("loading dataset", args.verbose) |
|
dataset = load_data_for_training(tokenizer, args.loader, args.dataset) |
|
|
|
dataset_indices = setup_dataset_indices(args, manager.step) |
|
|
|
train( |
|
args, |
|
manager, |
|
model, |
|
tokenizer, |
|
reward_generator, |
|
dataset, |
|
dataset_indices, |
|
eval_func |
|
) |
|
|