File size: 16,916 Bytes
59c7686 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 |
from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import functools
import ml_collections
import jax
###########################
### Helper Modules
### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
###########################
def get_norm_layer(norm_type):
"""Normalization layer."""
if norm_type == 'BN':
raise NotImplementedError
elif norm_type == 'LN':
norm_fn = functools.partial(nn.LayerNorm)
elif norm_type == 'GN':
norm_fn = functools.partial(nn.GroupNorm)
else:
raise NotImplementedError
return norm_fn
def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
(1,) + window_shape + (1,),
(1,) + strides + (1,), padding)
pool_denom = jax.lax.reduce_window(
jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
(1,) + strides + (1,), padding)
return pool_sum / pool_denom
def upsample(x, factor=2):
n, h, w, c = x.shape
x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
return x
def dsample(x):
return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
def squared_euclidean_distance(a: jnp.ndarray,
b: jnp.ndarray,
b2: jnp.ndarray = None) -> jnp.ndarray:
"""Computes the pairwise squared Euclidean distance.
Args:
a: float32: (n, d): An array of points.
b: float32: (m, d): An array of points.
b2: float32: (d, m): b square transpose.
Returns:
d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
a[i] and b[j].
"""
if b2 is None:
b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
a2 = jnp.sum(a**2, axis=1, keepdims=True)
ab = jnp.matmul(a, b.T)
d = a2 - 2 * ab + b2
return d
def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
"""Calculates the entropy loss. Affinity is the similarity/distance matrix."""
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probs = jax.nn.softmax(flat_affinity, axis=-1)
log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
if loss_type == "softmax":
target_probs = probs
elif loss_type == "argmax":
codes = jnp.argmax(flat_affinity, axis=-1)
onehots = jax.nn.one_hot(
codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
onehots = probs - jax.lax.stop_gradient(probs - onehots)
target_probs = onehots
else:
raise ValueError("Entropy loss {} not supported".format(loss_type))
avg_probs = jnp.mean(target_probs, axis=0)
avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
loss = sample_entropy - avg_entropy
return loss
def sg(x):
return jax.lax.stop_gradient(x)
###########################
### Modules
###########################
class ResBlock(nn.Module):
"""Basic Residual Block."""
filters: int
norm_fn: Any
activation_fn: Any
@nn.compact
def __call__(self, x):
input_dim = x.shape[-1]
residual = x
x = self.norm_fn()(x)
x = self.activation_fn(x)
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
x = self.norm_fn()(x)
x = self.activation_fn(x)
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
if input_dim != self.filters:#Basically if input doesn't match output, use a skip
residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
return x + residual
class Encoder(nn.Module):
"""From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
config: ml_collections.ConfigDict
def setup(self):
self.filters = self.config.filters#filters is the original setup
self.num_res_blocks = self.config.num_res_blocks
self.channel_multipliers = self.config.channel_multipliers
self.embedding_dim = self.config.embedding_dim
self.norm_type = self.config.norm_type
self.activation_fn = nn.swish
def pixels(self, x):
#print("pixel shuffle x shape", x.shape)
x = pixel_unshuffle(x, 2)
#print(x.shape)
B, H, W, C = x.shape
x = jnp.reshape(x, (B, H, W, int(C/4), 4))
#print(x.shape)
x = jnp.mean(x, axis = -1)
#print(x.shape)
#exit()
return x
@nn.compact
def __call__(self, x):
print("Initializing encoder.")
norm_fn = get_norm_layer(norm_type=self.norm_type)
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
print("Incoming encoder shape", x.shape)
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
print('Encoder layer', x.shape)
num_blocks = len(self.channel_multipliers)
#The way SD works, is it does 2x resnet, not changing anything, then downsample
#It does this 3 times, leading to 8x downsample
#Then it has an extra resnet block, and THEN from 512 to 8 / 4
#So the DCAE architecture is like 4x resnet, down
#And then efficient vit down
for i in range(num_blocks):
filters = self.filters * self.channel_multipliers[i]
for _ in range(self.num_res_blocks):
x = ResBlock(filters, **block_args)(x)
if i < num_blocks - 1:#For each block *except end* do downsample
print("doing downsample")
#If we want to do it DCAE style, they do channel averaging between before downsample and after
if self.channel_multipliers[i] != -1:
print("pre pixels", x.shape)
pixel_x = self.pixels(x)
print("pixel_x", pixel_x.shape)
x = dsample(x) + pixel_x
print("post", x.shape)
else:
x = dsample(x)
print("other post", x.shape)
print('Encoder layer', x.shape)
#After we are done downsampling, we do the 2 resnet, and down below here, we have the 2 midblock?
for _ in range(self.num_res_blocks):
x = ResBlock(filters, **block_args)(x)
print('Encoder layer final', x.shape)
x = norm_fn()(x)
x = self.activation_fn(x)
last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
print("Final embeddings are size", x.shape)
return x
class Decoder(nn.Module):
"""From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
config: ml_collections.ConfigDict
def setup(self):
self.filters = self.config.filters
self.num_res_blocks = self.config.num_res_blocks
self.channel_multipliers = self.config.channel_multipliers
self.norm_type = self.config.norm_type
self.image_channels = self.config.image_channels
self.activation_fn = nn.swish
def pixels(self, x):
print("pixels shape", x.shape)
x = jnp.repeat(x, 4, axis = -1)
print(x.shape)
x = pixel_shuffle(x, 2)
print(x.shape)
print("done duplicating")
return x
@nn.compact
def __call__(self, x):
norm_fn = get_norm_layer(norm_type=self.norm_type)
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
num_blocks = len(self.channel_multipliers)
filters = self.filters * self.channel_multipliers[-1]
print("Decoder incoming shape", x.shape)
#We don't need to do anything here because it'll put it back to 512
x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
print("Decoder input", x.shape)
#This is the mid block
for _ in range(self.num_res_blocks):
x = ResBlock(filters, **block_args)(x)
print('Mid Block Decoder layer', x.shape)
#First two SET of blocks is just 3 resnet, no channel changes, we are already at 4x = 512
for i in reversed(range(num_blocks)):
filters = self.filters * self.channel_multipliers[i]
for _ in range(self.num_res_blocks + 1):
x = ResBlock(filters, **block_args)(x)
if i > 0:
#We do pixel channel downsampling every time we downsample spatially.
pixel = self.pixels(x)
print("pre up", x.shape)
x = upsample(x, 2)
print("post up", x.shape)
x = x + pixel
x = nn.Conv(filters, kernel_size=(3, 3))(x)
print('Decoder layer', x.shape)
x = norm_fn()(x)
x = self.activation_fn(x)
x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
return x
class VectorQuantizer(nn.Module):
"""Basic vector quantizer."""
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
codebook_size = self.config.codebook_size
emb_dim = x.shape[-1]
codebook = self.param(
"codebook",
jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
(codebook_size, emb_dim))
codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
distances = jnp.reshape(
squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
encoding_indices = jnp.argmin(distances, axis=-1)
encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
quantized = self.quantize(encoding_onehot)
result_dict = dict()
if self.train:
e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
q_latent_loss = jnp.mean((quantized - sg(x))**2)
entropy_loss = 0.0
if self.config.entropy_loss_ratio != 0:
entropy_loss = entropy_loss_fn(
-distances,
loss_type=self.config.entropy_loss_type,
temperature=self.config.entropy_temperature
) * self.config.entropy_loss_ratio
e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
loss = e_latent_loss + q_latent_loss + entropy_loss
result_dict = dict(
quantizer_loss=loss,
e_latent_loss=e_latent_loss,
q_latent_loss=q_latent_loss,
entropy_loss=entropy_loss)
quantized = x + jax.lax.stop_gradient(quantized - x)
result_dict.update({
"z_ids": encoding_indices,
})
return quantized, result_dict
def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
codebook = jnp.asarray(self.variables["params"]["codebook"])
return jnp.dot(encoding_onehot, codebook)
def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
codebook = self.variables["params"]["codebook"]
return jnp.take(codebook, ids, axis=0)
class KLQuantizer(nn.Module):
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
means = x[..., :emb_dim]
logvars = x[..., emb_dim:]
if not self.train:
result_dict = dict()
return means, result_dict
else:
noise = jax.random.normal(self.make_rng("noise"), means.shape)
stds = jnp.exp(0.5 * logvars)
z = means + stds * noise
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
result_dict = dict(quantizer_loss=kl_loss)
return z, result_dict
class AEQuantizer(nn.Module): #cooking
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
result_dict = dict()
return x, result_dict
from einops import rearrange
def pixel_unshuffle(x, factor):
x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=factor, b2=factor)
return x
def pixel_shuffle(x, factor):
x = rearrange(x, '... h w (c b1 b2) -> ... (h b1) (w b2) c', b1=factor, b2=factor)
return x
class KLQuantizerTwo(nn.Module):
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
#emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
#means = x[..., :emb_dim]
#logvars = x[..., emb_dim:]
#Wwe actually wanna do mean and STD on the batch axis?
#we start as b hw 8, go to b hw 4, with mean and std over those.
if not self.train:
result_dict = dict()
return x, result_dict
else:
#Previous run is mean over axis 0..
means = jnp.mean(x, axis = [1,2,3])
stds = jnp.std(x, axis = [1,2,3])
noise = jax.random.normal(self.make_rng("noise"), means.shape)
logvars = .5 * jnp.log(stds)
z = means + stds * noise
#We just... don't need to return Z for this, but instead we return X
#This is the denoising version
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
result_dict = dict(quantizer_loss=kl_loss)
return x, result_dict
class FSQuantizer(nn.Module):
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
z = jnp.tanh(x) # [-1, 1]
z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
quantized = z + jax.lax.stop_gradient(zhat - z)
quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
result_dict = dict()
# Diagnostics for codebook usage.
zhat_scaled = zhat + self.config['fsq_levels'] // 2
basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
idx_flat = idx.reshape(-1)
usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
result_dict.update({
"z_ids": zhat,
'usage': usage
})
return quantized, result_dict
class VQVAE(nn.Module):
"""VQVAE model."""
config: ml_collections.ConfigDict
train: bool
def setup(self):
"""VQVAE setup."""
if self.config['quantizer_type'] == 'vq':
self.quantizer = VectorQuantizer(config=self.config, train=self.train)
elif self.config['quantizer_type'] == 'kl':
self.quantizer = KLQuantizer(config=self.config, train=self.train)
elif self.config['quantizer_type'] == 'fsq':
self.quantizer = FSQuantizer(config=self.config, train=self.train)
elif self.config['quantizer_type'] == 'ae':
self.quantizer = AEQuantizer(config=self.config, train=self.train)
elif self.config["quantizer_type"] == "kl_two":
self.quantizer = KLQuantizerTwo(config=self.config, train=self.train)
self.encoder = Encoder(config=self.config)
self.decoder = Decoder(config=self.config)
def encode(self, image):
encoded_feature = self.encoder(image)
quantized, result_dict = self.quantizer(encoded_feature)
print("After quant", quantized.shape)
return quantized, result_dict
def decode(self, z_vectors):
print("z_vectors shape", z_vectors.shape)
reconstructed = self.decoder(z_vectors)
return reconstructed
def decode_from_indices(self, z_ids):
z_vectors = self.quantizer.decode_ids(z_ids)
reconstructed_image = self.decode(z_vectors)
return reconstructed_image
def encode_to_indices(self, image):
encoded_feature = self.encoder(image)
_, result_dict = self.quantizer(encoded_feature)
ids = result_dict["z_ids"]
return ids
def __call__(self, input_dict):
quantized, result_dict = self.encode(input_dict)
outputs = self.decoder(quantized)
return outputs, result_dict
|