StyleRes / models /styleres.py
hamzapehlivan
Intial Commit
6709fc9
raw
history blame
2.98 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.e4e import E4E_Inversion
from models.stylegan2 import Generator
from editings.editor import Editor
from options import Settings
class StyleRes(nn.Module):
def __init__(self):
super().__init__()
self.encoder = E4E_Inversion(resolution=256, num_layers = 50, mode='ir_se', out_res=64)
self.generator = Generator(z_dim=512, w_dim=512, c_dim=0, resolution=1024, img_channels=3,
fused_modconv_default='inference_only', embed_res=64)
# Set Generator arguments for eval mode
self.G_kwargs_val = {'noise_mode':'const', 'force_fp32':True}
self.device = Settings.device
self.editor = Editor()
def load_ckpt(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
self.encoder.basic_encoder.load_state_dict(ckpt['e4e'], strict=True)
self.encoder.latent_avg = ckpt['latent_avg']
self.generator.load_state_dict(ckpt['generator_smooth'], strict=True)
print("Model succesfully loaded")
def send_to_device(self):
self.encoder.to(self.device)
self.generator.to(self.device)
if self.device != 'cpu':
self.encoder.latent_avg = self.encoder.latent_avg.cuda()
"""
Inputs: Input images and edit configs
Returns: Edited images together with the randomly generated image when the edit is interpolation.
"""
def edit_images(self, image, cfg):
image = image.to(self.device)
with torch.no_grad():
latents, skips = self.encoder(image)
# GradCtrl requires gradients, others do not
latents_edited = self.editor.edit(latents, cfg)
with torch.no_grad():
# Get F space features F_orig, for the original image
skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val)
# Transform F_orig to incoming image
images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val)
return images
# def edit_demo_image(self, image, edit, factor):
# from utils import AttrDict
# cfg = AttrDict()
# edit = edit.lower()
# if edit in ['pose', 'age', 'smile']:
# cfg.method = 'interfacegan'
# cfg.edit = edit
# cfg.strength = factor
# image = image.to(self.device)
# with torch.no_grad():
# latents, skips = self.encoder(image)
# latents_edited = self.editor.edit(latents, cfg)
# with torch.no_grad():
# # Get F space features F_orig, for the original image
# skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val)
# # Transform F_orig to incoming image
# images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val)
# return images