import os import torch from tqdm import tqdm from configs import paths_config, hyperparameters, global_config from training.coaches.base_coach import BaseCoach from utils.log_utils import log_images_from_w from color_transfer_loss import ColorTransferLoss import copy class SingleIDCoach(BaseCoach): def __init__(self, data_loader, use_wandb): super().__init__(data_loader, use_wandb) def train(self): w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}" os.makedirs(w_path_dir, exist_ok=True) os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True) use_ball_holder = True for fname, image in tqdm(self.data_loader): image_name = fname[0] self.restart_training() if self.image_counter >= hyperparameters.max_images_to_invert: break embedding_dir = ( f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}" ) os.makedirs(embedding_dir, exist_ok=True) w_pivot = None if hyperparameters.use_last_w_pivots: w_pivot = self.load_inversions(w_path_dir, image_name) elif not hyperparameters.use_last_w_pivots or w_pivot is None: w_pivot = self.calc_inversions(image, image_name) # w_pivot = w_pivot.detach().clone().to(global_config.device) w_pivot = w_pivot.to(global_config.device) torch.save(w_pivot, f"{embedding_dir}/0.pt") # w_pivot = torch.load( # f"{embedding_dir}/0.pt", map_location=global_config.device # ) log_images_counter = 0 real_images_batch = image.to(global_config.device) if hyperparameters.color_transfer_lambda > 0: self.color_losses = {} for y in self.years: _, init_rgbs = self.forward_sibling( self.siblings[y].synthesis, w_pivot ) self.color_losses[y] = ColorTransferLoss(init_rgbs) for i in tqdm(range(hyperparameters.max_pti_steps)): rgbs = {} if hyperparameters.color_transfer_lambda > 0: for y in self.years: G_sibling_aug = copy.deepcopy(self.siblings[y]) for p_pti, p_orig, p in zip( self.G.synthesis.parameters(), self.original_G.synthesis.parameters(), G_sibling_aug.synthesis.parameters(), ): delta = p_pti - p_orig p += delta rgbs[y] = self.forward_sibling( G_sibling_aug.synthesis, w_pivot )[1] generated_images = self.forward(w_pivot) loss, l2_loss_val, loss_lpips = self.calc_loss( generated_images, real_images_batch, image_name, self.G, use_ball_holder, w_pivot, rgbs, ) self.optimizer.zero_grad() if loss_lpips <= hyperparameters.LPIPS_value_threshold: break loss.backward() self.optimizer.step() use_ball_holder = ( global_config.training_step % hyperparameters.locality_regularization_interval == 0 ) if ( self.use_wandb and log_images_counter % global_config.image_rec_result_log_snapshot == 0 ): log_images_from_w([w_pivot], self.G, [image_name]) global_config.training_step += 1 log_images_counter += 1 self.image_counter += 1 torch.save( self.G, f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt", )