echen01
add PTI
2e34814
raw
history blame
4.24 kB
import os
import torch
from tqdm import tqdm
from configs import paths_config, hyperparameters, global_config
from training.coaches.base_coach import BaseCoach
from utils.log_utils import log_images_from_w
from color_transfer_loss import ColorTransferLoss
import copy
class SingleIDCoach(BaseCoach):
def __init__(self, data_loader, use_wandb):
super().__init__(data_loader, use_wandb)
def train(self):
w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}"
os.makedirs(w_path_dir, exist_ok=True)
os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True)
use_ball_holder = True
for fname, image in tqdm(self.data_loader):
image_name = fname[0]
self.restart_training()
if self.image_counter >= hyperparameters.max_images_to_invert:
break
embedding_dir = (
f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}"
)
os.makedirs(embedding_dir, exist_ok=True)
w_pivot = None
if hyperparameters.use_last_w_pivots:
w_pivot = self.load_inversions(w_path_dir, image_name)
elif not hyperparameters.use_last_w_pivots or w_pivot is None:
w_pivot = self.calc_inversions(image, image_name)
# w_pivot = w_pivot.detach().clone().to(global_config.device)
w_pivot = w_pivot.to(global_config.device)
torch.save(w_pivot, f"{embedding_dir}/0.pt")
# w_pivot = torch.load(
# f"{embedding_dir}/0.pt", map_location=global_config.device
# )
log_images_counter = 0
real_images_batch = image.to(global_config.device)
if hyperparameters.color_transfer_lambda > 0:
self.color_losses = {}
for y in self.years:
_, init_rgbs = self.forward_sibling(
self.siblings[y].synthesis, w_pivot
)
self.color_losses[y] = ColorTransferLoss(init_rgbs)
for i in tqdm(range(hyperparameters.max_pti_steps)):
rgbs = {}
if hyperparameters.color_transfer_lambda > 0:
for y in self.years:
G_sibling_aug = copy.deepcopy(self.siblings[y])
for p_pti, p_orig, p in zip(
self.G.synthesis.parameters(),
self.original_G.synthesis.parameters(),
G_sibling_aug.synthesis.parameters(),
):
delta = p_pti - p_orig
p += delta
rgbs[y] = self.forward_sibling(
G_sibling_aug.synthesis, w_pivot
)[1]
generated_images = self.forward(w_pivot)
loss, l2_loss_val, loss_lpips = self.calc_loss(
generated_images,
real_images_batch,
image_name,
self.G,
use_ball_holder,
w_pivot,
rgbs,
)
self.optimizer.zero_grad()
if loss_lpips <= hyperparameters.LPIPS_value_threshold:
break
loss.backward()
self.optimizer.step()
use_ball_holder = (
global_config.training_step
% hyperparameters.locality_regularization_interval
== 0
)
if (
self.use_wandb
and log_images_counter % global_config.image_rec_result_log_snapshot
== 0
):
log_images_from_w([w_pivot], self.G, [image_name])
global_config.training_step += 1
log_images_counter += 1
self.image_counter += 1
torch.save(
self.G,
f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt",
)