echen01
working demo
2fec875
raw
history blame
11 kB
import abc
import os
from argparse import Namespace
import wandb
import os.path
from criteria.localitly_regulizer import Space_Regulizer
import torch
from torchvision import transforms
from lpips import LPIPS
from training.projectors import w_projector # w_plus_projector as w_projector
from configs import global_config, paths_config, hyperparameters
from criteria import l2_loss
from criteria import mask
from criteria import id_loss
from models.e4e.psp import pSp
from utils.log_utils import log_image_from_w
from utils.models_utils import toogle_grad, load_old_G
from torch_utils import misc
from torch_utils.ops import upfirdn2d
import numpy as np
import pickle
import copy
class BaseCoach:
def __init__(self, data_loader, use_wandb):
self.use_wandb = use_wandb
self.data_loader = data_loader
self.w_pivots = {}
self.image_counter = 0
if hyperparameters.first_inv_type == "w+":
self.initilize_e4e()
self.e4e_image_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
# Initialize loss
self.lpips_loss = (
LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval()
)
self.id_loss = (
id_loss.IDLoss(
paths_config.ir_se50,
official=False,
)
.to(global_config.device)
.eval()
)
if hyperparameters.use_mask:
self.mask = mask.Mask()
self.restart_training()
# Initialize checkpoint dir
self.checkpoint_dir = paths_config.checkpoints_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def restart_training(self):
# Initialize networks
self.G = load_old_G()
toogle_grad(self.G, True)
self.original_G = load_old_G()
self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss)
self.optimizer = self.configure_optimizers()
def get_inversion(self, w_path_dir, image_name, image):
embedding_dir = f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}"
os.makedirs(embedding_dir, exist_ok=True)
w_pivot = None
if hyperparameters.use_last_w_pivots:
w_pivot = self.load_inversions(w_path_dir, image_name)
if not hyperparameters.use_last_w_pivots or w_pivot is None:
w_pivot = self.calc_inversions(image, image_name)
torch.save(w_pivot, f"{embedding_dir}/0.pt")
w_pivot = w_pivot.to(global_config.device)
return w_pivot
def load_inversions(self, w_path_dir, image_name):
if image_name in self.w_pivots:
return self.w_pivots[image_name]
if hyperparameters.first_inv_type == "w+":
w_potential_path = (
f"{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}/0.pt"
)
else:
w_potential_path = (
f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}/0.pt"
)
if not os.path.isfile(w_potential_path):
return None
w = torch.load(w_potential_path, map_location=global_config.device).to(
global_config.device
)
self.w_pivots[image_name] = w
return w
def calc_inversions(self, image, image_name):
if hyperparameters.first_inv_type == "w+":
w = self.get_e4e_inversion(image)
else:
id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255
w = w_projector.project(
self.G,
id_image,
device=torch.device(global_config.device),
w_avg_samples=600,
num_steps=hyperparameters.first_inv_steps,
w_name=image_name,
use_wandb=self.use_wandb,
)
return w
@abc.abstractmethod
def train(self):
pass
def configure_optimizers(self):
#params = list(self.G.parameters())
params = []
# res = ["64", "32", "16", "8", "4"]
for n, p in self.G.synthesis.named_parameters():
#for r in res:
#if r in n:
if "rgb" not in n:
params.append(p)
# params += list(self.G.synthesis.parameters())
optimizer = torch.optim.Adam(params, lr=hyperparameters.pti_learning_rate)
return optimizer
def calc_loss(
self,
generated_images,
real_images,
log_name,
new_G,
use_ball_holder,
w_batch,
rgbs,
):
loss = 0.0
if hyperparameters.use_mask:
real_images, generated_images = self.mask(real_images, generated_images)
if hyperparameters.pt_l2_lambda > 0:
l2_loss_val = l2_loss.l2_loss(generated_images, real_images, gray=False)
if self.use_wandb:
wandb.log(
{f"MSE_loss_val_{log_name}": l2_loss_val.detach().cpu()},
step=global_config.training_step,
)
loss += l2_loss_val * hyperparameters.pt_l2_lambda
if hyperparameters.pt_lpips_lambda > 0:
loss_lpips = self.lpips_loss(real_images, generated_images)
loss_lpips = torch.squeeze(loss_lpips)
if self.use_wandb:
wandb.log(
{f"LPIPS_loss_val_{log_name}": loss_lpips.detach().cpu()},
step=global_config.training_step,
)
loss += loss_lpips * hyperparameters.pt_lpips_lambda
if hyperparameters.color_transfer_lambda > 0:
for y in self.years:
color_loss = self.color_losses[y](rgbs[y])
""" print(
"Year: ",
y,
" Color Transfer:",
color_loss * hyperparameters.color_transfer_lambda,
) """
loss += color_loss * hyperparameters.color_transfer_lambda
if hyperparameters.id_lambda > 0:
loss_id = self.id_loss(real_images, generated_images)
loss_id = torch.squeeze(loss_id)
loss += loss_id * hyperparameters.id_lambda
if use_ball_holder and hyperparameters.use_locality_regularization:
ball_holder_loss_val = self.space_regulizer.space_regulizer_loss(
new_G, w_batch, use_wandb=self.use_wandb
)
loss += ball_holder_loss_val
return loss, l2_loss_val, loss_lpips
def synthesis_block(self, block, x, img, ws, force_fp32=False, fused_modconv=None):
w_iter = iter(ws.unbind(dim=1))
dtype = torch.float16 if block.use_fp16 and not force_fp32 else torch.float32
memory_format = (
torch.channels_last
if block.channels_last and not force_fp32
else torch.contiguous_format
)
if fused_modconv is None:
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
fused_modconv = (not block.training) and (
dtype == torch.float32 or int(x.shape[0]) == 1
)
# Input.
if block.in_channels == 0:
x = block.const.to(dtype=dtype, memory_format=memory_format)
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
else:
misc.assert_shape(
x,
[None, block.in_channels, block.resolution // 2, block.resolution // 2],
)
x = x.to(dtype=dtype, memory_format=memory_format)
# Main layers.
if block.in_channels == 0:
x = block.conv1(x, next(w_iter), fused_modconv=fused_modconv)
elif block.architecture == "resnet":
y = block.skip(x, gain=np.sqrt(0.5))
x = block.conv0(x, next(w_iter), fused_modconv=fused_modconv)
x = block.conv1(
x,
next(w_iter),
fused_modconv=fused_modconv,
gain=np.sqrt(0.5),
)
x = y.add_(x)
else:
x = block.conv0(x, next(w_iter), fused_modconv=fused_modconv)
x = block.conv1(x, next(w_iter), fused_modconv=fused_modconv)
# ToRGB.
if img is not None:
misc.assert_shape(
img,
[
None,
block.img_channels,
block.resolution // 2,
block.resolution // 2,
],
)
img = upfirdn2d.upsample2d(img, block.resample_filter)
if block.is_last or block.architecture == "skip":
y = block.torgb(x, next(w_iter), fused_modconv=fused_modconv)
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
img = img.add_(y) if img is not None else y
assert x.dtype == dtype
assert img is None or img.dtype == torch.float32
return x, img, y
def forward(self, w):
generated_images = self.G.synthesis(w, noise_mode="const", force_fp32=True)
return generated_images
def forward_sibling(self, G_sibling, w):
block_ws = []
rgbs = []
ws = w.to(torch.float32)
w_idx = 0
for res in G_sibling.block_resolutions:
block = getattr(G_sibling, f"b{res}")
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
w_idx += block.num_conv
x = img = None
for res, cur_ws in zip(G_sibling.block_resolutions, block_ws):
block = getattr(G_sibling, f"b{res}")
x, img, rgb_mod = self.synthesis_block(block, x, img, cur_ws)
# print(f"ToRGB: {res}", rgb_mod)
rgbs.append(rgb_mod)
return img, rgbs
def initilize_e4e(self):
ckpt = torch.load(paths_config.e4e, map_location="cpu")
opts = ckpt["opts"]
opts["batch_size"] = hyperparameters.train_batch_size
opts["checkpoint_path"] = paths_config.e4e
opts = Namespace(**opts)
self.e4e_inversion_net = pSp(opts)
self.e4e_inversion_net.eval()
self.e4e_inversion_net = self.e4e_inversion_net.to(global_config.device)
toogle_grad(self.e4e_inversion_net, False)
def get_e4e_inversion(self, image):
image = (image + 1) / 2
new_image = self.e4e_image_transform(image[0]).to(global_config.device)
_, w = self.e4e_inversion_net(
new_image.unsqueeze(0),
randomize_noise=False,
return_latents=True,
resize=False,
input_code=False,
)
if self.use_wandb:
log_image_from_w(w, self.G, "First e4e inversion")
return w