Spaces:
Runtime error
Runtime error
File size: 5,216 Bytes
4d6b877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# python3.7
"""Contains the generator class of ProgressiveGAN.
Basically, this class is derived from the `BaseGenerator` class defined in
`base_generator.py`.
"""
import os
import numpy as np
import torch
from . import model_settings
from .pggan_generator_model import PGGANGeneratorModel
from .base_generator import BaseGenerator
__all__ = ['PGGANGenerator']
class PGGANGenerator(BaseGenerator):
"""Defines the generator class of ProgressiveGAN."""
def __init__(self, model_name, logger=None):
super().__init__(model_name, logger)
assert self.gan_type == 'pggan'
def build(self):
self.check_attr('fused_scale')
self.model = PGGANGeneratorModel(resolution=self.resolution,
fused_scale=self.fused_scale,
output_channels=self.output_channels)
def load(self):
self.logger.info(f'Loading pytorch model from `{self.model_path}`.')
self.model.load_state_dict(torch.load(self.model_path))
self.logger.info(f'Successfully loaded!')
self.lod = self.model.lod.to(self.cpu_device).tolist()
self.logger.info(f' `lod` of the loaded model is {self.lod}.')
def convert_tf_model(self, test_num=10):
import sys
import pickle
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
sys.path.append(model_settings.BASE_DIR + '/pggan_tf_official')
self.logger.info(f'Loading tensorflow model from `{self.tf_model_path}`.')
tf.InteractiveSession()
with open(self.tf_model_path, 'rb') as f:
_, _, tf_model = pickle.load(f)
self.logger.info(f'Successfully loaded!')
self.logger.info(f'Converting tensorflow model to pytorch version.')
tf_vars = dict(tf_model.__getstate__()['variables'])
state_dict = self.model.state_dict()
for pth_var_name, tf_var_name in self.model.pth_to_tf_var_mapping.items():
if 'ToRGB_lod' in tf_var_name:
lod = int(tf_var_name[len('ToRGB_lod')])
lod_shift = 10 - int(np.log2(self.resolution))
tf_var_name = tf_var_name.replace(f'{lod}', f'{lod - lod_shift}')
if tf_var_name not in tf_vars:
self.logger.debug(f'Variable `{tf_var_name}` does not exist in '
f'tensorflow model.')
continue
self.logger.debug(f' Converting `{tf_var_name}` to `{pth_var_name}`.')
var = torch.from_numpy(np.array(tf_vars[tf_var_name]))
if 'weight' in pth_var_name:
if 'layer0.conv' in pth_var_name:
var = var.view(var.shape[0], -1, 4, 4).permute(1, 0, 2, 3).flip(2, 3)
elif 'Conv0_up' in tf_var_name:
var = var.permute(0, 1, 3, 2)
else:
var = var.permute(3, 2, 0, 1)
state_dict[pth_var_name] = var
self.logger.info(f'Successfully converted!')
self.logger.info(f'Saving pytorch model to `{self.model_path}`.')
torch.save(state_dict, self.model_path)
self.logger.info(f'Successfully saved!')
self.load()
# Official tensorflow model can only run on GPU.
if test_num <= 0 or not tf.test.is_built_with_cuda():
return
self.logger.info(f'Testing conversion results.')
self.model.eval().to(self.run_device)
label_dim = tf_model.input_shapes[1][1]
tf_fake_label = np.zeros((1, label_dim), np.float32)
total_distance = 0.0
for i in range(test_num):
latent_code = self.easy_sample(1)
tf_output = tf_model.run(latent_code, tf_fake_label)
pth_output = self.synthesize(latent_code)['image']
distance = np.average(np.abs(tf_output - pth_output))
self.logger.debug(f' Test {i:03d}: distance {distance:.6e}.')
total_distance += distance
self.logger.info(f'Average distance is {total_distance / test_num:.6e}.')
def sample(self, num):
assert num > 0
return np.random.randn(num, self.latent_space_dim).astype(np.float32)
def preprocess(self, latent_codes):
if not isinstance(latent_codes, np.ndarray):
raise ValueError(f'Latent codes should be with type `numpy.ndarray`!')
latent_codes = latent_codes.reshape(-1, self.latent_space_dim)
norm = np.linalg.norm(latent_codes, axis=1, keepdims=True)
latent_codes = latent_codes / norm * np.sqrt(self.latent_space_dim)
return latent_codes.astype(np.float32)
def synthesize(self, latent_codes):
if not isinstance(latent_codes, np.ndarray):
raise ValueError(f'Latent codes should be with type `numpy.ndarray`!')
latent_codes_shape = latent_codes.shape
if not (len(latent_codes_shape) == 2 and
latent_codes_shape[0] <= self.batch_size and
latent_codes_shape[1] == self.latent_space_dim):
raise ValueError(f'Latent_codes should be with shape [batch_size, '
f'latent_space_dim], where `batch_size` no larger than '
f'{self.batch_size}, and `latent_space_dim` equal to '
f'{self.latent_space_dim}!\n'
f'But {latent_codes_shape} received!')
zs = torch.from_numpy(latent_codes).type(torch.FloatTensor)
zs = zs.to(self.run_device)
images = self.model(zs)
results = {
'z': latent_codes,
'image': self.get_value(images),
}
return results
|