File size: 2,423 Bytes
d7dbcdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12979fc
 
d7dbcdd
 
 
12979fc
d7dbcdd
 
 
 
12979fc
d7dbcdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from torch import nn
from models.networks import latent_transformer
from models.stylegan2.model import Generator
import numpy as np

def get_keys(d, name):
	if 'state_dict' in d:
		d = d['state_dict']
	d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
	return d_filt


class StyleGANControler(nn.Module):

	def __init__(self, opts):
		super(StyleGANControler, self).__init__()
		self.set_opts(opts)
		# Define architecture
		
		if 'ffhq' in self.opts.stylegan_weights:
			self.style_num = 18
		elif 'car' in self.opts.stylegan_weights:
			self.style_num = 16
		elif 'cat' in self.opts.stylegan_weights:
			self.style_num = 14
		elif 'church' in self.opts.stylegan_weights:
			self.style_num = 14
		elif 'anime' in self.opts.stylegan_weights:
			self.style_num = 16
		else:
			self.style_num = 18 #Please modify to adjust network architecture to your pre-trained StyleGAN2
		
		self.encoder = self.set_encoder()
		if self.style_num==18:
			self.decoder = Generator(1024, 512, 8, channel_multiplier=2) 
		elif self.style_num==16:
			self.decoder = Generator(512, 512, 8, channel_multiplier=2)
		elif self.style_num==14:
			self.decoder = Generator(256, 512, 8, channel_multiplier=2)
			
		self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
		
		# Load weights if needed
		self.load_weights()

	def set_encoder(self):
		encoder = latent_transformer.Network(self.opts)
		return encoder

	def load_weights(self):
		if self.opts.checkpoint_path is not None:
			print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
			ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
			self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
			self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
			self.__load_latent_avg(ckpt)
		else:
			print('Loading decoder weights from pretrained!')
			ckpt = torch.load(self.opts.stylegan_weights)
			self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
			self.__load_latent_avg(ckpt, repeat=self.opts.style_num)
		
	def set_opts(self, opts):
		self.opts = opts

	def __load_latent_avg(self, ckpt, repeat=None):
		if 'latent_avg' in ckpt:
			self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
			if repeat is not None:
				self.latent_avg = self.latent_avg.repeat(repeat, 1)
		else:
			self.latent_avg = None