File size: 1,882 Bytes
6709fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import csv
from options import Settings
import os

class GanSpace():
    def __init__(self) -> None:

        self.gan_space_configs = {}
        
        with open(os.path.join(Settings.ganspace_directions, 'ganspace_configs.csv'), "r") as f:
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                key = row.pop(0)
                self.gan_space_configs[key] = list(map(int, row))

    def edit(self, latent, cfg):
        with torch.no_grad():
            self.load_ganspace_pca()
            gan_space_config = self.gan_space_configs[cfg.edit]
            gan_space_config[-1] = cfg.strength
            return self.edit_ganspace(latent, gan_space_config)

    def load_ganspace_pca(self):
        try:   # Check if loaded
            getattr(self, f"pca")
        except:
            pca = torch.load(os.path.join(Settings.ganspace_directions, 'ffhq_pca.pt'))
            setattr(self, f"pca", pca)
        

    def edit_ganspace(self, latents, config):
        edit_latents = []
        pca_idx, start, end, strength = config
        for latent in latents:
            delta = self.get_delta( latent, pca_idx, strength)
            delta_padded = torch.zeros(latent.shape).to(Settings.device)
            delta_padded[start:end] += delta.repeat(end - start, 1)
            edit_latents.append(latent + delta_padded)
        return torch.stack(edit_latents)

    def get_delta(self, latent, idx, strength):
        # pca: ganspace checkpoint. latent: (16, 512) w+
        w_centered = latent - self.pca['mean'].to(Settings.device)
        lat_comp = self.pca['comp'].to(Settings.device)
        lat_std = self.pca['std'].to(Settings.device)
        w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
        delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
        return delta