Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from models.StyleCLIP.mapper import latent_mappers | |
from models.StyleCLIP.models.stylegan2.model import Generator | |
def get_keys(d, name): | |
if 'state_dict' in d: | |
d = d['state_dict'] | |
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} | |
return d_filt | |
class StyleCLIPMapper(nn.Module): | |
def __init__(self, opts, run_id): | |
super(StyleCLIPMapper, self).__init__() | |
self.opts = opts | |
# Define architecture | |
self.mapper = self.set_mapper() | |
self.run_id = run_id | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
# Load weights if needed | |
self.load_weights() | |
def set_mapper(self): | |
if self.opts.mapper_type == 'SingleMapper': | |
mapper = latent_mappers.SingleMapper(self.opts) | |
elif self.opts.mapper_type == 'LevelsMapper': | |
mapper = latent_mappers.LevelsMapper(self.opts) | |
else: | |
raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) | |
return mapper | |
def load_weights(self): | |
if self.opts.checkpoint_path is not None: | |
print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) | |
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') | |
self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) | |
def set_G(self, new_G): | |
self.decoder = new_G | |
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, | |
inject_latent=None, return_latents=False, alpha=None): | |
if input_code: | |
codes = x | |
else: | |
codes = self.mapper(x) | |
if latent_mask is not None: | |
for i in latent_mask: | |
if inject_latent is not None: | |
if alpha is not None: | |
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] | |
else: | |
codes[:, i] = inject_latent[:, i] | |
else: | |
codes[:, i] = 0 | |
input_is_latent = not input_code | |
images = self.decoder.synthesis(codes, noise_mode='const') | |
result_latent = None | |
# images, result_latent = self.decoder([codes], | |
# input_is_latent=input_is_latent, | |
# randomize_noise=randomize_noise, | |
# return_latents=return_latents) | |
if resize: | |
images = self.face_pool(images) | |
if return_latents: | |
return images, result_latent | |
else: | |
return images | |