import os import torch from tqdm import tqdm from color_transfer_loss import ColorTransferLoss from configs import paths_config, hyperparameters, global_config from training.coaches.base_coach import BaseCoach from utils.log_utils import log_images_from_w class MultiIDCoach(BaseCoach): def __init__(self, data_loader, use_wandb): super().__init__(data_loader, use_wandb) def train(self): self.G.synthesis.train() self.G.mapping.train() w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}" os.makedirs(w_path_dir, exist_ok=True) os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True) use_ball_holder = True w_pivots = [] images = [] for fname, image in self.data_loader: if self.image_counter >= hyperparameters.max_images_to_invert: break image_name = fname[0] if hyperparameters.first_inv_type == "w+": embedding_dir = ( f"{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}" ) else: embedding_dir = ( f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}" ) os.makedirs(embedding_dir, exist_ok=True) w_pivot = self.get_inversion(w_path_dir, image_name, image) w_pivots.append(w_pivot) images.append((image_name, image)) self.image_counter += 1 for i in tqdm(range(hyperparameters.max_pti_steps)): self.image_counter = 0 for data, w_pivot in zip(images, w_pivots): image_name, image = data if self.image_counter >= hyperparameters.max_images_to_invert: break real_images_batch = image.to(global_config.device) generated_images = self.forward(w_pivot) loss, l2_loss_val, loss_lpips = self.calc_loss( generated_images, real_images_batch, image_name, self.G, use_ball_holder, w_pivot, {}, ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() use_ball_holder = ( global_config.training_step % hyperparameters.locality_regularization_interval == 0 ) global_config.training_step += 1 self.image_counter += 1 if self.use_wandb: log_images_from_w(w_pivots, self.G, [image[0] for image in images]) torch.save( self.G, f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt", )