Spaces:
Runtime error
Runtime error
File size: 2,423 Bytes
d7dbcdd 12979fc d7dbcdd 12979fc d7dbcdd 12979fc d7dbcdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
from torch import nn
from models.networks import latent_transformer
from models.stylegan2.model import Generator
import numpy as np
def get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
class StyleGANControler(nn.Module):
def __init__(self, opts):
super(StyleGANControler, self).__init__()
self.set_opts(opts)
# Define architecture
if 'ffhq' in self.opts.stylegan_weights:
self.style_num = 18
elif 'car' in self.opts.stylegan_weights:
self.style_num = 16
elif 'cat' in self.opts.stylegan_weights:
self.style_num = 14
elif 'church' in self.opts.stylegan_weights:
self.style_num = 14
elif 'anime' in self.opts.stylegan_weights:
self.style_num = 16
else:
self.style_num = 18 #Please modify to adjust network architecture to your pre-trained StyleGAN2
self.encoder = self.set_encoder()
if self.style_num==18:
self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
elif self.style_num==16:
self.decoder = Generator(512, 512, 8, channel_multiplier=2)
elif self.style_num==14:
self.decoder = Generator(256, 512, 8, channel_multiplier=2)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
def set_encoder(self):
encoder = latent_transformer.Network(self.opts)
return encoder
def load_weights(self):
if self.opts.checkpoint_path is not None:
print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
else:
print('Loading decoder weights from pretrained!')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
self.__load_latent_avg(ckpt, repeat=self.opts.style_num)
def set_opts(self, opts):
self.opts = opts
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None
|