Spaces:
Runtime error
Runtime error
""" | |
This file defines the core research contribution | |
""" | |
import matplotlib | |
matplotlib.use("Agg") | |
import math | |
import torch | |
from torch import nn | |
from pixel2style2pixel.models.encoders import psp_encoders | |
from pixel2style2pixel.models.stylegan2.model import Generator | |
from pixel2style2pixel.configs.paths_config import model_paths | |
def get_keys(d, name): | |
if "state_dict" in d: | |
d = d["state_dict"] | |
d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name)] == name} | |
return d_filt | |
class pSp(nn.Module): | |
def __init__(self, opts): | |
super(pSp, self).__init__() | |
self.set_opts(opts) | |
# compute number of style inputs based on the output resolution | |
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 | |
# Define architecture | |
self.encoder = self.set_encoder() | |
self.decoder = Generator(self.opts.output_size, 512, 8) | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
# Load weights if needed | |
self.load_weights() | |
def set_encoder(self): | |
if self.opts.encoder_type == "GradualStyleEncoder": | |
encoder = psp_encoders.GradualStyleEncoder(50, "ir_se", self.opts) | |
elif self.opts.encoder_type == "BackboneEncoderUsingLastLayerIntoW": | |
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW( | |
50, "ir_se", self.opts | |
) | |
elif self.opts.encoder_type == "BackboneEncoderUsingLastLayerIntoWPlus": | |
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus( | |
50, "ir_se", self.opts | |
) | |
else: | |
raise Exception("{} is not a valid encoders".format(self.opts.encoder_type)) | |
return encoder | |
def load_weights(self): | |
if self.opts.checkpoint_path is not None: | |
print("Loading pSp from checkpoint: {}".format(self.opts.checkpoint_path)) | |
ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu") | |
self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True) | |
self.decoder.load_state_dict(get_keys(ckpt, "decoder"), strict=True) | |
self.__load_latent_avg(ckpt) | |
else: | |
print("Loading encoders weights from irse50!") | |
encoder_ckpt = torch.load(model_paths["ir_se50"]) | |
# if input to encoder is not an RGB image, do not load the input layer weights | |
if self.opts.label_nc != 0: | |
encoder_ckpt = { | |
k: v for k, v in encoder_ckpt.items() if "input_layer" not in k | |
} | |
self.encoder.load_state_dict(encoder_ckpt, strict=False) | |
print("Loading decoder weights from pretrained!") | |
ckpt = torch.load(self.opts.stylegan_weights) | |
self.decoder.load_state_dict(ckpt["g_ema"], strict=False) | |
if self.opts.learn_in_w: | |
self.__load_latent_avg(ckpt, repeat=1) | |
else: | |
self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) | |
def forward( | |
self, | |
x, | |
resize=True, | |
latent_mask=None, | |
input_code=False, | |
randomize_noise=True, | |
inject_latent=None, | |
return_latents=False, | |
alpha=None, | |
): | |
if input_code: | |
codes = x | |
else: | |
codes = self.encoder(x) | |
# normalize with respect to the center of an average face | |
if self.opts.start_from_latent_avg: | |
if self.opts.learn_in_w: | |
codes = codes + self.latent_avg.repeat(codes.shape[0], 1) | |
else: | |
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) | |
if latent_mask is not None: | |
for i in latent_mask: | |
if inject_latent is not None: | |
if alpha is not None: | |
codes[:, i] = ( | |
alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] | |
) | |
else: | |
codes[:, i] = inject_latent[:, i] | |
else: | |
codes[:, i] = 0 | |
input_is_latent = not input_code | |
images, result_latent = self.decoder( | |
[codes], | |
input_is_latent=input_is_latent, | |
randomize_noise=randomize_noise, | |
return_latents=return_latents, | |
) | |
if resize: | |
images = self.face_pool(images) | |
if return_latents: | |
return images, result_latent | |
else: | |
return images | |
def set_opts(self, opts): | |
self.opts = opts | |
def __load_latent_avg(self, ckpt, repeat=None): | |
if "latent_avg" in ckpt: | |
self.latent_avg = ckpt["latent_avg"].to(self.opts.device) | |
if repeat is not None: | |
self.latent_avg = self.latent_avg.repeat(repeat, 1) | |
else: | |
self.latent_avg = None | |