Spaces:
Runtime error
Runtime error
import abc | |
import os | |
from argparse import Namespace | |
import wandb | |
import os.path | |
from criteria.localitly_regulizer import Space_Regulizer | |
import torch | |
from torchvision import transforms | |
from lpips import LPIPS | |
from training.projectors import w_projector # w_plus_projector as w_projector | |
from configs import global_config, paths_config, hyperparameters | |
from criteria import l2_loss | |
from criteria import mask | |
from criteria import id_loss | |
from models.e4e.psp import pSp | |
from utils.log_utils import log_image_from_w | |
from utils.models_utils import toogle_grad, load_old_G | |
from torch_utils import misc | |
from torch_utils.ops import upfirdn2d | |
import numpy as np | |
import pickle | |
import copy | |
class BaseCoach: | |
def __init__(self, data_loader, use_wandb): | |
self.use_wandb = use_wandb | |
self.data_loader = data_loader | |
self.w_pivots = {} | |
self.image_counter = 0 | |
if hyperparameters.first_inv_type == "w+": | |
self.initilize_e4e() | |
self.e4e_image_transform = transforms.Compose( | |
[ | |
transforms.ToPILImage(), | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
# Initialize loss | |
self.lpips_loss = ( | |
LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval() | |
) | |
self.id_loss = ( | |
id_loss.IDLoss( | |
paths_config.ir_se50, | |
official=False, | |
) | |
.to(global_config.device) | |
.eval() | |
) | |
if hyperparameters.use_mask: | |
self.mask = mask.Mask() | |
self.restart_training() | |
# Initialize checkpoint dir | |
self.checkpoint_dir = paths_config.checkpoints_dir | |
os.makedirs(self.checkpoint_dir, exist_ok=True) | |
def restart_training(self): | |
# Initialize networks | |
self.G = load_old_G() | |
toogle_grad(self.G, True) | |
self.original_G = load_old_G() | |
self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss) | |
self.optimizer = self.configure_optimizers() | |
def get_inversion(self, w_path_dir, image_name, image): | |
embedding_dir = f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}" | |
os.makedirs(embedding_dir, exist_ok=True) | |
w_pivot = None | |
if hyperparameters.use_last_w_pivots: | |
w_pivot = self.load_inversions(w_path_dir, image_name) | |
if not hyperparameters.use_last_w_pivots or w_pivot is None: | |
w_pivot = self.calc_inversions(image, image_name) | |
torch.save(w_pivot, f"{embedding_dir}/0.pt") | |
w_pivot = w_pivot.to(global_config.device) | |
return w_pivot | |
def load_inversions(self, w_path_dir, image_name): | |
if image_name in self.w_pivots: | |
return self.w_pivots[image_name] | |
if hyperparameters.first_inv_type == "w+": | |
w_potential_path = ( | |
f"{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}/0.pt" | |
) | |
else: | |
w_potential_path = ( | |
f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}/0.pt" | |
) | |
if not os.path.isfile(w_potential_path): | |
return None | |
w = torch.load(w_potential_path, map_location=global_config.device).to( | |
global_config.device | |
) | |
self.w_pivots[image_name] = w | |
return w | |
def calc_inversions(self, image, image_name): | |
if hyperparameters.first_inv_type == "w+": | |
w = self.get_e4e_inversion(image) | |
else: | |
id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 | |
w = w_projector.project( | |
self.G, | |
id_image, | |
device=torch.device(global_config.device), | |
w_avg_samples=600, | |
num_steps=hyperparameters.first_inv_steps, | |
w_name=image_name, | |
use_wandb=self.use_wandb, | |
) | |
return w | |
def train(self): | |
pass | |
def configure_optimizers(self): | |
#params = list(self.G.parameters()) | |
params = [] | |
# res = ["64", "32", "16", "8", "4"] | |
for n, p in self.G.synthesis.named_parameters(): | |
#for r in res: | |
#if r in n: | |
if "rgb" not in n: | |
params.append(p) | |
# params += list(self.G.synthesis.parameters()) | |
optimizer = torch.optim.Adam(params, lr=hyperparameters.pti_learning_rate) | |
return optimizer | |
def calc_loss( | |
self, | |
generated_images, | |
real_images, | |
log_name, | |
new_G, | |
use_ball_holder, | |
w_batch, | |
rgbs, | |
): | |
loss = 0.0 | |
if hyperparameters.use_mask: | |
real_images, generated_images = self.mask(real_images, generated_images) | |
if hyperparameters.pt_l2_lambda > 0: | |
l2_loss_val = l2_loss.l2_loss(generated_images, real_images, gray=False) | |
if self.use_wandb: | |
wandb.log( | |
{f"MSE_loss_val_{log_name}": l2_loss_val.detach().cpu()}, | |
step=global_config.training_step, | |
) | |
loss += l2_loss_val * hyperparameters.pt_l2_lambda | |
if hyperparameters.pt_lpips_lambda > 0: | |
loss_lpips = self.lpips_loss(real_images, generated_images) | |
loss_lpips = torch.squeeze(loss_lpips) | |
if self.use_wandb: | |
wandb.log( | |
{f"LPIPS_loss_val_{log_name}": loss_lpips.detach().cpu()}, | |
step=global_config.training_step, | |
) | |
loss += loss_lpips * hyperparameters.pt_lpips_lambda | |
if hyperparameters.color_transfer_lambda > 0: | |
for y in self.years: | |
color_loss = self.color_losses[y](rgbs[y]) | |
""" print( | |
"Year: ", | |
y, | |
" Color Transfer:", | |
color_loss * hyperparameters.color_transfer_lambda, | |
) """ | |
loss += color_loss * hyperparameters.color_transfer_lambda | |
if hyperparameters.id_lambda > 0: | |
loss_id = self.id_loss(real_images, generated_images) | |
loss_id = torch.squeeze(loss_id) | |
loss += loss_id * hyperparameters.id_lambda | |
if use_ball_holder and hyperparameters.use_locality_regularization: | |
ball_holder_loss_val = self.space_regulizer.space_regulizer_loss( | |
new_G, w_batch, use_wandb=self.use_wandb | |
) | |
loss += ball_holder_loss_val | |
return loss, l2_loss_val, loss_lpips | |
def synthesis_block(self, block, x, img, ws, force_fp32=False, fused_modconv=None): | |
w_iter = iter(ws.unbind(dim=1)) | |
dtype = torch.float16 if block.use_fp16 and not force_fp32 else torch.float32 | |
memory_format = ( | |
torch.channels_last | |
if block.channels_last and not force_fp32 | |
else torch.contiguous_format | |
) | |
if fused_modconv is None: | |
with misc.suppress_tracer_warnings(): # this value will be treated as a constant | |
fused_modconv = (not block.training) and ( | |
dtype == torch.float32 or int(x.shape[0]) == 1 | |
) | |
# Input. | |
if block.in_channels == 0: | |
x = block.const.to(dtype=dtype, memory_format=memory_format) | |
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) | |
else: | |
misc.assert_shape( | |
x, | |
[None, block.in_channels, block.resolution // 2, block.resolution // 2], | |
) | |
x = x.to(dtype=dtype, memory_format=memory_format) | |
# Main layers. | |
if block.in_channels == 0: | |
x = block.conv1(x, next(w_iter), fused_modconv=fused_modconv) | |
elif block.architecture == "resnet": | |
y = block.skip(x, gain=np.sqrt(0.5)) | |
x = block.conv0(x, next(w_iter), fused_modconv=fused_modconv) | |
x = block.conv1( | |
x, | |
next(w_iter), | |
fused_modconv=fused_modconv, | |
gain=np.sqrt(0.5), | |
) | |
x = y.add_(x) | |
else: | |
x = block.conv0(x, next(w_iter), fused_modconv=fused_modconv) | |
x = block.conv1(x, next(w_iter), fused_modconv=fused_modconv) | |
# ToRGB. | |
if img is not None: | |
misc.assert_shape( | |
img, | |
[ | |
None, | |
block.img_channels, | |
block.resolution // 2, | |
block.resolution // 2, | |
], | |
) | |
img = upfirdn2d.upsample2d(img, block.resample_filter) | |
if block.is_last or block.architecture == "skip": | |
y = block.torgb(x, next(w_iter), fused_modconv=fused_modconv) | |
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) | |
img = img.add_(y) if img is not None else y | |
assert x.dtype == dtype | |
assert img is None or img.dtype == torch.float32 | |
return x, img, y | |
def forward(self, w): | |
generated_images = self.G.synthesis(w, noise_mode="const", force_fp32=True) | |
return generated_images | |
def forward_sibling(self, G_sibling, w): | |
block_ws = [] | |
rgbs = [] | |
ws = w.to(torch.float32) | |
w_idx = 0 | |
for res in G_sibling.block_resolutions: | |
block = getattr(G_sibling, f"b{res}") | |
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) | |
w_idx += block.num_conv | |
x = img = None | |
for res, cur_ws in zip(G_sibling.block_resolutions, block_ws): | |
block = getattr(G_sibling, f"b{res}") | |
x, img, rgb_mod = self.synthesis_block(block, x, img, cur_ws) | |
# print(f"ToRGB: {res}", rgb_mod) | |
rgbs.append(rgb_mod) | |
return img, rgbs | |
def initilize_e4e(self): | |
ckpt = torch.load(paths_config.e4e, map_location="cpu") | |
opts = ckpt["opts"] | |
opts["batch_size"] = hyperparameters.train_batch_size | |
opts["checkpoint_path"] = paths_config.e4e | |
opts = Namespace(**opts) | |
self.e4e_inversion_net = pSp(opts) | |
self.e4e_inversion_net.eval() | |
self.e4e_inversion_net = self.e4e_inversion_net.to(global_config.device) | |
toogle_grad(self.e4e_inversion_net, False) | |
def get_e4e_inversion(self, image): | |
image = (image + 1) / 2 | |
new_image = self.e4e_image_transform(image[0]).to(global_config.device) | |
_, w = self.e4e_inversion_net( | |
new_image.unsqueeze(0), | |
randomize_noise=False, | |
return_latents=True, | |
resize=False, | |
input_code=False, | |
) | |
if self.use_wandb: | |
log_image_from_w(w, self.G, "First e4e inversion") | |
return w | |