viLegal_bi / src /bi_kd /trainer.py
coang's picture
Upload 18 files
7d95c60 verified
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.checkpoint import get_device_states, set_device_states
from torch.nn import DataParallel
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import LambdaLR
from .model import SharedBiEncoder
from .loss import BiEncoderNllLoss, BiEncoderDoubleNllLoss
class WarmupLinearSchedule(LambdaLR):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
class RandContext:
def __init__(self, *tensors):
self.fwd_cpu_state = torch.get_rng_state()
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
def __enter__(self):
self._fork = torch.random.fork_rng(
devices=self.fwd_gpu_devices,
enabled=True
)
self._fork.__enter__()
torch.set_rng_state(self.fwd_cpu_state)
set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
def __exit__(self, exc_type, exc_val, exc_tb):
self._fork.__exit__(exc_type, exc_val, exc_tb)
self._fork = None
class BiTrainer():
"""
Trainer for biencoder
"""
def __init__(self,
args,
train_loader,
val_loader):
self.parallel = True if torch.cuda.device_count() > 1 else False
print("No of GPU(s):",torch.cuda.device_count())
self.args = args
self.train_loader = train_loader
self.val_loader = val_loader
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SharedBiEncoder(model_checkpoint=self.args.BE_checkpoint,
representation=self.args.BE_representation,
fixed=self.args.bi_fixed)
if self.args.resume_training_from:
old_checkpoint = torch.load(self.args.resume_training_from)
self.model.load_state_dict(old_checkpoint['model'])
print("Resume training. Trained", old_checkpoint['epoch'], 'epochs.')
if self.parallel:
print("Parallel Training")
self.model = DataParallel(self.model)
self.model.to(self.device)
if self.args.BE_loss == 0.0 or self.args.BE_loss == 1.0:
self.criterion = BiEncoderNllLoss(score_type=self.args.BE_score,
kd_alpha=self.args.kd_loss)
else:
self.criterion = BiEncoderDoubleNllLoss(score_type=self.args.BE_score,
alpha=self.args.BE_loss)
self.val_criterion = BiEncoderNllLoss(score_type=self.args.BE_score)
self.optimizer = AdamW(self.model.parameters(), lr=args.BE_lr)
self.scheduler = WarmupLinearSchedule(self.optimizer, 0.1 * len(self.train_loader) * self.args.BE_num_epochs, len(self.train_loader) * self.args.BE_num_epochs)
self.epoch = 0
if self.args.resume_training_from:
self.optimizer.load_state_dict(old_checkpoint['optimizer'])
self.scheduler.load_state_dict(old_checkpoint['scheduler'])
self.epoch = old_checkpoint['epoch']
self.patience_counter = 0
self.best_val_acc = 0.0
self.epochs_count = []
self.train_losses = []
self.valid_losses = []
self.train_acc = []
self.valid_acc = []
def train_biencoder(self):
# Compute loss and accuracy before starting (or resuming) training.
print("\n",
20 * "=",
"Validation before training",
20 * "=")
val_time, val_loss, val_acc = self.validate()
print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
.format(val_time, val_loss, (val_acc*100)))
print("\n",
20 * "=",
"Training biencoder model on device: {}".format(self.device),
20 * "=")
while self.epoch < self.args.BE_num_epochs:
self.epoch +=1
self.epochs_count.append(self.epoch)
print("* Training epoch {}:".format(self.epoch))
epoch_avg_loss, epoch_accuracy, epoch_time = self.train()
self.train_losses.append(epoch_avg_loss)
self.train_acc.append(epoch_accuracy.to('cpu')*100)
print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%"
.format(epoch_time, epoch_avg_loss, (epoch_accuracy*100)))
print("* Validation for epoch {}:".format(self.epoch))
epoch_time, epoch_loss, epoch_accuracy = self.validate()
self.valid_losses.append(epoch_loss)
self.valid_acc.append(epoch_accuracy.to('cpu')*100)
print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
.format(epoch_time, epoch_loss, (epoch_accuracy*100)))
if epoch_accuracy <= self.best_val_acc:
self.patience_counter += 1
else:
self.best_val_acc = epoch_accuracy
self.patience_counter = 0
if self.parallel:
self.model.module.encoder.save(self.args.biencoder_path)
#torch.save(self.model.module.state_dict(), self.args.biencoder_path)
else:
self.model.encoder.save(self.args.biencoder_path)
#torch.save(self.model.state_dict(), self.args.biencoder_path)
if self.epoch == self.args.BE_num_epochs:
if self.parallel:
self.model.module.encoder.save(self.args.final_path)
#torch.save(self.model.module.state_dict(), self.args.final_path)
else:
self.model.encoder.save(self.args.final_path)
#torch.save(self.model.state_dict(), self.args.final_path)
if self.parallel:
checkpoint = {
'epoch': self.epoch,
'model': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()}
torch.save(checkpoint, 'last_checkpoint.pth')
else:
checkpoint = {
'epoch': self.epoch,
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()}
torch.save(checkpoint, 'last_checkpoint.pth')
# Plotting of the loss curves for the train and validation sets.
plt.figure()
plt.plot(self.epochs_count, self.train_losses, "-r")
plt.plot(self.epochs_count, self.valid_losses, "-b")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(["Training loss", "Validation loss"])
plt.title("Cross entropy loss")
plt.show()
plt.figure()
plt.plot(self.epochs_count, self.train_acc, '-r')
plt.plot(self.epochs_count, self.valid_acc, "-b")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend(["Training accuracy", "Validation accuracy"])
plt.title("Accuracy")
plt.show()
#return the final q_model, ctx_model
if self.parallel:
return self.model.module.get_model()
else:
return self.model.get_model()
def train(self):
self.model.train()
epoch_start = time.time()
batch_time_avg = 0.0
epoch_loss = 0.0
epoch_correct = 0
tqdm_batch_iterator = tqdm(self.train_loader)
for i, batch in enumerate(tqdm_batch_iterator):
batch_start = time.time()
if self.args.grad_cache:
loss, num_correct = self.step_cache(batch)
else:
loss, num_correct = self.step(batch)
batch_time_avg += time.time() - batch_start
epoch_loss += loss
epoch_correct += num_correct
description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}"\
.format(batch_time_avg/(i+1),
epoch_loss/(i+1))
tqdm_batch_iterator.set_description(description)
epoch_time = time.time() - epoch_start
epoch_avg_loss = epoch_loss / len(self.train_loader)
epoch_accuracy = epoch_correct / len(self.train_loader.dataset)
return epoch_avg_loss, epoch_accuracy, epoch_time
def step(self, batch):
self.model.train()
if self.args.no_hard != 0:
q_input_ids, q_attn_mask, p_input_ids, p_attn_mask, n_input_ids, n_attn_mask, scores = tuple(t.to(self.device) for t in batch)
ctx_len = n_input_ids.size()[-1]
n_input_ids = n_input_ids.view(-1,ctx_len)
n_attn_mask = n_attn_mask.view(-1,ctx_len)
ctx_input_ids = torch.cat((p_input_ids, n_input_ids), 0)
ctx_attn_mask = torch.cat((p_attn_mask, n_attn_mask), 0)
else:
q_input_ids, q_attn_mask, ctx_input_ids, ctx_attn_mask, scores = tuple(t.to(self.device) for t in batch)
self.optimizer.zero_grad()
q_vectors, ctx_vectors = self.model(q_input_ids, q_attn_mask, ctx_input_ids, ctx_attn_mask)
loss, num_correct = self.criterion.calc(q_vectors, ctx_vectors, scores)
loss.backward()
self.optimizer.step()
self.scheduler.step()
return loss.item(), num_correct
def step_cache(self, batch):
self.model.train()
self.optimizer.zero_grad()
if self.args.no_hard != 0:
q_input_ids, q_attn_mask, p_input_ids, p_attn_mask, n_input_ids, n_attn_mask, scores = tuple(t.to(self.device) for t in batch)
ctx_len = n_input_ids.size()[-1]
n_input_ids = n_input_ids.view(-1,ctx_len)
n_attn_mask = n_attn_mask.view(-1,ctx_len)
ctx_input_ids = torch.cat((p_input_ids, n_input_ids), 0)
ctx_attn_mask = torch.cat((p_attn_mask, n_attn_mask), 0)
else:
q_input_ids, q_attn_mask, ctx_input_ids, ctx_attn_mask, scores = tuple(t.to(self.device) for t in batch)
all_q_reps, all_ctx_reps = [], []
q_rnds, ctx_rnds = [], []
q_id_chunks = q_input_ids.split(self.args.q_chunk_size)
q_attn_mask_chunks = q_attn_mask.split(self.args.q_chunk_size)
ctx_id_chunks = ctx_input_ids.split(self.args.ctx_chunk_size)
ctx_attn_mask_chunks = ctx_attn_mask.split(self.args.ctx_chunk_size)
for id_chunk, attn_chunk in zip(q_id_chunks, q_attn_mask_chunks):
q_rnds.append(RandContext(id_chunk, attn_chunk))
with torch.no_grad():
q_chunk_reps = self.model(id_chunk, attn_chunk, None, None)[0]
all_q_reps.append(q_chunk_reps)
all_q_reps = torch.cat(all_q_reps)
for id_chunk, attn_chunk in zip(ctx_id_chunks, ctx_attn_mask_chunks):
ctx_rnds.append(RandContext(id_chunk, attn_chunk))
with torch.no_grad():
ctx_chunk_reps = self.model(None, None, id_chunk, attn_chunk)[1]
all_ctx_reps.append(ctx_chunk_reps)
all_ctx_reps = torch.cat(all_ctx_reps)
all_q_reps = all_q_reps.float().detach().requires_grad_()
all_ctx_reps = all_ctx_reps.float().detach().requires_grad_()
loss, num_correct = self.criterion.calc(all_q_reps, all_ctx_reps, scores)
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
loss.backward()
q_grads = all_q_reps.grad.split(self.args.q_chunk_size)
ctx_grads = all_ctx_reps.grad.split(self.args.ctx_chunk_size)
for id_chunk, attn_chunk, grad, rnd in zip(q_id_chunks, q_attn_mask_chunks, q_grads, q_rnds):
with rnd:
q_chunk_reps = self.model(id_chunk, attn_chunk, None, None)[0]
surrogate = torch.dot(q_chunk_reps.flatten().float(), grad.flatten())
#surrogate = surrogate * (trainer.distributed_factor / 8.)
surrogate.backward()
for id_chunk, attn_chunk, grad, rnd in zip(ctx_id_chunks, ctx_attn_mask_chunks, ctx_grads, ctx_rnds):
with rnd:
ctx_chunk_reps = self.model(None, None, id_chunk, attn_chunk)[1]
surrogate = torch.dot(ctx_chunk_reps.flatten().float(), grad.flatten())
#surrogate = surrogate * (trainer.distributed_factor / 8.)
surrogate.backward()
#q_vectors, ctx_vectors = self.model(q_input_ids, q_attn_mask, ctx_input_ids, ctx_attn_mask)
#loss, num_correct = self.criterion.calc(q_vectors, ctx_vectors, scores)
#loss.backward()
self.optimizer.step()
self.scheduler.step()
return loss.item(), num_correct
def validate(self):
self.model.eval()
epoch_start = time.time()
total_loss = 0.0
total_correct = 0
accuracy = 0
with torch.no_grad():
tqdm_batch_iterator = tqdm(self.val_loader)
for i, batch in enumerate(tqdm_batch_iterator):
if self.args.no_hard != 0:
q_input_ids, q_attn_mask, p_input_ids, p_attn_mask, n_input_ids, n_attn_mask, scores = tuple(t.to(self.device) for t in batch)
ctx_len = n_input_ids.size()[-1]
n_input_ids = n_input_ids.view(-1,ctx_len)
n_attn_mask = n_attn_mask.view(-1,ctx_len)
ctx_input_ids = torch.cat((p_input_ids, n_input_ids), 0)
ctx_attn_mask = torch.cat((p_attn_mask, n_attn_mask), 0)
else:
q_input_ids, q_attn_mask, ctx_input_ids, ctx_attn_mask, scores = tuple(t.to(self.device) for t in batch)
q_vectors, ctx_vectors = self.model(q_input_ids, q_attn_mask, ctx_input_ids, ctx_attn_mask)
loss, num_correct = self.val_criterion.calc(q_vectors, ctx_vectors, scores)
total_loss += loss.item()
total_correct += num_correct
epoch_time = time.time() - epoch_start
val_loss = total_loss/len(self.val_loader)
accuracy = total_correct/len(self.val_loader.dataset)
return epoch_time, val_loss, accuracy