import torch import torch.nn as nn import torch.nn.functional as F from models.e4e import E4E_Inversion from models.stylegan2 import Generator from editings.editor import Editor from options import Settings class StyleRes(nn.Module): def __init__(self): super().__init__() self.encoder = E4E_Inversion(resolution=256, num_layers = 50, mode='ir_se', out_res=64) self.generator = Generator(z_dim=512, w_dim=512, c_dim=0, resolution=1024, img_channels=3, fused_modconv_default='inference_only', embed_res=64) # Set Generator arguments for eval mode self.G_kwargs_val = {'noise_mode':'const', 'force_fp32':True} self.device = Settings.device self.editor = Editor() def load_ckpt(self, ckpt_path): ckpt = torch.load(ckpt_path, map_location='cpu') self.encoder.basic_encoder.load_state_dict(ckpt['e4e'], strict=True) self.encoder.latent_avg = ckpt['latent_avg'] self.generator.load_state_dict(ckpt['generator_smooth'], strict=True) print("Model succesfully loaded") def send_to_device(self): self.encoder.to(self.device) self.generator.to(self.device) if self.device != 'cpu': self.encoder.latent_avg = self.encoder.latent_avg.cuda() """ Inputs: Input images and edit configs Returns: Edited images together with the randomly generated image when the edit is interpolation. """ def edit_images(self, image, cfg): image = image.to(self.device) with torch.no_grad(): latents, skips = self.encoder(image) # GradCtrl requires gradients, others do not latents_edited = self.editor.edit(latents, cfg) with torch.no_grad(): # Get F space features F_orig, for the original image skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val) # Transform F_orig to incoming image images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val) return images # def edit_demo_image(self, image, edit, factor): # from utils import AttrDict # cfg = AttrDict() # edit = edit.lower() # if edit in ['pose', 'age', 'smile']: # cfg.method = 'interfacegan' # cfg.edit = edit # cfg.strength = factor # image = image.to(self.device) # with torch.no_grad(): # latents, skips = self.encoder(image) # latents_edited = self.editor.edit(latents, cfg) # with torch.no_grad(): # # Get F space features F_orig, for the original image # skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val) # # Transform F_orig to incoming image # images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val) # return images