Spaces:
Runtime error
Runtime error
File size: 2,930 Bytes
2e34814 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import torch
from tqdm import tqdm
from color_transfer_loss import ColorTransferLoss
from configs import paths_config, hyperparameters, global_config
from training.coaches.base_coach import BaseCoach
from utils.log_utils import log_images_from_w
class MultiIDCoach(BaseCoach):
def __init__(self, data_loader, use_wandb):
super().__init__(data_loader, use_wandb)
def train(self):
self.G.synthesis.train()
self.G.mapping.train()
w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}"
os.makedirs(w_path_dir, exist_ok=True)
os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True)
use_ball_holder = True
w_pivots = []
images = []
for fname, image in self.data_loader:
if self.image_counter >= hyperparameters.max_images_to_invert:
break
image_name = fname[0]
if hyperparameters.first_inv_type == "w+":
embedding_dir = (
f"{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}"
)
else:
embedding_dir = (
f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}"
)
os.makedirs(embedding_dir, exist_ok=True)
w_pivot = self.get_inversion(w_path_dir, image_name, image)
w_pivots.append(w_pivot)
images.append((image_name, image))
self.image_counter += 1
for i in tqdm(range(hyperparameters.max_pti_steps)):
self.image_counter = 0
for data, w_pivot in zip(images, w_pivots):
image_name, image = data
if self.image_counter >= hyperparameters.max_images_to_invert:
break
real_images_batch = image.to(global_config.device)
generated_images = self.forward(w_pivot)
loss, l2_loss_val, loss_lpips = self.calc_loss(
generated_images,
real_images_batch,
image_name,
self.G,
use_ball_holder,
w_pivot,
{},
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
use_ball_holder = (
global_config.training_step
% hyperparameters.locality_regularization_interval
== 0
)
global_config.training_step += 1
self.image_counter += 1
if self.use_wandb:
log_images_from_w(w_pivots, self.G, [image[0] for image in images])
torch.save(
self.G,
f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt",
)
|