|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Abstract VAE model class. |
|
|
|
Gaussian encoder and decoder (the latter assumed to have constant variance). |
|
|
|
Inspiration drawn from https://github.com/pytorch/examples/tree/main/vae. |
|
""" |
|
|
|
import abc |
|
from typing import Optional, Mapping |
|
|
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
class Model(nn.Module, metaclass=abc.ABCMeta): |
|
"""Abstract VAE model class.""" |
|
|
|
codeword_dim: Optional[int] = None |
|
code_len: int = 256 |
|
code_dropout: str = "none" |
|
|
|
@abc.abstractmethod |
|
def encode( |
|
self, |
|
x: jax.Array, |
|
*, |
|
train: bool = False, |
|
) -> tuple[jax.Array, jax.Array]: |
|
... |
|
|
|
def reparametrize( |
|
self, |
|
mu: jax.Array, |
|
logvar: jax.Array, |
|
rng: jax.Array | None = None, |
|
) -> jax.Array: |
|
std = jnp.exp(0.5 * logvar) |
|
if rng is None: |
|
rng = self.make_rng("dropout") |
|
eps = jax.random.normal(rng, shape=std.shape, dtype=std.dtype) |
|
return mu + std * eps |
|
|
|
@abc.abstractmethod |
|
def decode( |
|
self, x: jax.Array, |
|
train: bool = False, |
|
) -> jax.Array | Mapping[str, jax.Array]: |
|
... |
|
|
|
def code_dropout_fn(self, z: jax.Array, *, train: bool = False) -> jax.Array: |
|
|
|
|
|
assert self.code_dropout in ["none", "seq", "random"] |
|
if train and self.code_dropout != "none": |
|
importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1] |
|
thr = jax.random.uniform(self.make_rng("dropout"), z.shape[:1]) |
|
mask = importance[None, :] > thr[:, None] |
|
if self.code_dropout == "random": |
|
mask = jax.random.permutation( |
|
self.make_rng("dropout"), mask, axis=-1, independent=True) |
|
z = z * mask[:, :, None] |
|
return z |
|
|
|
def __call__( |
|
self, |
|
x: jax.Array, |
|
*, |
|
train: bool = False, |
|
) -> tuple[jax.Array | Mapping[str, jax.Array], Mapping[str, jax.Array]]: |
|
mu, logvar = self.encode(x, train=train) |
|
|
|
if train: |
|
z = self.reparametrize(mu, logvar) |
|
else: |
|
z = mu |
|
z = self.code_dropout_fn(z, train=train) |
|
x = self.decode(z, train=train) |
|
return x, {"mu": mu, "logvar": logvar, "z": z} |
|
|