""" This file defines the core research contribution """ import matplotlib matplotlib.use("Agg") import math import torch from torch import nn from pixel2style2pixel.models.encoders import psp_encoders from pixel2style2pixel.models.stylegan2.model import Generator from pixel2style2pixel.configs.paths_config import model_paths def get_keys(d, name): if "state_dict" in d: d = d["state_dict"] d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name)] == name} return d_filt class pSp(nn.Module): def __init__(self, opts): super(pSp, self).__init__() self.set_opts(opts) # compute number of style inputs based on the output resolution self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 # Define architecture self.encoder = self.set_encoder() self.decoder = Generator(self.opts.output_size, 512, 8) self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) # Load weights if needed self.load_weights() def set_encoder(self): if self.opts.encoder_type == "GradualStyleEncoder": encoder = psp_encoders.GradualStyleEncoder(50, "ir_se", self.opts) elif self.opts.encoder_type == "BackboneEncoderUsingLastLayerIntoW": encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW( 50, "ir_se", self.opts ) elif self.opts.encoder_type == "BackboneEncoderUsingLastLayerIntoWPlus": encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus( 50, "ir_se", self.opts ) else: raise Exception("{} is not a valid encoders".format(self.opts.encoder_type)) return encoder def load_weights(self): if self.opts.checkpoint_path is not None: print("Loading pSp from checkpoint: {}".format(self.opts.checkpoint_path)) ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu") self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True) self.decoder.load_state_dict(get_keys(ckpt, "decoder"), strict=True) self.__load_latent_avg(ckpt) else: print("Loading encoders weights from irse50!") encoder_ckpt = torch.load(model_paths["ir_se50"]) # if input to encoder is not an RGB image, do not load the input layer weights if self.opts.label_nc != 0: encoder_ckpt = { k: v for k, v in encoder_ckpt.items() if "input_layer" not in k } self.encoder.load_state_dict(encoder_ckpt, strict=False) print("Loading decoder weights from pretrained!") ckpt = torch.load(self.opts.stylegan_weights) self.decoder.load_state_dict(ckpt["g_ema"], strict=False) if self.opts.learn_in_w: self.__load_latent_avg(ckpt, repeat=1) else: self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) def forward( self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, inject_latent=None, return_latents=False, alpha=None, ): if input_code: codes = x else: codes = self.encoder(x) # normalize with respect to the center of an average face if self.opts.start_from_latent_avg: if self.opts.learn_in_w: codes = codes + self.latent_avg.repeat(codes.shape[0], 1) else: codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) if latent_mask is not None: for i in latent_mask: if inject_latent is not None: if alpha is not None: codes[:, i] = ( alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] ) else: codes[:, i] = inject_latent[:, i] else: codes[:, i] = 0 input_is_latent = not input_code images, result_latent = self.decoder( [codes], input_is_latent=input_is_latent, randomize_noise=randomize_noise, return_latents=return_latents, ) if resize: images = self.face_pool(images) if return_latents: return images, result_latent else: return images def set_opts(self, opts): self.opts = opts def __load_latent_avg(self, ckpt, repeat=None): if "latent_avg" in ckpt: self.latent_avg = ckpt["latent_avg"].to(self.opts.device) if repeat is not None: self.latent_avg = self.latent_avg.repeat(repeat, 1) else: self.latent_avg = None