radames's picture
add inversion
d9778ff
import os
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from utils import common, train_utils
from criteria import id_loss, w_norm, moco_loss
from configs import data_configs
from datasets.images_dataset import ImagesDataset
from criteria.lpips.lpips import LPIPS
from models.psp import pSp
from training.ranger import Ranger
class Coach:
def __init__(self, opts):
self.opts = opts
self.global_step = 0
self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
self.opts.device = self.device
if self.opts.use_wandb:
from utils.wandb_utils import WBLogger
self.wb_logger = WBLogger(self.opts)
# Initialize network
self.net = pSp(self.opts).to(self.device)
# Estimate latent_avg via dense sampling if latent_avg is not available
if self.net.latent_avg is None:
self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0].detach()
# Initialize loss
if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
raise ValueError('Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!')
self.mse_loss = nn.MSELoss().to(self.device).eval()
if self.opts.lpips_lambda > 0:
self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
if self.opts.id_lambda > 0:
self.id_loss = id_loss.IDLoss().to(self.device).eval()
if self.opts.w_norm_lambda > 0:
self.w_norm_loss = w_norm.WNormLoss(start_from_latent_avg=self.opts.start_from_latent_avg)
if self.opts.moco_lambda > 0:
self.moco_loss = moco_loss.MocoLoss().to(self.device).eval()
# Initialize optimizer
self.optimizer = self.configure_optimizers()
# Initialize dataset
self.train_dataset, self.test_dataset = self.configure_datasets()
self.train_dataloader = DataLoader(self.train_dataset,
batch_size=self.opts.batch_size,
shuffle=True,
num_workers=int(self.opts.workers),
drop_last=True)
self.test_dataloader = DataLoader(self.test_dataset,
batch_size=self.opts.test_batch_size,
shuffle=False,
num_workers=int(self.opts.test_workers),
drop_last=True)
# Initialize logger
log_dir = os.path.join(opts.exp_dir, 'logs')
os.makedirs(log_dir, exist_ok=True)
self.logger = SummaryWriter(log_dir=log_dir)
# Initialize checkpoint dir
self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.best_val_loss = None
if self.opts.save_interval is None:
self.opts.save_interval = self.opts.max_steps
def train(self):
self.net.train()
while self.global_step < self.opts.max_steps:
for batch_idx, batch in enumerate(self.train_dataloader):
self.optimizer.zero_grad()
x, y = batch
x, y = x.to(self.device).float(), y.to(self.device).float()
y_hat, latent = self.net.forward(x, return_latents=True)
loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
loss.backward()
self.optimizer.step()
# Logging related
if self.global_step % self.opts.image_interval == 0 or (self.global_step < 1000 and self.global_step % 25 == 0):
self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
if self.global_step % self.opts.board_interval == 0:
self.print_metrics(loss_dict, prefix='train')
self.log_metrics(loss_dict, prefix='train')
# Log images of first batch to wandb
if self.opts.use_wandb and batch_idx == 0:
self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="train", step=self.global_step, opts=self.opts)
# Validation related
val_loss_dict = None
if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
val_loss_dict = self.validate()
if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
self.best_val_loss = val_loss_dict['loss']
self.checkpoint_me(val_loss_dict, is_best=True)
if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
if val_loss_dict is not None:
self.checkpoint_me(val_loss_dict, is_best=False)
else:
self.checkpoint_me(loss_dict, is_best=False)
if self.global_step == self.opts.max_steps:
print('OMG, finished training!')
break
self.global_step += 1
def validate(self):
self.net.eval()
agg_loss_dict = []
for batch_idx, batch in enumerate(self.test_dataloader):
x, y = batch
with torch.no_grad():
x, y = x.to(self.device).float(), y.to(self.device).float()
y_hat, latent = self.net.forward(x, return_latents=True)
loss, cur_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
agg_loss_dict.append(cur_loss_dict)
# Logging related
self.parse_and_log_images(id_logs, x, y, y_hat,
title='images/test/faces',
subscript='{:04d}'.format(batch_idx))
# Log images of first batch to wandb
if self.opts.use_wandb and batch_idx == 0:
self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="test", step=self.global_step, opts=self.opts)
# For first step just do sanity test on small amount of data
if self.global_step == 0 and batch_idx >= 4:
self.net.train()
return None # Do not log, inaccurate in first batch
loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
self.log_metrics(loss_dict, prefix='test')
self.print_metrics(loss_dict, prefix='test')
self.net.train()
return loss_dict
def checkpoint_me(self, loss_dict, is_best):
save_name = 'best_model.pt' if is_best else f'iteration_{self.global_step}.pt'
save_dict = self.__get_save_dict()
checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
torch.save(save_dict, checkpoint_path)
with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
if is_best:
f.write(f'**Best**: Step - {self.global_step}, Loss - {self.best_val_loss} \n{loss_dict}\n')
if self.opts.use_wandb:
self.wb_logger.log_best_model()
else:
f.write(f'Step - {self.global_step}, \n{loss_dict}\n')
def configure_optimizers(self):
params = list(self.net.encoder.parameters())
if self.opts.train_decoder:
params += list(self.net.decoder.parameters())
if self.opts.optim_name == 'adam':
optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
else:
optimizer = Ranger(params, lr=self.opts.learning_rate)
return optimizer
def configure_datasets(self):
if self.opts.dataset_type not in data_configs.DATASETS.keys():
Exception(f'{self.opts.dataset_type} is not a valid dataset_type')
print(f'Loading dataset for {self.opts.dataset_type}')
dataset_args = data_configs.DATASETS[self.opts.dataset_type]
transforms_dict = dataset_args['transforms'](self.opts).get_transforms()
train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'],
target_root=dataset_args['train_target_root'],
source_transform=transforms_dict['transform_source'],
target_transform=transforms_dict['transform_gt_train'],
opts=self.opts)
test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
target_root=dataset_args['test_target_root'],
source_transform=transforms_dict['transform_source'],
target_transform=transforms_dict['transform_test'],
opts=self.opts)
if self.opts.use_wandb:
self.wb_logger.log_dataset_wandb(train_dataset, dataset_name="Train")
self.wb_logger.log_dataset_wandb(test_dataset, dataset_name="Test")
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
return train_dataset, test_dataset
def calc_loss(self, x, y, y_hat, latent):
loss_dict = {}
loss = 0.0
id_logs = None
if self.opts.id_lambda > 0:
loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x)
loss_dict['loss_id'] = float(loss_id)
loss_dict['id_improve'] = float(sim_improvement)
loss = loss_id * self.opts.id_lambda
if self.opts.l2_lambda > 0:
loss_l2 = F.mse_loss(y_hat, y)
loss_dict['loss_l2'] = float(loss_l2)
loss += loss_l2 * self.opts.l2_lambda
if self.opts.lpips_lambda > 0:
loss_lpips = self.lpips_loss(y_hat, y)
loss_dict['loss_lpips'] = float(loss_lpips)
loss += loss_lpips * self.opts.lpips_lambda
if self.opts.lpips_lambda_crop > 0:
loss_lpips_crop = self.lpips_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])
loss_dict['loss_lpips_crop'] = float(loss_lpips_crop)
loss += loss_lpips_crop * self.opts.lpips_lambda_crop
if self.opts.l2_lambda_crop > 0:
loss_l2_crop = F.mse_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])
loss_dict['loss_l2_crop'] = float(loss_l2_crop)
loss += loss_l2_crop * self.opts.l2_lambda_crop
if self.opts.w_norm_lambda > 0:
loss_w_norm = self.w_norm_loss(latent, self.net.latent_avg)
loss_dict['loss_w_norm'] = float(loss_w_norm)
loss += loss_w_norm * self.opts.w_norm_lambda
if self.opts.moco_lambda > 0:
loss_moco, sim_improvement, id_logs = self.moco_loss(y_hat, y, x)
loss_dict['loss_moco'] = float(loss_moco)
loss_dict['id_improve'] = float(sim_improvement)
loss += loss_moco * self.opts.moco_lambda
loss_dict['loss'] = float(loss)
return loss, loss_dict, id_logs
def log_metrics(self, metrics_dict, prefix):
for key, value in metrics_dict.items():
self.logger.add_scalar(f'{prefix}/{key}', value, self.global_step)
if self.opts.use_wandb:
self.wb_logger.log(prefix, metrics_dict, self.global_step)
def print_metrics(self, metrics_dict, prefix):
print(f'Metrics for {prefix}, step {self.global_step}')
for key, value in metrics_dict.items():
print(f'\t{key} = ', value)
def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2):
im_data = []
for i in range(display_count):
cur_im_data = {
'input_face': common.log_input_image(x[i], self.opts),
'target_face': common.tensor2im(y[i]),
'output_face': common.tensor2im(y_hat[i]),
}
if id_logs is not None:
for key in id_logs[i]:
cur_im_data[key] = id_logs[i][key]
im_data.append(cur_im_data)
self.log_images(title, im_data=im_data, subscript=subscript)
def log_images(self, name, im_data, subscript=None, log_latest=False):
fig = common.vis_faces(im_data)
step = self.global_step
if log_latest:
step = 0
if subscript:
path = os.path.join(self.logger.log_dir, name, f'{subscript}_{step:04d}.jpg')
else:
path = os.path.join(self.logger.log_dir, name, f'{step:04d}.jpg')
os.makedirs(os.path.dirname(path), exist_ok=True)
fig.savefig(path)
plt.close(fig)
def __get_save_dict(self):
save_dict = {
'state_dict': self.net.state_dict(),
'opts': vars(self.opts)
}
# save the latent avg in state_dict for inference if truncation of w was used during training
if self.opts.start_from_latent_avg:
save_dict['latent_avg'] = self.net.latent_avg
return save_dict