|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Decoder-only and encoder-decoder GIVT model. |
|
|
|
Used abbreviations for dimension annotations: |
|
B: batch size. |
|
E: embedding size. |
|
L: (soft) token sequence length. |
|
D: soft token dimension. |
|
P: number of patches (extracted by a ViT encoder in GIVT-based UViM) |
|
""" |
|
|
|
import enum |
|
import itertools |
|
from typing import Literal, Optional, Sequence, Any, Mapping |
|
|
|
from absl import logging |
|
from big_vision import utils |
|
from big_vision.models import common |
|
from big_vision.models import vit |
|
import distrax |
|
import einops |
|
import flax.linen as nn |
|
from flax.linen import partitioning |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
class _SpecialLabel(enum.Enum): |
|
|
|
MASK = "mask" |
|
NOMASK = "nomask" |
|
REPLACE = "replace" |
|
NOLABEL = "nolabel" |
|
|
|
|
|
def _random_mask_with_ratios(rng, ratios: jax.Array, seq_len: int): |
|
"""Generates masks where a fraction of tokens is uncovered. |
|
|
|
Args: |
|
rng: RNG. |
|
ratios: Ratios, must be a 1D matrix of shape (B,). Values must be in |
|
[0, 1], and indicate at ratios[i] how many of the i-th tokens are |
|
uncovered (ie. equal to `True`). |
|
seq_len: How many tokens this mask has to cover. |
|
|
|
Returns: |
|
Mask of dtype bool, shape (B, L). |
|
|
|
Raises: |
|
ValueError: Incorrect inputs. |
|
""" |
|
if ratios.ndim != 1: |
|
raise ValueError("Ratios must have shape (B,)!") |
|
ratios = jnp.clip(ratios, 0, 1) |
|
indices = jnp.arange(seq_len, dtype=jnp.float32) |
|
ratios = ratios[:, jnp.newaxis] * seq_len |
|
|
|
mask = (indices < ratios).astype(jnp.bool_) |
|
|
|
return jax.random.shuffle(rng, mask, axis=-1) |
|
|
|
|
|
def apply_mask_schedule(ratio: float | jax.Array, method: str) -> jax.Array: |
|
"""Generate a mask rate by scheduling mask functions R.""" |
|
if method == "cosine": |
|
mask_ratio = jax.lax.cos(jnp.pi / 2. * ratio) |
|
elif "pow:" in method: |
|
exponent = float(method.replace("pow:", "")) |
|
mask_ratio = 1. - ratio**exponent |
|
else: |
|
raise NotImplementedError(method) |
|
|
|
mask_ratio = jnp.clip(mask_ratio, 1e-6, 1.) |
|
return mask_ratio |
|
|
|
|
|
class EncoderDecoderBlock(nn.Module): |
|
"""Transformer encoder-decoder layer.""" |
|
mlp_dim: int |
|
num_heads: int |
|
dropout_rate: float = 0. |
|
decode: bool = False |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
targets: jax.Array, |
|
encoded: jax.Array | None = None, |
|
decoder_mask: jax.Array | None = None, |
|
deterministic: bool = True, |
|
) -> tuple[jax.Array, jax.Array]: |
|
"""Applies EncoderDecoderBlock module. |
|
|
|
Args: |
|
targets: target text embeddings [B, L, D]. |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
decoder_mask: decoder self-attention mask. |
|
deterministic: bool, deterministic or not (to apply dropout). |
|
|
|
Returns: |
|
output after transformer encoder-decoder block [B, L, E]. |
|
""" |
|
|
|
def wlc(f): |
|
dim_names = ("act_batch", "act_len", "act_emb") |
|
return nn.with_logical_constraint(f, dim_names) |
|
|
|
x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=False)(targets)) |
|
x = wlc(nn.SelfAttention( |
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, |
|
dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( |
|
x, decoder_mask, deterministic=deterministic)) |
|
x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)) |
|
x = wlc(x + targets) |
|
|
|
if encoded is None: |
|
y = x |
|
else: |
|
|
|
y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=False)(x)) |
|
y = wlc(nn.MultiHeadDotProductAttention( |
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, |
|
dropout_rate=self.dropout_rate, name="CrossAttn")( |
|
y, encoded, deterministic=deterministic)) |
|
y = wlc( |
|
nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)) |
|
y = wlc(y + x) |
|
|
|
|
|
z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=False)(y)) |
|
z = wlc(vit.MlpBlock(mlp_dim=self.mlp_dim, dropout=self.dropout_rate, |
|
name="MLP")(z, deterministic=deterministic)) |
|
|
|
|
|
out = wlc(y + z) |
|
return out, out |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Transformer decoder model with optional cross-attention.""" |
|
emb_dim: int |
|
mlp_dim: int |
|
num_heads: int |
|
num_layers: int |
|
out_dim: int |
|
seq_len: int |
|
style: Literal["ar", "masked"] |
|
dropout_rate: float = 0. |
|
zero_embedding_init: bool = False |
|
|
|
scan: bool = False |
|
remat_policy: str = "nothing_saveable" |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
targets: jax.Array, |
|
encoded: jax.Array | None = None, |
|
decoder_mask: jax.Array | None = None, |
|
decode: bool = False, |
|
deterministic: bool = True, |
|
return_reps: bool = False, |
|
) -> jax.Array | tuple[jax.Array, Mapping[str, jax.Array]]: |
|
"""Applies Transformer model on the inputs. |
|
|
|
Args: |
|
targets: target text tokens [B, L]. |
|
encoded: encoded sequence from an encoder [B, P, E]. |
|
decoder_mask: decoder self-attention mask. |
|
decode: bool, whether to perform fast autoregressive decoding with cache. |
|
deterministic: bool, deterministic or not (to apply dropout). |
|
return_reps: bool, whether to return intermediate representations. |
|
|
|
Returns: |
|
output of a transformer decoder [B, L, out_dim], where out_dim is usually |
|
a multiple of D. |
|
""" |
|
if self.style == "masked" and decode: |
|
raise ValueError("Cannot run masked model in cached mode!") |
|
|
|
pos_emb = vit.get_posemb( |
|
self, "learn", self.seq_len, self.emb_dim, |
|
"pos_emb") |
|
|
|
y = common.AddPositionEmbs( |
|
decode=decode, name="PosEmbedTargets")(targets, pos_emb) |
|
|
|
out = {} |
|
if self.scan: |
|
|
|
|
|
|
|
|
|
|
|
enc_dec_block_remat = nn.remat( |
|
EncoderDecoderBlock, |
|
prevent_cse=False, |
|
static_argnums=(-1, -2), |
|
policy=getattr(jax.checkpoint_policies, self.remat_policy, None)) |
|
|
|
initializing = self.is_mutable_collection("params") |
|
param_scan_axis = 1 |
|
params_spec = (param_scan_axis if initializing |
|
else partitioning.ScanIn(param_scan_axis)) |
|
dec_scanned = nn.scan(enc_dec_block_remat, |
|
variable_axes={ |
|
"params": params_spec, |
|
"cache": 0, |
|
}, |
|
split_rngs={"params": True, "dropout": True}, |
|
in_axes=nn.broadcast, |
|
length=self.num_layers) |
|
|
|
y, out = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim, |
|
dropout_rate=self.dropout_rate, decode=decode, |
|
name="EncDecBlock")( |
|
y, encoded, decoder_mask, deterministic) |
|
|
|
|
|
|
|
|
|
|
|
assert out.shape[0] == self.num_layers and ( |
|
decode or out.shape[2] == self.seq_len), ( |
|
(out.shape, self.num_layers, self.seq_len)) |
|
out = {f"block{l}_rep": jnp.mean(out[l], axis=1) |
|
for l in range(self.num_layers)} |
|
else: |
|
for lyr in range(self.num_layers): |
|
y, _ = EncoderDecoderBlock( |
|
num_heads=self.num_heads, mlp_dim=self.mlp_dim, |
|
dropout_rate=self.dropout_rate, decode=decode, |
|
name=f"EncDecBlock{lyr}")(y, encoded, decoder_mask=decoder_mask, |
|
deterministic=deterministic) |
|
out[f"block{lyr}_rep"] = jnp.mean(y, axis=1) |
|
y = nn.LayerNorm(name="LayerNorm")(y) |
|
out["pre_logits"] = jnp.mean(y, axis=1) |
|
|
|
logits = nn.Dense( |
|
self.out_dim, |
|
kernel_init=nn.initializers.zeros, |
|
name="LogitsDense", |
|
)(y) |
|
out["logits"] = logits |
|
if return_reps: |
|
return logits, out |
|
return logits |
|
|
|
|
|
class Model(nn.Module): |
|
"""GIVT model supporting decoder-only and encoder-decoder applications.""" |
|
num_heads: int = 8 |
|
|
|
num_layers: int = 0 |
|
num_decoder_layers: int = 6 |
|
mlp_dim: int = 2048 |
|
enc_dropout_rate: float = 0. |
|
dec_dropout_rate: float = 0. |
|
|
|
emb_dim: int = 512 |
|
num_labels: Optional[int] = 1000 |
|
seq_len: int = 256 |
|
|
|
patches: Sequence[int] = (16, 16) |
|
input_size: Sequence[int] = (256, 256) |
|
posemb_type: Literal["learn", "sincos2d"] = "learn" |
|
zero_decoder_seq: bool = False |
|
style: Literal["ar", "masked"] = "ar" |
|
|
|
zero_embedding_init: bool = False |
|
|
|
num_mixtures: int = 4 |
|
multivariate: bool = False |
|
out_dim: int = 32 |
|
scale_tol: float = 1e-6 |
|
|
|
|
|
mask_schedule_train: str = "cosine" |
|
|
|
min_masking_rate_training: float = 0.3 |
|
|
|
|
|
|
|
|
|
|
|
mask_style: str = "replace" |
|
|
|
|
|
drop_labels_probability: float = 0.0 |
|
|
|
fix_square_plus: bool = False |
|
|
|
|
|
|
|
per_channel_mixtures: bool = True |
|
|
|
scan: bool = False |
|
remat_policy: str = "nothing_saveable" |
|
|
|
@property |
|
def has_encoder(self) -> bool: |
|
return self.num_layers > 0 |
|
|
|
@property |
|
def num_logits(self) -> int: |
|
if self.multivariate: |
|
assert self.num_mixtures == 1 |
|
|
|
|
|
return round(self.out_dim ** 2) + self.out_dim |
|
|
|
elif self.per_channel_mixtures: |
|
|
|
|
|
|
|
return 3 * self.num_mixtures * self.out_dim |
|
|
|
else: |
|
|
|
return self.num_mixtures + 2 * self.num_mixtures * self.out_dim |
|
|
|
def setup(self) -> None: |
|
assert self.posemb_type == "learn" |
|
assert self.num_mixtures > 0 |
|
|
|
if self.multivariate and self.num_mixtures != 1: |
|
raise ValueError("Cannot do multivariate GMM!") |
|
|
|
if self.num_layers > 0: |
|
grid_size = np.array(self.input_size) // np.array(self.patches) |
|
|
|
self.pos_emb_for_encoder = vit.get_posemb( |
|
self, self.posemb_type, grid_size, self.emb_dim, |
|
"pos_embedding_encoder") |
|
|
|
self.conv = nn.Conv(self.emb_dim, self.patches, padding="VALID", |
|
strides=self.patches, name="EmbedPatches") |
|
|
|
self.encoder = vit.Encoder( |
|
depth=self.num_layers, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
dropout=self.enc_dropout_rate, |
|
scan=self.scan, |
|
remat_policy=self.remat_policy,) |
|
else: |
|
self.encoder = None |
|
|
|
|
|
next_label = itertools.count(self.num_labels or 0) |
|
special_labels = {} |
|
|
|
if self.style == "ar": |
|
pass |
|
elif self.style == "masked": |
|
if self.mask_style == "replace": |
|
special_labels = {_SpecialLabel.MASK: next(next_label)} |
|
elif self.mask_style == "concat": |
|
special_labels = { |
|
_SpecialLabel.MASK: next(next_label), |
|
_SpecialLabel.NOMASK: next(next_label), |
|
_SpecialLabel.REPLACE: next(next_label), |
|
} |
|
else: |
|
raise NotImplementedError(self.mask_style) |
|
else: |
|
raise NotImplementedError(self.style) |
|
|
|
if self.drop_labels_probability > 0: |
|
special_labels[_SpecialLabel.NOLABEL] = next(next_label) |
|
|
|
self.special_labels = special_labels |
|
lookup_size = (self.num_labels or 1) + len(self.special_labels) |
|
|
|
self.labels_emb = nn.Embed( |
|
lookup_size, |
|
self.emb_dim, |
|
name="EmbedLabels", |
|
embedding_init=nn.initializers.zeros |
|
if self.zero_embedding_init |
|
else nn.initializers.normal(stddev=1.0), |
|
) |
|
|
|
self.targets_emb = nn.Dense(self.emb_dim, name="EmbedTargets") |
|
|
|
self.decoder = Decoder( |
|
num_layers=self.num_decoder_layers or self.num_layers, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
out_dim=self.num_logits, |
|
|
|
seq_len=self.seq_len + int(self.style == "masked"), |
|
dropout_rate=self.dec_dropout_rate, |
|
emb_dim=self.emb_dim, |
|
zero_embedding_init=self.zero_embedding_init, |
|
style=self.style, |
|
scan=self.scan, |
|
remat_policy=self.remat_policy, |
|
) |
|
|
|
def encode(self, image: jax.Array, train: bool = False) -> jax.Array: |
|
"""Encodes input image or embeddings.""" |
|
emb = self.conv(image) |
|
patch_embeddings = einops.rearrange(emb, "B PH PW E -> B (PH PW) E") |
|
encoded, _ = self.encoder( |
|
patch_embeddings + self.pos_emb_for_encoder, deterministic=not train) |
|
return encoded |
|
|
|
def embed_labels( |
|
self, |
|
labels: jax.Array | None = None, |
|
batch_size: int | None = None, |
|
) -> jax.Array: |
|
if labels is not None: |
|
|
|
return self.labels_emb(labels)[:, None, :] |
|
|
|
assert ((self.num_labels == 1 or self.num_labels is None) |
|
and batch_size is not None) |
|
|
|
return self.labels_emb(jnp.zeros((batch_size,), jnp.int32))[:, None, :] |
|
|
|
def prefill( |
|
self, labels=None, batch_size=None, encoded=None, drop_labels=None |
|
): |
|
labels = self._drop_labels(drop_labels, labels) |
|
labels_for_prefill = self.embed_labels(labels=labels, batch_size=batch_size) |
|
return self.decoder( |
|
labels_for_prefill, |
|
encoded=encoded, |
|
decode=True) |
|
|
|
def _decode_ar( |
|
self, |
|
targets: jax.Array, |
|
labels: jax.Array | None = None, |
|
encoded: jax.Array | None = None, |
|
decode: bool = False, |
|
train: bool = False, |
|
) -> tuple[jax.Array, Mapping[str, jax.Array]]: |
|
"""Autoregressive decoding.""" |
|
targets_embedded = self.targets_emb(targets) |
|
|
|
if decode: |
|
decoder_mask = None |
|
else: |
|
decoder_mask = nn.make_causal_mask(targets[:, :, 0]) |
|
b = targets.shape[0] |
|
labels_embedded = self.embed_labels(labels, b) |
|
assert labels_embedded.shape == (b, 1, self.emb_dim), ( |
|
labels_embedded.shape, (b, 1, self.emb_dim)) |
|
targets_embedded = jnp.concatenate( |
|
[labels_embedded, targets_embedded[:, : -1]], axis=1) |
|
|
|
logits, out = self.decoder( |
|
targets_embedded, |
|
encoded=encoded, |
|
decoder_mask=decoder_mask, |
|
decode=decode, |
|
deterministic=not train, |
|
return_reps=True) |
|
|
|
return logits, out |
|
|
|
def _get_special_label(self, size, label: _SpecialLabel): |
|
return self.labels_emb( |
|
jnp.full(size, self.special_labels[label], jnp.int32) |
|
) |
|
|
|
def _decode_masked( |
|
self, |
|
targets, |
|
input_mask, |
|
labels=None, |
|
encoded=None, |
|
train=False, |
|
): |
|
"""Masked decoding.""" |
|
b, s, _ = targets.shape |
|
assert input_mask.shape == (b, s) |
|
|
|
if self.mask_style == "replace": |
|
targets_embedded = jnp.where( |
|
input_mask[:, :, None], |
|
self._get_special_label((b, s), _SpecialLabel.MASK), |
|
self.targets_emb(targets), |
|
) |
|
elif self.mask_style == "concat": |
|
masks = jnp.where( |
|
input_mask[:, :, None], |
|
self._get_special_label((b, s), _SpecialLabel.MASK), |
|
self._get_special_label((b, s), _SpecialLabel.NOMASK), |
|
) |
|
embedded_targets = self.targets_emb(targets) |
|
targets_embedded = jnp.where( |
|
input_mask[:, :, None], |
|
self._get_special_label((b, s), _SpecialLabel.REPLACE), |
|
embedded_targets, |
|
) |
|
|
|
targets_embedded = jnp.concatenate( |
|
[masks[..., ::2], targets_embedded[..., ::2]], axis=-1 |
|
) |
|
else: |
|
raise ValueError(self.mask_style) |
|
|
|
labels_embedded = self.embed_labels(labels, b) |
|
assert labels_embedded.shape == (b, 1, self.emb_dim) |
|
|
|
|
|
targets_embedded = jnp.concatenate( |
|
[labels_embedded, targets_embedded], axis=1) |
|
|
|
logits = self.decoder( |
|
targets_embedded, |
|
encoded=encoded, |
|
decoder_mask=None, |
|
decode=False, |
|
deterministic=not train) |
|
|
|
logits = logits[:, 1:, ...] |
|
assert logits.shape[:2] == (b, s) |
|
return logits |
|
|
|
def _drop_labels(self, drop_labels_mask, labels): |
|
if labels is None: |
|
return None |
|
if self.drop_labels_probability >= 0.999: |
|
logging.warning("Dropping all labels...") |
|
return jnp.full_like(labels, self.special_labels[_SpecialLabel.NOLABEL]) |
|
if drop_labels_mask is None: |
|
return labels |
|
assert _SpecialLabel.NOLABEL in self.special_labels |
|
nolabel = jnp.full_like( |
|
labels, self.special_labels[_SpecialLabel.NOLABEL] |
|
) |
|
return jnp.where(drop_labels_mask, nolabel, labels) |
|
|
|
def decode( |
|
self, |
|
targets: jax.Array, |
|
labels: jax.Array | None = None, |
|
encoded: jax.Array | None = None, |
|
decode: bool = False, |
|
train: bool = False, |
|
max_decode_length: int | None = None, |
|
input_mask: jax.Array | None = None, |
|
drop_labels: jax.Array | None = None, |
|
return_reps: bool = False, |
|
) -> jax.Array | tuple[jax.Array, Mapping[str, jax.Array]]: |
|
"""Applies Transformer decoder-branch on encoded-input and target. |
|
|
|
Args: |
|
targets: target text tokens [B, L, out_dim]. |
|
labels: optional class labes, [B]. |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
decode: whether to prepare and use an autoregressive cache. |
|
train: whether it is training. |
|
max_decode_length: optional max length for positional embeddings. |
|
input_mask: If given, mask input. Required for style=="masked". |
|
Shape [B, L], bool tensor. True means the token will be removed |
|
from the input. |
|
drop_labels: Drop labels at corresponding locations [B]. |
|
return_reps: whether to return intermediate representations. |
|
|
|
Returns: |
|
logits array from transformer decoder [B, L, 3 * num_mixtures * out_dim]. |
|
""" |
|
del max_decode_length |
|
labels = self._drop_labels(drop_labels, labels) |
|
if self.style == "ar": |
|
logits, out = self._decode_ar( |
|
targets, labels, encoded, decode, train) |
|
if return_reps: |
|
return logits, out |
|
return logits |
|
elif self.style == "masked": |
|
assert not decode |
|
assert input_mask is not None |
|
assert not return_reps |
|
return self._decode_masked(targets, input_mask, labels, encoded, train) |
|
else: |
|
raise NotImplementedError(self.style) |
|
|
|
def _square_plus(self, x): |
|
|
|
if self.fix_square_plus: |
|
return (x + jnp.sqrt(jnp.square(x) + 4)) / 2 |
|
else: |
|
return x + jnp.sqrt(jnp.square(x) + 4) / 2 |
|
|
|
def get_pdf( |
|
self, |
|
logits: jax.Array, |
|
temperature_scales: float | None = None, |
|
temperature_probs: float | None = None, |
|
) -> distrax.Distribution: |
|
assert logits.shape[-1] == self.num_logits |
|
if self.multivariate: |
|
scales = logits[..., :self.out_dim ** 2] |
|
locs = logits[..., self.out_dim ** 2:] |
|
assert locs.shape[-1] == self.out_dim |
|
scales = self._square_plus(scales) |
|
|
|
*leading, _ = scales.shape |
|
scales = scales.reshape(*leading, self.out_dim, self.out_dim) |
|
|
|
diag_scale_tol = jnp.eye(self.out_dim) * self.scale_tol |
|
scales = jnp.maximum(scales, diag_scale_tol) |
|
if (t := temperature_scales) is not None: |
|
scales = scales * t |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return distrax.MultivariateNormalTri(locs, scales) |
|
|
|
elif self.per_channel_mixtures: |
|
|
|
logits = jnp.reshape(logits, logits.shape[: -1] + (-1, self.num_mixtures)) |
|
|
|
probs, locs, scales = jnp.split(logits, 3, axis=-2) |
|
if (t := temperature_probs) is not None: |
|
probs = probs * t |
|
|
|
|
|
probs = nn.softmax(probs) |
|
scales = self._square_plus(scales) |
|
|
|
scales = jnp.maximum(scales, self.scale_tol) |
|
if (t := temperature_scales) is not None: |
|
scales = scales * t |
|
|
|
|
|
|
|
|
|
return distrax.MixtureSameFamily( |
|
mixture_distribution=distrax.Categorical(probs=probs), |
|
components_distribution=distrax.Normal(loc=locs, scale=scales), |
|
) |
|
else: |
|
*shape, num_logits = logits.shape |
|
assert num_logits == self.num_logits, (num_logits, self.num_logits) |
|
prob_logits, other_logits = ( |
|
logits[..., : self.num_mixtures], |
|
logits[..., self.num_mixtures :], |
|
) |
|
if (t := temperature_probs) is not None: |
|
prob_logits = prob_logits * t |
|
other_logits = jnp.reshape( |
|
other_logits, (*shape, self.num_mixtures, 2, self.out_dim) |
|
) |
|
locs = other_logits[..., 0, :] |
|
scales = self._square_plus(other_logits[..., 1, :]) |
|
|
|
scales = jnp.maximum(scales, self.scale_tol) |
|
if (t := temperature_scales) is not None: |
|
scales = scales * t |
|
|
|
|
|
|
|
assert prob_logits.ndim == locs.ndim - 1, (prob_logits.shape, locs.shape) |
|
assert locs.shape == scales.shape, (locs.shape, scales.shape) |
|
|
|
|
|
|
|
|
|
|
|
return distrax.MixtureSameFamily( |
|
mixture_distribution=distrax.Categorical(logits=prob_logits), |
|
components_distribution=distrax.MultivariateNormalDiag( |
|
loc=locs, scale_diag=scales |
|
), |
|
) |
|
|
|
def __call__( |
|
self, |
|
sequence: jax.Array, |
|
labels: jax.Array | None = None, |
|
*, |
|
image: jax.Array | None = None, |
|
decode: bool = False, |
|
input_mask: jax.Array | None = None, |
|
drop_labels: jax.Array | None = None, |
|
train: bool = False, |
|
) -> tuple[jax.Array, distrax.Distribution]: |
|
"""Applies Transformer model on the inputs. |
|
|
|
Args: |
|
sequence: batch of sequences [B, L]. |
|
labels: class labels for class conditional generation [B]. |
|
image: batch of images [B, H, W, 3]. |
|
decode: whether to prepare and use an autoregressive cache. |
|
input_mask: If given, mask input. Required for style=="masked" [B, L]. |
|
drop_labels: If given, drop labels of the corresponding batches [B]. |
|
train: whether it is training. |
|
|
|
Returns: |
|
logits array from full transformer [B, L, out_dim]. |
|
""" |
|
if self.style == "masked" and input_mask is None: |
|
raise ValueError("Cannot run masked model without input mask!") |
|
|
|
if self.encoder is not None: |
|
assert image is not None |
|
encoded = self.encode(image, train=train) |
|
else: |
|
assert image is None |
|
encoded = None |
|
|
|
logits = self.decode(sequence, labels=labels, encoded=encoded, |
|
decode=decode, input_mask=input_mask, train=train) |
|
pdf = self.get_pdf(logits) |
|
return logits, pdf |
|
|
|
def get_input_mask_training( |
|
self, |
|
rng: jax.Array, |
|
shape: tuple[int, int], |
|
) -> jax.Array | None: |
|
"""Creates a random maask of shape (B, L) for training masked models.""" |
|
if self.style == "ar": |
|
return None |
|
b, s = shape |
|
|
|
keep = jax.random.uniform( |
|
rng, shape=(b,), maxval=1.0 - self.min_masking_rate_training |
|
) |
|
mask_ratio = apply_mask_schedule(keep, self.mask_schedule_train) |
|
return _random_mask_with_ratios(rng, ratios=mask_ratio, seq_len=s) |
|
|
|
def get_input_mask_teacher_forced( |
|
self, |
|
shape: tuple[int, int], |
|
) -> jax.Array | None: |
|
"""Creates a random maask of shape (B, L) for training masked models.""" |
|
if self.style == "ar": |
|
return None |
|
return jnp.zeros(shape, dtype=jnp.bool_) |
|
|
|
def get_drop_labels( |
|
self, |
|
rng: jax.Array, |
|
batch_size: int, |
|
) -> jax.Array | None: |
|
if (p := self.drop_labels_probability) > 0: |
|
return jax.random.uniform(rng, shape=(batch_size,)) <= p |
|
else: |
|
return None |
|
|
|
|
|
def load( |
|
init_params: Any, |
|
init_files: str | Mapping[str, str], |
|
model_params: Any = None, |
|
dont_load: Sequence[str] = (), |
|
resample_encoder_posemb: bool = False, |
|
trim_decoder_posemb: bool = False, |
|
) -> Any: |
|
"""Loads params from init checkpoint and merges into init_params.""" |
|
del model_params |
|
if isinstance(init_files, str): |
|
ckpt_params = utils.load_params(init_files) |
|
ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) |
|
|
|
if resample_encoder_posemb: |
|
if init_params and "pos_embedding_encoder" in init_params: |
|
ckpt_params["pos_embedding_encoder"] = vit.resample_posemb( |
|
old=ckpt_params["pos_embedding_encoder"], |
|
new=init_params["pos_embedding_encoder"]) |
|
|
|
if trim_decoder_posemb: |
|
if init_params and "pos_embedding_decoder" in init_params: |
|
ckpt_params["pos_embedding_decoder"] = ( |
|
ckpt_params["pos_embedding_decoder"][ |
|
:, :init_params["pos_embedding_decoder"].shape[1], :]) |
|
|
|
else: |
|
init_files = {**init_files} |
|
|
|
enc_init = init_files.pop("encoder", None) |
|
if enc_init: |
|
ckpt_params = init_params.copy() |
|
vit_params = { |
|
"pos_embedding": ckpt_params["pos_embedding_encoder"], |
|
"Transformer": ckpt_params["encoder"], |
|
"embedding": ckpt_params["EmbedPatches"], |
|
} |
|
encoder_params = vit.load( |
|
vit_params, enc_init, model_cfg={}, |
|
dont_load=dont_load) |
|
ckpt_params["encoder"] = encoder_params["Transformer"] |
|
ckpt_params["pos_embedding_encoder"] = encoder_params["pos_embedding"] |
|
ckpt_params["EmbedPatches"] = encoder_params["embedding"] |
|
else: |
|
raise ValueError("Only encoder init is supported: {}.".format(init_files)) |
|
|
|
return ckpt_params |
|
|