File size: 2,981 Bytes
6709fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.e4e import E4E_Inversion
from models.stylegan2 import Generator
from editings.editor import Editor
from options import Settings

class StyleRes(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = E4E_Inversion(resolution=256, num_layers = 50, mode='ir_se', out_res=64)
        self.generator = Generator(z_dim=512, w_dim=512, c_dim=0, resolution=1024, img_channels=3, 
                        fused_modconv_default='inference_only', embed_res=64)
        # Set Generator arguments for eval mode
        self.G_kwargs_val = {'noise_mode':'const', 'force_fp32':True}
        self.device = Settings.device
        self.editor = Editor()

    def load_ckpt(self, ckpt_path):
        ckpt = torch.load(ckpt_path, map_location='cpu')
        self.encoder.basic_encoder.load_state_dict(ckpt['e4e'], strict=True)
        self.encoder.latent_avg = ckpt['latent_avg']
        self.generator.load_state_dict(ckpt['generator_smooth'], strict=True)
        print("Model succesfully loaded")

    def send_to_device(self):
        self.encoder.to(self.device)
        self.generator.to(self.device)
        if self.device != 'cpu':
            self.encoder.latent_avg = self.encoder.latent_avg.cuda()

    """
        Inputs: Input images and edit configs
        Returns: Edited images together with the randomly generated image when the edit is interpolation.
    """
    def edit_images(self, image, cfg):
        image = image.to(self.device)
        with torch.no_grad():
            latents, skips = self.encoder(image)

        # GradCtrl requires gradients, others do not
        latents_edited = self.editor.edit(latents, cfg) 

        with torch.no_grad():
            # Get F space features F_orig, for the original image
            skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val)
            # Transform F_orig to incoming image
            images, _ = self.generator(latents_edited, skips,  **self.G_kwargs_val)
            
        return images

    # def edit_demo_image(self, image, edit, factor):
    #     from utils import AttrDict
    #     cfg = AttrDict()
    #     edit = edit.lower()
    #     if edit in ['pose', 'age', 'smile']:
    #         cfg.method = 'interfacegan'
    #         cfg.edit = edit
    #         cfg.strength = factor
    #     image = image.to(self.device)
    #     with torch.no_grad():
    #         latents, skips = self.encoder(image)
    #         latents_edited = self.editor.edit(latents, cfg) 
    #     with torch.no_grad():
    #         # Get F space features F_orig, for the original image
    #         skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val)
    #         # Transform F_orig to incoming image
    #         images, _ = self.generator(latents_edited, skips,  **self.G_kwargs_val)
            
    #     return images