File size: 5,216 Bytes
4d6b877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# python3.7
"""Contains the generator class of ProgressiveGAN.

Basically, this class is derived from the `BaseGenerator` class defined in
`base_generator.py`.
"""

import os
import numpy as np

import torch

from . import model_settings
from .pggan_generator_model import PGGANGeneratorModel
from .base_generator import BaseGenerator

__all__ = ['PGGANGenerator']


class PGGANGenerator(BaseGenerator):
  """Defines the generator class of ProgressiveGAN."""

  def __init__(self, model_name, logger=None):
    super().__init__(model_name, logger)
    assert self.gan_type == 'pggan'

  def build(self):
    self.check_attr('fused_scale')
    self.model = PGGANGeneratorModel(resolution=self.resolution,
                                     fused_scale=self.fused_scale,
                                     output_channels=self.output_channels)

  def load(self):
    self.logger.info(f'Loading pytorch model from `{self.model_path}`.')
    self.model.load_state_dict(torch.load(self.model_path))
    self.logger.info(f'Successfully loaded!')
    self.lod = self.model.lod.to(self.cpu_device).tolist()
    self.logger.info(f'  `lod` of the loaded model is {self.lod}.')

  def convert_tf_model(self, test_num=10):
    import sys
    import pickle
    import tensorflow as tf
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    sys.path.append(model_settings.BASE_DIR + '/pggan_tf_official')

    self.logger.info(f'Loading tensorflow model from `{self.tf_model_path}`.')
    tf.InteractiveSession()
    with open(self.tf_model_path, 'rb') as f:
      _, _, tf_model = pickle.load(f)
    self.logger.info(f'Successfully loaded!')

    self.logger.info(f'Converting tensorflow model to pytorch version.')
    tf_vars = dict(tf_model.__getstate__()['variables'])
    state_dict = self.model.state_dict()
    for pth_var_name, tf_var_name in self.model.pth_to_tf_var_mapping.items():
      if 'ToRGB_lod' in tf_var_name:
        lod = int(tf_var_name[len('ToRGB_lod')])
        lod_shift = 10 - int(np.log2(self.resolution))
        tf_var_name = tf_var_name.replace(f'{lod}', f'{lod - lod_shift}')
      if tf_var_name not in tf_vars:
        self.logger.debug(f'Variable `{tf_var_name}` does not exist in '
                          f'tensorflow model.')
        continue
      self.logger.debug(f'  Converting `{tf_var_name}` to `{pth_var_name}`.')
      var = torch.from_numpy(np.array(tf_vars[tf_var_name]))
      if 'weight' in pth_var_name:
        if 'layer0.conv' in pth_var_name:
          var = var.view(var.shape[0], -1, 4, 4).permute(1, 0, 2, 3).flip(2, 3)
        elif 'Conv0_up' in tf_var_name:
          var = var.permute(0, 1, 3, 2)
        else:
          var = var.permute(3, 2, 0, 1)
      state_dict[pth_var_name] = var
    self.logger.info(f'Successfully converted!')

    self.logger.info(f'Saving pytorch model to `{self.model_path}`.')
    torch.save(state_dict, self.model_path)
    self.logger.info(f'Successfully saved!')

    self.load()

    # Official tensorflow model can only run on GPU.
    if test_num <= 0 or not tf.test.is_built_with_cuda():
      return
    self.logger.info(f'Testing conversion results.')
    self.model.eval().to(self.run_device)
    label_dim = tf_model.input_shapes[1][1]
    tf_fake_label = np.zeros((1, label_dim), np.float32)
    total_distance = 0.0
    for i in range(test_num):
      latent_code = self.easy_sample(1)
      tf_output = tf_model.run(latent_code, tf_fake_label)
      pth_output = self.synthesize(latent_code)['image']
      distance = np.average(np.abs(tf_output - pth_output))
      self.logger.debug(f'  Test {i:03d}: distance {distance:.6e}.')
      total_distance += distance
    self.logger.info(f'Average distance is {total_distance / test_num:.6e}.')

  def sample(self, num):
    assert num > 0
    return np.random.randn(num, self.latent_space_dim).astype(np.float32)

  def preprocess(self, latent_codes):
    if not isinstance(latent_codes, np.ndarray):
      raise ValueError(f'Latent codes should be with type `numpy.ndarray`!')

    latent_codes = latent_codes.reshape(-1, self.latent_space_dim)
    norm = np.linalg.norm(latent_codes, axis=1, keepdims=True)
    latent_codes = latent_codes / norm * np.sqrt(self.latent_space_dim)
    return latent_codes.astype(np.float32)

  def synthesize(self, latent_codes):
    if not isinstance(latent_codes, np.ndarray):
      raise ValueError(f'Latent codes should be with type `numpy.ndarray`!')
    latent_codes_shape = latent_codes.shape
    if not (len(latent_codes_shape) == 2 and
            latent_codes_shape[0] <= self.batch_size and
            latent_codes_shape[1] == self.latent_space_dim):
      raise ValueError(f'Latent_codes should be with shape [batch_size, '
                       f'latent_space_dim], where `batch_size` no larger than '
                       f'{self.batch_size}, and `latent_space_dim` equal to '
                       f'{self.latent_space_dim}!\n'
                       f'But {latent_codes_shape} received!')

    zs = torch.from_numpy(latent_codes).type(torch.FloatTensor)
    zs = zs.to(self.run_device)
    images = self.model(zs)
    results = {
        'z': latent_codes,
        'image': self.get_value(images),
    }
    return results