Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from tqdm import tqdm | |
from configs import paths_config, hyperparameters, global_config | |
from training.coaches.base_coach import BaseCoach | |
from utils.log_utils import log_images_from_w | |
from color_transfer_loss import ColorTransferLoss | |
import copy | |
class SingleIDCoach(BaseCoach): | |
def __init__(self, data_loader, use_wandb): | |
super().__init__(data_loader, use_wandb) | |
def train(self): | |
w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}" | |
os.makedirs(w_path_dir, exist_ok=True) | |
os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True) | |
use_ball_holder = True | |
for fname, image in tqdm(self.data_loader): | |
image_name = fname[0] | |
self.restart_training() | |
if self.image_counter >= hyperparameters.max_images_to_invert: | |
break | |
embedding_dir = ( | |
f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}" | |
) | |
os.makedirs(embedding_dir, exist_ok=True) | |
w_pivot = None | |
if hyperparameters.use_last_w_pivots: | |
w_pivot = self.load_inversions(w_path_dir, image_name) | |
elif not hyperparameters.use_last_w_pivots or w_pivot is None: | |
w_pivot = self.calc_inversions(image, image_name) | |
# w_pivot = w_pivot.detach().clone().to(global_config.device) | |
w_pivot = w_pivot.to(global_config.device) | |
torch.save(w_pivot, f"{embedding_dir}/0.pt") | |
# w_pivot = torch.load( | |
# f"{embedding_dir}/0.pt", map_location=global_config.device | |
# ) | |
log_images_counter = 0 | |
real_images_batch = image.to(global_config.device) | |
if hyperparameters.color_transfer_lambda > 0: | |
self.color_losses = {} | |
for y in self.years: | |
_, init_rgbs = self.forward_sibling( | |
self.siblings[y].synthesis, w_pivot | |
) | |
self.color_losses[y] = ColorTransferLoss(init_rgbs) | |
for i in tqdm(range(hyperparameters.max_pti_steps)): | |
rgbs = {} | |
if hyperparameters.color_transfer_lambda > 0: | |
for y in self.years: | |
G_sibling_aug = copy.deepcopy(self.siblings[y]) | |
for p_pti, p_orig, p in zip( | |
self.G.synthesis.parameters(), | |
self.original_G.synthesis.parameters(), | |
G_sibling_aug.synthesis.parameters(), | |
): | |
delta = p_pti - p_orig | |
p += delta | |
rgbs[y] = self.forward_sibling( | |
G_sibling_aug.synthesis, w_pivot | |
)[1] | |
generated_images = self.forward(w_pivot) | |
loss, l2_loss_val, loss_lpips = self.calc_loss( | |
generated_images, | |
real_images_batch, | |
image_name, | |
self.G, | |
use_ball_holder, | |
w_pivot, | |
rgbs, | |
) | |
self.optimizer.zero_grad() | |
if loss_lpips <= hyperparameters.LPIPS_value_threshold: | |
break | |
loss.backward() | |
self.optimizer.step() | |
use_ball_holder = ( | |
global_config.training_step | |
% hyperparameters.locality_regularization_interval | |
== 0 | |
) | |
if ( | |
self.use_wandb | |
and log_images_counter % global_config.image_rec_result_log_snapshot | |
== 0 | |
): | |
log_images_from_w([w_pivot], self.G, [image_name]) | |
global_config.training_step += 1 | |
log_images_counter += 1 | |
self.image_counter += 1 | |
torch.save( | |
self.G, | |
f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt", | |
) | |