farrell236 commited on
Commit
2aa6515
1 Parent(s): 3699172

Upload 37 files

Browse files
Files changed (37) hide show
  1. assets/GauGAN.png +0 -0
  2. assets/RetinaGAN_pipeline.png +0 -0
  3. assets/cStyleGAN.png +0 -0
  4. assets/sample.jpeg +0 -0
  5. assets/sample_images/image_class_0_batch_0_sample_0.png +0 -0
  6. assets/sample_images/image_class_0_batch_0_sample_1.png +0 -0
  7. assets/sample_images/image_class_0_batch_1_sample_0.png +0 -0
  8. assets/sample_images/image_class_0_batch_1_sample_1.png +0 -0
  9. assets/sample_images/image_class_1_batch_0_sample_0.png +0 -0
  10. assets/sample_images/image_class_1_batch_0_sample_1.png +0 -0
  11. assets/sample_images/image_class_1_batch_1_sample_0.png +0 -0
  12. assets/sample_images/image_class_1_batch_1_sample_1.png +0 -0
  13. assets/sample_images/image_class_2_batch_0_sample_0.png +0 -0
  14. assets/sample_images/image_class_2_batch_0_sample_1.png +0 -0
  15. assets/sample_images/image_class_2_batch_1_sample_0.png +0 -0
  16. assets/sample_images/image_class_2_batch_1_sample_1.png +0 -0
  17. assets/sample_images/image_class_3_batch_0_sample_0.png +0 -0
  18. assets/sample_images/image_class_3_batch_0_sample_1.png +0 -0
  19. assets/sample_images/image_class_3_batch_1_sample_0.png +0 -0
  20. assets/sample_images/image_class_3_batch_1_sample_1.png +0 -0
  21. assets/sample_images/image_class_4_batch_0_sample_0.png +0 -0
  22. assets/sample_images/image_class_4_batch_0_sample_1.png +0 -0
  23. assets/sample_images/image_class_4_batch_1_sample_0.png +0 -0
  24. assets/sample_images/image_class_4_batch_1_sample_1.png +0 -0
  25. assets/sample_images/mask_class_0_batch_0.png +0 -0
  26. assets/sample_images/mask_class_0_batch_1.png +0 -0
  27. assets/sample_images/mask_class_1_batch_0.png +0 -0
  28. assets/sample_images/mask_class_1_batch_1.png +0 -0
  29. assets/sample_images/mask_class_2_batch_0.png +0 -0
  30. assets/sample_images/mask_class_2_batch_1.png +0 -0
  31. assets/sample_images/mask_class_3_batch_0.png +0 -0
  32. assets/sample_images/mask_class_3_batch_1.png +0 -0
  33. assets/sample_images/mask_class_4_batch_0.png +0 -0
  34. assets/sample_images/mask_class_4_batch_1.png +0 -0
  35. models/cstylegan.py +530 -0
  36. models/gaugan.py +403 -0
  37. utils.py +71 -0
assets/GauGAN.png ADDED
assets/RetinaGAN_pipeline.png ADDED
assets/cStyleGAN.png ADDED
assets/sample.jpeg ADDED
assets/sample_images/image_class_0_batch_0_sample_0.png ADDED
assets/sample_images/image_class_0_batch_0_sample_1.png ADDED
assets/sample_images/image_class_0_batch_1_sample_0.png ADDED
assets/sample_images/image_class_0_batch_1_sample_1.png ADDED
assets/sample_images/image_class_1_batch_0_sample_0.png ADDED
assets/sample_images/image_class_1_batch_0_sample_1.png ADDED
assets/sample_images/image_class_1_batch_1_sample_0.png ADDED
assets/sample_images/image_class_1_batch_1_sample_1.png ADDED
assets/sample_images/image_class_2_batch_0_sample_0.png ADDED
assets/sample_images/image_class_2_batch_0_sample_1.png ADDED
assets/sample_images/image_class_2_batch_1_sample_0.png ADDED
assets/sample_images/image_class_2_batch_1_sample_1.png ADDED
assets/sample_images/image_class_3_batch_0_sample_0.png ADDED
assets/sample_images/image_class_3_batch_0_sample_1.png ADDED
assets/sample_images/image_class_3_batch_1_sample_0.png ADDED
assets/sample_images/image_class_3_batch_1_sample_1.png ADDED
assets/sample_images/image_class_4_batch_0_sample_0.png ADDED
assets/sample_images/image_class_4_batch_0_sample_1.png ADDED
assets/sample_images/image_class_4_batch_1_sample_0.png ADDED
assets/sample_images/image_class_4_batch_1_sample_1.png ADDED
assets/sample_images/mask_class_0_batch_0.png ADDED
assets/sample_images/mask_class_0_batch_1.png ADDED
assets/sample_images/mask_class_1_batch_0.png ADDED
assets/sample_images/mask_class_1_batch_1.png ADDED
assets/sample_images/mask_class_2_batch_0.png ADDED
assets/sample_images/mask_class_2_batch_1.png ADDED
assets/sample_images/mask_class_3_batch_0.png ADDED
assets/sample_images/mask_class_3_batch_1.png ADDED
assets/sample_images/mask_class_4_batch_0.png ADDED
assets/sample_images/mask_class_4_batch_1.png ADDED
models/cstylegan.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is based on the StyleGAN by Cheong et. al
2
+ # https://keras.io/examples/generative/stylegan/
3
+
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ from tensorflow import keras
8
+ from tensorflow.keras import layers
9
+ from tensorflow.keras.models import Sequential
10
+ from tensorflow_addons.layers import InstanceNormalization
11
+
12
+
13
+ def log2(x):
14
+ return int(np.log2(x))
15
+
16
+
17
+ # we use different batch size for different resolution, so larger image size
18
+ # could fit into GPU memory. The keys is image resolution in log2
19
+ batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
20
+ # We adjust the train step accordingly
21
+ train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
22
+
23
+
24
+ def fade_in(alpha, a, b):
25
+ return alpha * a + (1.0 - alpha) * b
26
+
27
+
28
+ def wasserstein_loss(y_true, y_pred):
29
+ return -tf.reduce_mean(y_true * y_pred)
30
+
31
+
32
+ def pixel_norm(x, epsilon=1e-8):
33
+ return x / tf.math.sqrt(tf.reduce_mean(x ** 2, axis=-1, keepdims=True) + epsilon)
34
+
35
+
36
+ def minibatch_std(input_tensor, epsilon=1e-8):
37
+ n, h, w, c = tf.shape(input_tensor)
38
+ group_size = tf.minimum(4, n)
39
+ x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
40
+ group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
41
+ group_std = tf.sqrt(group_var + epsilon)
42
+ avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
43
+ x = tf.tile(avg_std, [group_size, h, w, 1])
44
+ return tf.concat([input_tensor, x], axis=-1)
45
+
46
+
47
+ class EqualizedConv(layers.Layer):
48
+ def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
49
+ super(EqualizedConv, self).__init__(**kwargs)
50
+ self.kernel = kernel
51
+ self.out_channels = out_channels
52
+ self.gain = gain
53
+ self.pad = kernel != 1
54
+
55
+ def build(self, input_shape):
56
+ self.in_channels = input_shape[-1]
57
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
58
+ self.w = self.add_weight(
59
+ shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
60
+ initializer=initializer,
61
+ trainable=True,
62
+ name="kernel",
63
+ )
64
+ self.b = self.add_weight(
65
+ shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
66
+ )
67
+ fan_in = self.kernel * self.kernel * self.in_channels
68
+ self.scale = tf.sqrt(self.gain / fan_in)
69
+
70
+ def call(self, inputs):
71
+ if self.pad:
72
+ x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
73
+ else:
74
+ x = inputs
75
+ output = (
76
+ tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
77
+ )
78
+ return output
79
+
80
+
81
+ class EqualizedDense(layers.Layer):
82
+ def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
83
+ super(EqualizedDense, self).__init__(**kwargs)
84
+ self.units = units
85
+ self.gain = gain
86
+ self.learning_rate_multiplier = learning_rate_multiplier
87
+
88
+ def build(self, input_shape):
89
+ self.in_channels = input_shape[-1]
90
+ initializer = keras.initializers.RandomNormal(
91
+ mean=0.0, stddev=1.0 / self.learning_rate_multiplier
92
+ )
93
+ self.w = self.add_weight(
94
+ shape=[self.in_channels, self.units],
95
+ initializer=initializer,
96
+ trainable=True,
97
+ name="kernel",
98
+ )
99
+ self.b = self.add_weight(
100
+ shape=(self.units,), initializer="zeros", trainable=True, name="bias"
101
+ )
102
+ fan_in = self.in_channels
103
+ self.scale = tf.sqrt(self.gain / fan_in)
104
+
105
+ def call(self, inputs):
106
+ output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
107
+ return output * self.learning_rate_multiplier
108
+
109
+
110
+ class AddNoise(layers.Layer):
111
+ def build(self, input_shape):
112
+ n, h, w, c = input_shape[0]
113
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
114
+ self.b = self.add_weight(
115
+ shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
116
+ )
117
+
118
+ def call(self, inputs):
119
+ x, noise = inputs
120
+ output = x + self.b * noise
121
+ return output
122
+
123
+
124
+ class AdaIN(layers.Layer):
125
+ def __init__(self, gain=1, **kwargs):
126
+ super(AdaIN, self).__init__(**kwargs)
127
+ self.gain = gain
128
+
129
+ def build(self, input_shapes):
130
+ x_shape = input_shapes[0]
131
+ w_shape = input_shapes[1]
132
+
133
+ self.w_channels = w_shape[-1]
134
+ self.x_channels = x_shape[-1]
135
+
136
+ self.dense_1 = EqualizedDense(self.x_channels, gain=1)
137
+ self.dense_2 = EqualizedDense(self.x_channels, gain=1)
138
+
139
+ def call(self, inputs):
140
+ x, w = inputs
141
+ ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
142
+ yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
143
+ return ys * x + yb
144
+
145
+
146
+ def Mapping(num_stages, input_shape=512):
147
+ z = layers.Input(shape=(input_shape,))
148
+ w = pixel_norm(z)
149
+ class_embedding = layers.Input(shape=512)
150
+ for i in range(8):
151
+ w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
152
+ w = w + class_embedding
153
+ w = layers.LeakyReLU(0.2)(w)
154
+ w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
155
+ return keras.Model([z, class_embedding], w, name="mapping")
156
+
157
+
158
+ class Generator:
159
+ def __init__(self, start_res_log2, target_res_log2):
160
+ self.start_res_log2 = start_res_log2
161
+ self.target_res_log2 = target_res_log2
162
+ self.num_stages = target_res_log2 - start_res_log2 + 1
163
+ # list of generator blocks at increasing resolution
164
+ self.g_blocks = []
165
+ # list of layers to convert g_block activation to RGB
166
+ self.to_rgb = []
167
+ # list of noise input of different resolutions into g_blocks
168
+ self.noise_inputs = []
169
+ # filter size to use at each stage, keys are log2(resolution)
170
+ self.filter_nums = {
171
+ 0: 512,
172
+ 1: 512,
173
+ 2: 512, # 4x4
174
+ 3: 512, # 8x8
175
+ 4: 512, # 16x16
176
+ 5: 512, # 32x32
177
+ 6: 256, # 64x64
178
+ 7: 128, # 128x128
179
+ 8: 64, # 256x256
180
+ 9: 32, # 512x512
181
+ 10: 16,
182
+ } # 1024x1024
183
+
184
+ start_res = 2 ** start_res_log2
185
+ self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
186
+ self.g_input = layers.Input(self.input_shape, name="generator_input")
187
+
188
+ for i in range(start_res_log2, target_res_log2 + 1):
189
+ filter_num = self.filter_nums[i]
190
+ res = 2 ** i
191
+ self.noise_inputs.append(
192
+ layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
193
+ )
194
+ to_rgb = Sequential(
195
+ [
196
+ layers.InputLayer(input_shape=(res, res, filter_num)),
197
+ EqualizedConv(7, 1, gain=1), # CHANGE NO OF CHANNELS
198
+ ],
199
+ name=f"to_rgb_{res}x{res}",
200
+ )
201
+ self.to_rgb.append(to_rgb)
202
+ is_base = i == self.start_res_log2
203
+ if is_base:
204
+ input_shape = (res, res, self.filter_nums[i - 1])
205
+ else:
206
+ input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
207
+ g_block = self.build_block(
208
+ filter_num, res=res, input_shape=input_shape, is_base=is_base
209
+ )
210
+ self.g_blocks.append(g_block)
211
+
212
+ def build_block(self, filter_num, res, input_shape, is_base):
213
+ input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
214
+ noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
215
+ w = layers.Input(shape=512)
216
+ x = input_tensor
217
+
218
+ if not is_base:
219
+ x = layers.UpSampling2D((2, 2))(x)
220
+ x = EqualizedConv(filter_num, 3)(x)
221
+
222
+ x = AddNoise()([x, noise])
223
+ x = layers.LeakyReLU(0.2)(x)
224
+ x = InstanceNormalization()(x)
225
+ x = AdaIN()([x, w])
226
+
227
+ x = EqualizedConv(filter_num, 3)(x)
228
+ x = AddNoise()([x, noise])
229
+ x = layers.LeakyReLU(0.2)(x)
230
+ x = InstanceNormalization()(x)
231
+ x = AdaIN()([x, w])
232
+ return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
233
+
234
+ def grow(self, res_log2):
235
+ res = 2 ** res_log2
236
+
237
+ num_stages = res_log2 - self.start_res_log2 + 1
238
+ w = layers.Input(shape=(self.num_stages, 512), name="w")
239
+
240
+ alpha = layers.Input(shape=(1), name="g_alpha")
241
+ x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
242
+
243
+ if num_stages == 1:
244
+ rgb = self.to_rgb[0](x)
245
+ else:
246
+ for i in range(1, num_stages - 1):
247
+
248
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
249
+
250
+ old_rgb = self.to_rgb[num_stages - 2](x)
251
+ old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
252
+
253
+ i = num_stages - 1
254
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
255
+
256
+ new_rgb = self.to_rgb[i](x)
257
+
258
+ rgb = fade_in(alpha[0], new_rgb, old_rgb)
259
+
260
+ return keras.Model(
261
+ [self.g_input, w, self.noise_inputs, alpha],
262
+ rgb,
263
+ name=f"generator_{res}_x_{res}",
264
+ )
265
+
266
+
267
+ class Discriminator:
268
+ def __init__(self, start_res_log2, target_res_log2):
269
+ self.start_res_log2 = start_res_log2
270
+ self.target_res_log2 = target_res_log2
271
+ self.num_stages = target_res_log2 - start_res_log2 + 1
272
+ # filter size to use at each stage, keys are log2(resolution)
273
+ self.filter_nums = {
274
+ 0: 512,
275
+ 1: 512,
276
+ 2: 512, # 4x4
277
+ 3: 512, # 8x8
278
+ 4: 512, # 16x16
279
+ 5: 512, # 32x32
280
+ 6: 256, # 64x64
281
+ 7: 128, # 128x128
282
+ 8: 64, # 256x256
283
+ 9: 32, # 512x512
284
+ 10: 16,
285
+ } # 1024x1024
286
+ # list of discriminator blocks at increasing resolution
287
+ self.d_blocks = []
288
+ # list of layers to convert RGB into activation for d_blocks inputs
289
+ self.from_rgb = []
290
+ # Conditional embedding
291
+ # self.embedding = layers.Embedding(5, 256)
292
+
293
+ for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
294
+ res = 2 ** res_log2
295
+ filter_num = self.filter_nums[res_log2]
296
+ from_rgb = Sequential(
297
+ [
298
+ layers.InputLayer(
299
+ input_shape=(res, res, 7), name=f"from_rgb_input_{res}" # CHANGE NO OF CHANNELS
300
+ ),
301
+ EqualizedConv(filter_num, 1),
302
+ layers.LeakyReLU(0.2),
303
+ ],
304
+ name=f"from_rgb_{res}",
305
+ )
306
+
307
+ self.from_rgb.append(from_rgb)
308
+
309
+ input_shape = (res, res, filter_num)
310
+ if len(self.d_blocks) == 0:
311
+ d_block = self.build_base(filter_num, res)
312
+ else:
313
+ d_block = self.build_block(
314
+ filter_num, self.filter_nums[res_log2 - 1], res
315
+ )
316
+
317
+ self.d_blocks.append(d_block)
318
+
319
+ def build_base(self, filter_num, res):
320
+ input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
321
+ x = minibatch_std(input_tensor)
322
+ x = EqualizedConv(filter_num, 3)(x)
323
+ x = layers.LeakyReLU(0.2)(x)
324
+ x = layers.Flatten()(x)
325
+ x = EqualizedDense(filter_num)(x)
326
+ x = layers.LeakyReLU(0.2)(x)
327
+ x = EqualizedDense(1)(x)
328
+ return keras.Model(input_tensor, x, name=f"d_{res}")
329
+
330
+ def build_block(self, filter_num_1, filter_num_2, res):
331
+ input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
332
+ x = EqualizedConv(filter_num_1, 3)(input_tensor)
333
+ x = layers.LeakyReLU(0.2)(x)
334
+ x = EqualizedConv(filter_num_2)(x)
335
+ x = layers.LeakyReLU(0.2)(x)
336
+ x = layers.AveragePooling2D((2, 2))(x)
337
+ return keras.Model(input_tensor, x, name=f"d_{res}")
338
+
339
+ def grow(self, res_log2):
340
+ res = 2 ** res_log2
341
+ idx = res_log2 - self.start_res_log2
342
+ alpha = layers.Input(shape=(1), name="d_alpha")
343
+ input_image = layers.Input(shape=(res, res, 7), name="input_image") # CHANGE NO OF CHANNELS
344
+ class_embedding = layers.Input(shape=512, name="class_embedding")
345
+ x = self.from_rgb[idx](input_image)
346
+ x = AdaIN()([x, class_embedding])
347
+ x = self.d_blocks[idx](x)
348
+ if idx > 0:
349
+ idx -= 1
350
+ downsized_image = layers.AveragePooling2D((2, 2))(input_image)
351
+ y = self.from_rgb[idx](downsized_image)
352
+ x = fade_in(alpha[0], x, y)
353
+
354
+ for i in range(idx, -1, -1):
355
+ x = AdaIN()([x, class_embedding])
356
+ x = self.d_blocks[i](x)
357
+ return keras.Model([input_image, class_embedding, alpha], x, name=f"discriminator_{res}_x_{res}")
358
+
359
+
360
+ class cStyleGAN(tf.keras.Model):
361
+ def __init__(self, z_dim=512, target_res=64, start_res=4):
362
+ super(cStyleGAN, self).__init__()
363
+ self.z_dim = z_dim
364
+
365
+ self.target_res_log2 = log2(target_res)
366
+ self.start_res_log2 = log2(start_res)
367
+ self.current_res_log2 = self.target_res_log2
368
+ self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
369
+
370
+ self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
371
+
372
+ self.mapping = Mapping(num_stages=self.num_stages)
373
+ self.embedding = layers.Embedding(5, 512)
374
+ self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
375
+ self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
376
+ self.g_input_shape = self.g_builder.input_shape
377
+
378
+ self.phase = None
379
+ self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
380
+
381
+ self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
382
+
383
+ def grow_model(self, res):
384
+ tf.keras.backend.clear_session()
385
+ res_log2 = log2(res)
386
+ self.generator = self.g_builder.grow(res_log2)
387
+ self.discriminator = self.d_builder.grow(res_log2)
388
+ self.current_res_log2 = res_log2
389
+ print(f"\nModel resolution:{res}x{res}")
390
+
391
+ def compile(
392
+ self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
393
+ ):
394
+ self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
395
+ self.steps_per_epoch = steps_per_epoch
396
+ if res != 2 ** self.current_res_log2:
397
+ self.grow_model(res)
398
+ self.d_optimizer = d_optimizer
399
+ self.g_optimizer = g_optimizer
400
+
401
+ self.train_step_counter.assign(0)
402
+ self.phase = phase
403
+ self.d_loss_metric = keras.metrics.Mean(name="d_loss")
404
+ self.g_loss_metric = keras.metrics.Mean(name="g_loss")
405
+ super(cStyleGAN, self).compile(*args, **kwargs)
406
+
407
+ @property
408
+ def metrics(self):
409
+ return [self.d_loss_metric, self.g_loss_metric]
410
+
411
+ def generate_noise(self, batch_size):
412
+ noise = [
413
+ tf.random.normal((batch_size, 2 ** res, 2 ** res, 1))
414
+ for res in range(self.start_res_log2, self.target_res_log2 + 1)
415
+ ]
416
+ return noise
417
+
418
+ def gradient_loss(self, grad):
419
+ loss = tf.square(grad)
420
+ loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
421
+ loss = tf.sqrt(loss)
422
+ loss = tf.reduce_mean(tf.square(loss - 1))
423
+ return loss
424
+
425
+ def train_step(self, data_tuple):
426
+
427
+ real_images, class_label = data_tuple
428
+
429
+ self.train_step_counter.assign_add(1)
430
+
431
+ if self.phase == "TRANSITION":
432
+ self.alpha.assign(
433
+ tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
434
+ )
435
+ elif self.phase == "STABLE":
436
+ self.alpha.assign(1.0)
437
+ else:
438
+ raise NotImplementedError
439
+ alpha = tf.expand_dims(self.alpha, 0)
440
+ batch_size = tf.shape(real_images)[0]
441
+ real_labels = tf.ones(batch_size)
442
+ fake_labels = -tf.ones(batch_size)
443
+
444
+ z = tf.random.normal((batch_size, self.z_dim))
445
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
446
+ noise = self.generate_noise(batch_size)
447
+
448
+ # generator
449
+ with tf.GradientTape() as g_tape:
450
+ class_embedding = self.embedding(class_label)
451
+ w = self.mapping([z, class_embedding])
452
+ fake_images = self.generator([const_input, w, noise, alpha])
453
+ pred_fake = self.discriminator([fake_images, class_embedding, alpha])
454
+ g_loss = wasserstein_loss(real_labels, pred_fake)
455
+
456
+ trainable_weights = (
457
+ self.embedding.trainable_weights + self.mapping.trainable_weights + self.generator.trainable_weights
458
+ )
459
+ gradients = g_tape.gradient(g_loss, trainable_weights)
460
+ self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
461
+
462
+ # discriminator
463
+ with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
464
+ # class_embedding = self.embedding(class_label)
465
+ # forward pass
466
+ pred_fake = self.discriminator([fake_images, class_embedding, alpha])
467
+ pred_real = self.discriminator([real_images, class_embedding, alpha])
468
+
469
+ epsilon = tf.random.uniform((batch_size, 1, 1, 1))
470
+ interpolates = epsilon * real_images + (1 - epsilon) * fake_images
471
+ gradient_tape.watch(interpolates)
472
+ pred_fake_grad = self.discriminator([interpolates, class_embedding, alpha])
473
+
474
+ # calculate losses
475
+ loss_fake = wasserstein_loss(fake_labels, pred_fake)
476
+ loss_real = wasserstein_loss(real_labels, pred_real)
477
+ loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
478
+
479
+ # gradient penalty
480
+ gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
481
+ gradient_penalty = self.loss_weights[
482
+ "gradient_penalty"
483
+ ] * self.gradient_loss(gradients_fake)
484
+
485
+ # drift loss
486
+ all_pred = tf.concat([pred_fake, pred_real], axis=0)
487
+ drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred ** 2)
488
+
489
+ d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
490
+
491
+ gradients = total_tape.gradient(
492
+ d_loss, self.discriminator.trainable_weights
493
+ )
494
+ self.d_optimizer.apply_gradients(
495
+ zip(gradients, self.discriminator.trainable_weights)
496
+ )
497
+
498
+ # Update metrics
499
+ self.d_loss_metric.update_state(d_loss)
500
+ self.g_loss_metric.update_state(g_loss)
501
+ return {
502
+ "d_loss": self.d_loss_metric.result(),
503
+ "g_loss": self.g_loss_metric.result(),
504
+ }
505
+
506
+ def call(self, inputs: dict()):
507
+ style_code = inputs.get("style_code", None)
508
+ z = inputs.get("z", None)
509
+ noise = inputs.get("noise", None)
510
+ class_label = inputs.get("class_label", 0)
511
+ batch_size = inputs.get("batch_size", 1)
512
+ alpha = inputs.get("alpha", 1.0)
513
+ alpha = tf.expand_dims(alpha, 0)
514
+ class_embedding = self.embedding(class_label)
515
+ if style_code is None:
516
+ if z is None:
517
+ z = tf.random.normal((batch_size, self.z_dim))
518
+ style_code = self.mapping([z, class_embedding])
519
+
520
+ if noise is None:
521
+ noise = self.generate_noise(batch_size)
522
+
523
+ # self.alpha.assign(alpha)
524
+
525
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
526
+ images = self.generator([const_input, style_code, noise, alpha])
527
+ # images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
528
+ images = tf.clip_by_value((images * 0.5 + 0.5) * 255, 0, 255)
529
+
530
+ return images
models/gaugan.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is based on the GauGAN by Rakshit et. al
2
+ # https://keras.io/examples/generative/gaugan/
3
+
4
+ import tensorflow as tf
5
+ import tensorflow_addons as tfa
6
+
7
+
8
+ class SPADE(tf.keras.layers.Layer):
9
+ def __init__(self, filters, epsilon=1e-5, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.epsilon = epsilon
12
+ self.conv = tf.keras.layers.Conv2D(128, 3, padding="same", activation="relu")
13
+ self.conv_gamma = tf.keras.layers.Conv2D(filters, 3, padding="same")
14
+ self.conv_beta = tf.keras.layers.Conv2D(filters, 3, padding="same")
15
+
16
+ def build(self, input_shape):
17
+ self.resize_shape = input_shape[1:3]
18
+
19
+ def call(self, input_tensor, raw_mask):
20
+ mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
21
+ x = self.conv(mask)
22
+ gamma = self.conv_gamma(x)
23
+ beta = self.conv_beta(x)
24
+ mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
25
+ std = tf.sqrt(var + self.epsilon)
26
+ normalized = (input_tensor - mean) / std
27
+ output = gamma * normalized + beta
28
+ return output
29
+
30
+
31
+ class ResBlock(tf.keras.layers.Layer):
32
+ def __init__(self, filters, **kwargs):
33
+ super().__init__(**kwargs)
34
+ self.filters = filters
35
+
36
+ def build(self, input_shape):
37
+ input_filter = input_shape[-1]
38
+ self.spade_1 = SPADE(input_filter)
39
+ self.spade_2 = SPADE(self.filters)
40
+ self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
41
+ self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
42
+ self.learned_skip = False
43
+
44
+ if self.filters != input_filter:
45
+ self.learned_skip = True
46
+ self.spade_3 = SPADE(input_filter)
47
+ self.conv_3 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
48
+
49
+ def call(self, input_tensor, mask):
50
+ x = self.spade_1(input_tensor, mask)
51
+ x = self.conv_1(tf.nn.leaky_relu(x, 0.2))
52
+ x = self.spade_2(x, mask)
53
+ x = self.conv_2(tf.nn.leaky_relu(x, 0.2))
54
+ skip = (
55
+ self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2))
56
+ if self.learned_skip
57
+ else input_tensor
58
+ )
59
+ output = skip + x
60
+ return output
61
+
62
+
63
+ class GaussianSampler(tf.keras.layers.Layer):
64
+ def __init__(self, batch_size, latent_dim, **kwargs):
65
+ super().__init__(**kwargs)
66
+ self.batch_size = batch_size
67
+ self.latent_dim = latent_dim
68
+
69
+ def call(self, inputs):
70
+ means, variance = inputs
71
+ epsilon = tf.random.normal(
72
+ shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0
73
+ )
74
+ samples = means + tf.exp(0.5 * variance) * epsilon
75
+ return samples
76
+
77
+ def downsample(
78
+ channels,
79
+ kernels,
80
+ strides=2,
81
+ apply_norm=True,
82
+ apply_activation=True,
83
+ apply_dropout=False,
84
+ ):
85
+ block = tf.keras.Sequential()
86
+ block.add(
87
+ tf.keras.layers.Conv2D(
88
+ channels,
89
+ kernels,
90
+ strides=strides,
91
+ padding="same",
92
+ use_bias=False,
93
+ kernel_initializer=tf.keras.initializers.GlorotNormal(),
94
+ )
95
+ )
96
+ if apply_norm:
97
+ block.add(tfa.layers.InstanceNormalization())
98
+ if apply_activation:
99
+ block.add(tf.keras.layers.LeakyReLU(0.2))
100
+ if apply_dropout:
101
+ block.add(tf.keras.layers.Dropout(0.5))
102
+ return block
103
+
104
+
105
+ def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):
106
+ input_image = tf.keras.Input(shape=image_shape)
107
+ x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image)
108
+ x = downsample(2 * encoder_downsample_factor, 3)(x)
109
+ x = downsample(4 * encoder_downsample_factor, 3)(x)
110
+ x = downsample(8 * encoder_downsample_factor, 3)(x)
111
+ x = downsample(8 * encoder_downsample_factor, 3)(x)
112
+ x = downsample(8 * encoder_downsample_factor, 3)(x)
113
+ x = downsample(16 * encoder_downsample_factor, 3)(x)
114
+ x = tf.keras.layers.Flatten()(x)
115
+ mean = tf.keras.layers.Dense(latent_dim, name="mean")(x)
116
+ variance = tf.keras.layers.Dense(latent_dim, name="variance")(x)
117
+ return tf.keras.Model(input_image, [mean, variance], name="encoder")
118
+
119
+
120
+ def build_generator(mask_shape, latent_dim=256):
121
+ latent = tf.keras.Input(shape=(latent_dim))
122
+ mask = tf.keras.Input(shape=mask_shape)
123
+ x = tf.keras.layers.Dense(16384)(latent)
124
+ x = tf.keras.layers.Reshape((4, 4, 1024))(x)
125
+ x = ResBlock(filters=1024)(x, mask)
126
+ x = tf.keras.layers.UpSampling2D((2, 2))(x)
127
+ x = ResBlock(filters=1024)(x, mask)
128
+ x = tf.keras.layers.UpSampling2D((2, 2))(x)
129
+ x = ResBlock(filters=1024)(x, mask)
130
+ x = tf.keras.layers.UpSampling2D((2, 2))(x)
131
+ x = ResBlock(filters=512)(x, mask)
132
+ x = tf.keras.layers.UpSampling2D((2, 2))(x)
133
+ x = ResBlock(filters=256)(x, mask)
134
+ x = tf.keras.layers.UpSampling2D((2, 2))(x)
135
+ x = ResBlock(filters=128)(x, mask)
136
+ x = tf.keras.layers.UpSampling2D((2, 2))(x)
137
+ x = ResBlock(filters=64)(x, mask) # These 2 added layers
138
+ x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 512x512
139
+ x = ResBlock(filters=32)(x, mask) # These 2 added layers
140
+ x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 1024x1024
141
+ x = tf.nn.leaky_relu(x, 0.2)
142
+ output_image = tf.nn.sigmoid(tf.keras.layers.Conv2D(3, 4, padding="same")(x))
143
+ return tf.keras.Model([latent, mask], output_image, name="generator")
144
+
145
+
146
+ def build_discriminator(image_shape, downsample_factor=64):
147
+ input_image_A = tf.keras.Input(shape=image_shape, name="discriminator_image_A")
148
+ input_image_B = tf.keras.Input(shape=image_shape, name="discriminator_image_B")
149
+ x = tf.keras.layers.Concatenate()([input_image_A, input_image_B])
150
+ x1 = downsample(downsample_factor, 4, apply_norm=False)(x)
151
+ x2 = downsample(2 * downsample_factor, 4)(x1)
152
+ x3 = downsample(4 * downsample_factor, 4)(x2)
153
+ x4 = downsample(8 * downsample_factor, 4)(x3)
154
+ x5 = downsample(8 * downsample_factor, 4)(x4)
155
+ x6 = downsample(8 * downsample_factor, 4)(x5)
156
+ x7 = downsample(16 * downsample_factor, 4)(x6)
157
+ x8 = tf.keras.layers.Conv2D(1, 4)(x7)
158
+ outputs = [x1, x2, x3, x4, x5, x6, x7, x8]
159
+ return tf.keras.Model([input_image_A, input_image_B], outputs)
160
+
161
+
162
+ def generator_loss(y):
163
+ return -tf.reduce_mean(y)
164
+
165
+
166
+ def kl_divergence_loss(mean, variance):
167
+ return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance))
168
+
169
+
170
+ class FeatureMatchingLoss(tf.keras.losses.Loss):
171
+ def __init__(self, **kwargs):
172
+ super().__init__(**kwargs)
173
+ self.mae = tf.keras.losses.MeanAbsoluteError()
174
+
175
+ def call(self, y_true, y_pred):
176
+ loss = 0
177
+ for i in range(len(y_true) - 1):
178
+ loss += self.mae(y_true[i], y_pred[i])
179
+ return loss
180
+
181
+
182
+ class VGGFeatureMatchingLoss(tf.keras.losses.Loss):
183
+ def __init__(self, **kwargs):
184
+ super().__init__(**kwargs)
185
+ self.encoder_layers = [
186
+ "block1_conv1",
187
+ "block2_conv1",
188
+ "block3_conv1",
189
+ "block4_conv1",
190
+ "block5_conv1",
191
+ ]
192
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
193
+ vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet")
194
+ layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
195
+ self.vgg_model = tf.keras.Model(vgg.input, layer_outputs, name="VGG")
196
+ self.mae = tf.keras.losses.MeanAbsoluteError()
197
+
198
+ def call(self, y_true, y_pred):
199
+ y_true = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1))
200
+ y_pred = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1))
201
+ real_features = self.vgg_model(y_true)
202
+ fake_features = self.vgg_model(y_pred)
203
+ loss = 0
204
+ for i in range(len(real_features)):
205
+ loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
206
+ return loss
207
+
208
+
209
+ class DiscriminatorLoss(tf.keras.losses.Loss):
210
+ def __init__(self, **kwargs):
211
+ super().__init__(**kwargs)
212
+ self.hinge_loss = tf.keras.losses.Hinge()
213
+
214
+ def call(self, y, is_real):
215
+ label = 1.0 if is_real else -1.0
216
+ return self.hinge_loss(label, y)
217
+
218
+
219
+ class GauGAN(tf.keras.Model):
220
+ def __init__(
221
+ self,
222
+ image_size,
223
+ num_classes,
224
+ batch_size,
225
+ latent_dim,
226
+ feature_loss_coeff=10,
227
+ vgg_feature_loss_coeff=0.1,
228
+ kl_divergence_loss_coeff=0.1,
229
+ **kwargs,
230
+ ):
231
+ super().__init__(**kwargs)
232
+
233
+ self.image_size = image_size
234
+ self.latent_dim = latent_dim
235
+ self.batch_size = batch_size
236
+ self.num_classes = num_classes
237
+ self.image_shape = (image_size, image_size, 3)
238
+ self.mask_shape = (image_size, image_size, num_classes)
239
+ self.feature_loss_coeff = feature_loss_coeff
240
+ self.vgg_feature_loss_coeff = vgg_feature_loss_coeff
241
+ self.kl_divergence_loss_coeff = kl_divergence_loss_coeff
242
+
243
+ self.discriminator = build_discriminator(self.image_shape)
244
+ self.generator = build_generator(self.mask_shape, latent_dim=latent_dim)
245
+ self.encoder = build_encoder(self.image_shape, latent_dim=latent_dim)
246
+ self.sampler = GaussianSampler(batch_size, latent_dim)
247
+ self.patch_size, self.combined_model = self.build_combined_generator()
248
+
249
+ self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss")
250
+ self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss")
251
+ self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss")
252
+ self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss")
253
+ self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
254
+
255
+ @property
256
+ def metrics(self):
257
+ return [
258
+ self.disc_loss_tracker,
259
+ self.gen_loss_tracker,
260
+ self.feat_loss_tracker,
261
+ self.vgg_loss_tracker,
262
+ self.kl_loss_tracker,
263
+ ]
264
+
265
+ def build_combined_generator(self):
266
+ # This method builds a model that takes as inputs the following:
267
+ # latent vector, one-hot encoded segmentation label map, and
268
+ # a segmentation map. It then (i) generates an image with the generator,
269
+ # (ii) passes the generated images and segmentation map to the discriminator.
270
+ # Finally, the model produces the following outputs: (a) discriminator outputs,
271
+ # (b) generated image.
272
+ # We will be using this model to simplify the implementation.
273
+ self.discriminator.trainable = False
274
+ mask_input = tf.keras.Input(shape=self.mask_shape, name="mask")
275
+ image_input = tf.keras.Input(shape=self.image_shape, name="image")
276
+ latent_input = tf.keras.Input(shape=(self.latent_dim), name="latent")
277
+ generated_image = self.generator([latent_input, mask_input])
278
+ discriminator_output = self.discriminator([image_input, generated_image])
279
+ patch_size = discriminator_output[-1].shape[1]
280
+ combined_model = tf.keras.Model(
281
+ [latent_input, mask_input, image_input],
282
+ [discriminator_output, generated_image],
283
+ )
284
+ return patch_size, combined_model
285
+
286
+ def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs):
287
+ super().compile(**kwargs)
288
+ self.generator_optimizer = tf.keras.optimizers.Adam(
289
+ gen_lr, beta_1=0.0, beta_2=0.999
290
+ )
291
+ self.discriminator_optimizer = tf.keras.optimizers.Adam(
292
+ disc_lr, beta_1=0.0, beta_2=0.999
293
+ )
294
+ self.discriminator_loss = DiscriminatorLoss()
295
+ self.feature_matching_loss = FeatureMatchingLoss()
296
+ self.vgg_loss = VGGFeatureMatchingLoss()
297
+
298
+ def train_discriminator(self, latent_vector, segmentation_map, real_image, labels):
299
+ fake_images = self.generator([latent_vector, labels])
300
+ with tf.GradientTape() as gradient_tape:
301
+ pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
302
+ pred_real = self.discriminator([segmentation_map, real_image])[-1]
303
+ loss_fake = self.discriminator_loss(pred_fake, False)
304
+ loss_real = self.discriminator_loss(pred_real, True)
305
+ total_loss = 0.5 * (loss_fake + loss_real)
306
+
307
+ self.discriminator.trainable = True
308
+ gradients = gradient_tape.gradient(
309
+ total_loss, self.discriminator.trainable_variables
310
+ )
311
+ self.discriminator_optimizer.apply_gradients(
312
+ zip(gradients, self.discriminator.trainable_variables)
313
+ )
314
+ return total_loss
315
+
316
+ def train_generator(
317
+ self, latent_vector, segmentation_map, labels, image, mean, variance
318
+ ):
319
+ # Generator learns through the signal provided by the discriminator. During
320
+ # backpropagation, we only update the generator parameters.
321
+ self.discriminator.trainable = False
322
+ with tf.GradientTape() as tape:
323
+ real_d_output = self.discriminator([segmentation_map, image])
324
+ fake_d_output, fake_image = self.combined_model(
325
+ [latent_vector, labels, segmentation_map]
326
+ )
327
+ pred = fake_d_output[-1]
328
+
329
+ # Compute generator losses.
330
+ g_loss = generator_loss(pred)
331
+ kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
332
+ vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
333
+ feature_loss = self.feature_loss_coeff * self.feature_matching_loss(real_d_output, fake_d_output)
334
+ total_loss = g_loss + kl_loss + vgg_loss + feature_loss
335
+
336
+ gradients = tape.gradient(total_loss, self.combined_model.trainable_variables)
337
+ self.generator_optimizer.apply_gradients(
338
+ zip(gradients, self.combined_model.trainable_variables)
339
+ )
340
+ return total_loss, feature_loss, vgg_loss, kl_loss
341
+
342
+ def train_step(self, data):
343
+ segmentation_map, image, labels = data
344
+ mean, variance = self.encoder(image)
345
+ latent_vector = self.sampler([mean, variance])
346
+ discriminator_loss = self.train_discriminator(
347
+ latent_vector, segmentation_map, image, labels
348
+ )
349
+ (generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator(
350
+ latent_vector, segmentation_map, labels, image, mean, variance
351
+ )
352
+
353
+ # Report progress.
354
+ self.disc_loss_tracker.update_state(discriminator_loss)
355
+ self.gen_loss_tracker.update_state(generator_loss)
356
+ self.feat_loss_tracker.update_state(feature_loss)
357
+ self.vgg_loss_tracker.update_state(vgg_loss)
358
+ self.kl_loss_tracker.update_state(kl_loss)
359
+ results = {m.name: m.result() for m in self.metrics}
360
+ return results
361
+
362
+ def test_step(self, data):
363
+ segmentation_map, image, labels = data
364
+ # Obtain the learned moments of the real image distribution.
365
+ mean, variance = self.encoder(image)
366
+
367
+ # Sample a latent from the distribution defined by the learned moments.
368
+ latent_vector = self.sampler([mean, variance])
369
+
370
+ # Generate the fake images.
371
+ fake_images = self.generator([latent_vector, labels])
372
+
373
+ # Calculate the losses.
374
+ pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
375
+ pred_real = self.discriminator([segmentation_map, image])[-1]
376
+ loss_fake = self.discriminator_loss(pred_fake, False)
377
+ loss_real = self.discriminator_loss(pred_real, True)
378
+ total_discriminator_loss = 0.5 * (loss_fake + loss_real)
379
+ real_d_output = self.discriminator([segmentation_map, image])
380
+ fake_d_output, fake_image = self.combined_model(
381
+ [latent_vector, labels, segmentation_map]
382
+ )
383
+ pred = fake_d_output[-1]
384
+ g_loss = generator_loss(pred)
385
+ kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
386
+ vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
387
+ feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
388
+ real_d_output, fake_d_output
389
+ )
390
+ total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss
391
+
392
+ # Report progress.
393
+ self.disc_loss_tracker.update_state(total_discriminator_loss)
394
+ self.gen_loss_tracker.update_state(total_generator_loss)
395
+ self.feat_loss_tracker.update_state(feature_loss)
396
+ self.vgg_loss_tracker.update_state(vgg_loss)
397
+ self.kl_loss_tracker.update_state(kl_loss)
398
+ results = {m.name: m.result() for m in self.metrics}
399
+ return results
400
+
401
+ def call(self, inputs):
402
+ latent_vectors, labels = inputs
403
+ return self.generator([latent_vectors, labels])
utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ import numpy as np
4
+
5
+
6
+ # class to rgb colour pallet
7
+ color_dict = {
8
+ 0: (0, 0, 0), # BG
9
+ 1: (239, 164, 0), # EX
10
+ 2: (0, 186, 127), # HE
11
+ 3: (0, 185, 255), # SE
12
+ 4: (34, 80, 242), # MA
13
+ 5: (73, 73, 73), # OD
14
+ 6: (255, 255, 255), # VB
15
+ }
16
+
17
+
18
+ def rgb_to_onehot(rgb_arr, color_dict):
19
+ """
20
+ Converts a rgb label map to onehot label map defined by color_dict
21
+ Parameters:
22
+ rgb_arr (array): rgb label mask with shape (H x W x 3)
23
+ color_dict (dict): dictionary mapping of class to colour
24
+ Returns:
25
+ arr (array): onehot label map of shape (H x W x n_classes)
26
+ """
27
+ num_classes = len(color_dict)
28
+ shape = rgb_arr.shape[:2]+(num_classes,)
29
+ arr = np.zeros(shape, dtype=np.int8)
30
+ for i, cls in enumerate(color_dict):
31
+ arr[:, :, i] = np.all(rgb_arr.reshape((-1, 3)) == color_dict[i], axis=1).reshape(shape[:2])
32
+ return arr
33
+
34
+
35
+ def onehot_to_rgb(onehot_arr, color_dict):
36
+ """
37
+ Converts an onehot label map to rgb label map defined by color_dict
38
+ Parameters:
39
+ onehot_arr (array): onehot label mask with shape (H x W x n_classes)
40
+ color_dict (dict): dictionary mapping of class to colour
41
+ Returns:
42
+ arr (array): rgb label map of shape (H x W x 3)
43
+ """
44
+ shape = onehot_arr.shape[:2]
45
+ mask = np.argmax(onehot_arr, axis=-1)
46
+ arr = np.zeros(shape+(3,), dtype=np.uint8)
47
+ for i, cls in enumerate(color_dict):
48
+ arr = arr + np.tile(color_dict[cls], shape + (1,)) * (mask[..., None] == cls)
49
+ return arr
50
+
51
+
52
+ def fix_pred_label(labels):
53
+ """
54
+ Post-processing fixes for the prediction of VB and BG label class,
55
+ the Vitrous Body should be consistently spherical on a black background
56
+ Parameters:
57
+ labels (tensor): A 4-D array of predicted label
58
+ with shape (batch x H x W x 7)
59
+ Returns:
60
+ fixed_labels (array): shape (batch x H x W x 7)
61
+ """
62
+ shape = labels.shape[1:-1]
63
+ VB = np.uint8(cv2.circle(np.zeros(shape), (shape[0]//2, shape[1]//2), min(shape) // 2, 1, -1))[..., None]
64
+ BG = np.uint8(VB == 0)
65
+
66
+ VB = VB - np.sum(labels[..., 1:-1], axis=-1)[..., None]
67
+ BG = np.broadcast_to(BG, VB.shape)
68
+
69
+ fixed_labels = np.concatenate([BG, labels[..., 1:-1], VB], axis=-1)
70
+
71
+ return fixed_labels