|
|
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import gc |
|
import random |
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
class Trainer: |
|
def __init__( |
|
self, |
|
model: torch.nn.Module, |
|
datasets: DataLoader, |
|
optimizers: torch.optim.Optimizer, |
|
save_every: int, |
|
save_checkpoint_path: str, |
|
load_checkpoint_path: str, |
|
config |
|
) -> None: |
|
self.config = config |
|
self.local_rank = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0')) |
|
self.global_rank = int(os.environ["RANK"]) |
|
self.model = model.to(self.local_rank) |
|
|
|
|
|
self.train_data = datasets[0] |
|
self.valid_data = datasets[1] |
|
|
|
|
|
self.opt_ae = optimizers[0] |
|
self.opt_disc = optimizers[1] |
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
self.save_every = save_every |
|
self.epochs_run = 1 |
|
self.last_batch_idx = -1 |
|
self.save_checkpoint_path = save_checkpoint_path |
|
|
|
if os.path.exists(load_checkpoint_path): |
|
print("Loading checkpoint") |
|
self._load_checkpoint(load_checkpoint_path) |
|
|
|
self.model = DDP(self.model, find_unused_parameters=True) |
|
|
|
def _load_checkpoint(self, checkpoint_path): |
|
loc = f"cuda:{self.local_rank}" |
|
ckpt_dict = torch.load(checkpoint_path, map_location=loc) |
|
self.model.load_state_dict(ckpt_dict["MODEL_STATE"]) |
|
self.opt_ae.load_state_dict(ckpt_dict["optimizer"][0]) |
|
self.opt_disc.load_state_dict(ckpt_dict["optimizer"][1]) |
|
self.scaler.load_state_dict(ckpt_dict["scaler"]) |
|
self.last_batch_idx = ckpt_dict["last_batch_idx"] if 'last_batch_idx' in ckpt_dict else -1 |
|
self.epochs_run = ckpt_dict["EPOCHS_RUN"] + 1 if self.last_batch_idx == -1 else ckpt_dict["EPOCHS_RUN"] |
|
|
|
|
|
if 'rng' in ckpt_dict: |
|
rng = ckpt_dict['rng'] |
|
for key, value in rng.items(): |
|
if key =='torch_state': |
|
torch.set_rng_state(value.cpu()) |
|
elif key =='cuda_state': |
|
torch.cuda.set_rng_state(value.cpu()) |
|
elif key =='numpy_state': |
|
np.random.set_state(value) |
|
elif key =='python_state': |
|
random.setstate(value) |
|
else: |
|
print('unrecognized state') |
|
|
|
print(f"Resuming training from checkpoint at Epoch {self.epochs_run}") |
|
|
|
def _save_checkpoint(self, epoch, config, last_idx): |
|
|
|
out_dict = dict() |
|
out_dict['torch_state'] = torch.get_rng_state() |
|
out_dict['cuda_state'] = torch.cuda.get_rng_state() |
|
if np: |
|
out_dict['numpy_state'] = np.random.get_state() |
|
if random: |
|
out_dict['python_state'] = random.getstate() |
|
|
|
checkpoint = { |
|
"MODEL_STATE": self.model.module.state_dict(), |
|
"EPOCHS_RUN": epoch, |
|
"optimizer": [self.opt_ae.state_dict(), self.opt_disc.state_dict()], |
|
"scaler": self.scaler.state_dict(), |
|
"hparams": vars(config), |
|
"last_batch_idx": last_idx, |
|
"rng": out_dict |
|
} |
|
|
|
if last_idx == -1: |
|
filename = f'VQGAN_{epoch}.pt' |
|
else: |
|
filename = f'VQGAN_{last_idx}_{epoch}.pt' |
|
|
|
torch.save(checkpoint, os.path.join(self.save_checkpoint_path, filename)) |
|
|
|
print(f"Epoch {epoch} | Training checkpoint saved at {os.path.join(self.save_checkpoint_path, filename)}.") |
|
|
|
def train(self, max_epochs: int): |
|
for epoch in range(self.epochs_run, max_epochs+1): |
|
self._run_epoch(epoch) |
|
if self.global_rank == 0: |
|
self._save_checkpoint(epoch, self.config, last_idx=-1) |
|
|
|
def _run_epoch(self, epoch): |
|
|
|
print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {self.config.model.batch_size} | Steps: {len(self.train_data)} | LastIdx: {self.last_batch_idx}") |
|
|
|
self.train_data.sampler.set_epoch(epoch) |
|
|
|
|
|
|
|
train_losses = pd.DataFrame() |
|
for step, data in enumerate(tqdm(self.train_data)): |
|
|
|
if step <= self.last_batch_idx: |
|
continue |
|
|
|
self.model.train() |
|
|
|
x_train = data['data'].to(self.local_rank) |
|
global_step = step * epoch |
|
recon_loss, aeloss, perceptual_loss, g_image_loss, image_gan_feat_loss, commitment_loss, perplexity = self._run_batch(global_step, x_train) |
|
|
|
|
|
if self.global_rank == 0: |
|
df = pd.DataFrame({ |
|
'recon_loss': [recon_loss.detach().cpu().item()], |
|
'aeloss': [aeloss.detach().cpu().item()], |
|
'perceptual_loss': [perceptual_loss.detach().cpu().item()], |
|
'g_image_loss': [g_image_loss.detach().cpu().item()], |
|
'image_gan_feat_loss': [image_gan_feat_loss.detach().cpu().item()], |
|
'perplexity': [perplexity.detach().cpu().item()], |
|
'commitment_loss': [commitment_loss.detach().cpu().item()] |
|
}) |
|
train_losses = pd.concat([train_losses, df], axis=0) |
|
print(f"[Training] recon_loss={recon_loss.item()}, aeloss={aeloss.item()}, perceptual_loss={perceptual_loss.item()}, g_image_loss={g_image_loss.item()}, image_gan_feat_loss={image_gan_feat_loss.item()}, perplexity={perplexity.item()}, commitment_loss={commitment_loss.item()}") |
|
|
|
|
|
if self.global_rank == 0 and step % self.save_every == 0 and step != 0: |
|
self._save_checkpoint(epoch, self.config, step) |
|
|
|
train_losses.to_csv(os.path.join(self.save_checkpoint_path, f'training_loss_{step}_epoch{epoch}.csv'), index=False) |
|
train_losses = pd.Series() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.last_batch_idx = -1 |
|
|
|
|
|
if self.global_rank == 0: |
|
train_losses.to_csv(os.path.join(self.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False) |
|
|
|
|
|
def _run_batch(self, global_step, x): |
|
|
|
self.opt_ae.zero_grad() |
|
recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss, others = self.model.forward(global_step, x, optimizer_idx=0, gpu_id=self.local_rank) |
|
commitment_loss = vq_output['commitment_loss'] |
|
loss_ae = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss |
|
loss_ae.backward() |
|
self.opt_ae.step() |
|
|
|
|
|
self.opt_disc.zero_grad() |
|
loss_disc = self.model.forward(global_step, x, optimizer_idx=1, gpu_id=self.local_rank) |
|
loss_disc.backward() |
|
self.opt_disc.step() |
|
|
|
g_image_loss, image_gan_feat_loss, commitment_loss, perplexity = others[0], others[1], others[2], others[3] |
|
|
|
return recon_loss, aeloss, perceptual_loss, g_image_loss, image_gan_feat_loss, commitment_loss, perplexity |
|
|