|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model definitions for CapPa (https://arxiv.org/abs/2306.07915). |
|
|
|
Used abbreviations for dimension annotations: |
|
B: batch size. |
|
H: image height. |
|
W: image width. |
|
P: number of patches (PH/PW: number of patches in height/width dimensions). |
|
E: embedding size. |
|
L: sequence length of text tokens. |
|
V: vocab size. |
|
""" |
|
|
|
from collections.abc import Sequence |
|
|
|
from big_vision import utils |
|
from big_vision.models import common |
|
from big_vision.models import vit |
|
import flax |
|
import flax.linen as nn |
|
from flax.linen import partitioning |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
def shift_right(x, axis=1, constant_values=0): |
|
"""Shift to the right on given axis with padding value 0.""" |
|
pad_widths = [(0, 0)] * len(x.shape) |
|
pad_widths[axis] = (1, 0) |
|
padded = jnp.pad(x, pad_widths, constant_values=constant_values) |
|
|
|
|
|
return padded[tuple(slice(-1 if i == axis else None) for i in range(x.ndim))] |
|
|
|
|
|
class MlpBlock(nn.Module): |
|
"""Transformer MLP / feed-forward block with option to deactivate bias.""" |
|
mlp_dim: int | None = None |
|
dropout: float = 0.0 |
|
use_bias: bool = True |
|
|
|
@nn.compact |
|
def __call__(self, x, deterministic=True): |
|
"""Applies Transformer MlpBlock module.""" |
|
inits = dict( |
|
kernel_init=nn.initializers.xavier_uniform(), |
|
bias_init=nn.initializers.normal(stddev=1e-6), |
|
) |
|
|
|
n, l, d = x.shape |
|
x = nn.Dense(self.mlp_dim or 4 * d, use_bias=self.use_bias, **inits)(x) |
|
x = nn.gelu(x) |
|
x = nn.Dropout(rate=self.dropout)(x, deterministic) |
|
x = nn.Dense(d, use_bias=self.use_bias, **inits)(x) |
|
return x |
|
|
|
|
|
class EncoderDecoderBlock(nn.Module): |
|
"""Transformer encoder-decoder layer.""" |
|
mlp_dim: int |
|
num_heads: int |
|
dropout_rate: float = 0. |
|
decode: bool = False |
|
use_bias: bool = True |
|
|
|
@nn.compact |
|
def __call__(self, targets, encoded, decoder_mask=None, deterministic=True): |
|
"""Applies EncoderDecoder1DBlock module. |
|
|
|
Args: |
|
targets: target text embeddings [B, L, E]. |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
decoder_mask: decoder self-attention mask. |
|
deterministic: bool, deterministic or not (to apply dropout). |
|
|
|
Returns: |
|
output after transformer encoder-decoder block [B, L, E]. |
|
""" |
|
def wlc(f): |
|
dim_names = ("act_batch", "act_len", "act_emb") |
|
return nn.with_logical_constraint(f, dim_names) |
|
|
|
|
|
x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=self.use_bias)(targets)) |
|
x = wlc(nn.SelfAttention( |
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, |
|
dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( |
|
x, decoder_mask, deterministic=deterministic)) |
|
x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)) |
|
x = wlc(x + targets) |
|
|
|
if encoded is not None: |
|
|
|
y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=self.use_bias)(x)) |
|
y = wlc(nn.MultiHeadDotProductAttention( |
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, |
|
dropout_rate=self.dropout_rate, name="CrossAttn")( |
|
y, encoded, deterministic=deterministic)) |
|
y = wlc( |
|
nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)) |
|
y = wlc(y + x) |
|
else: |
|
y = x |
|
|
|
|
|
z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=self.use_bias)(y)) |
|
z = wlc(MlpBlock( |
|
mlp_dim=self.mlp_dim, dropout=self.dropout_rate, use_bias=self.use_bias, |
|
name="MLP")(z, deterministic=deterministic)) |
|
|
|
return wlc(y + z), None |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Transformer decoder with parallel prediction.""" |
|
emb_dim: int |
|
mlp_dim: int |
|
num_heads: int |
|
num_layers: int |
|
dropout_rate: float = 0. |
|
output_vocab_size: int = 32_000 |
|
|
|
|
|
masked_pred_prob: float = 0. |
|
masking_ratio: float = 0. |
|
|
|
|
|
use_bias: bool = True |
|
|
|
scan: bool = False |
|
remat_policy: str = "nothing_saveable" |
|
|
|
@nn.compact |
|
def __call__(self, |
|
encoded, |
|
targets, |
|
pos_emb, |
|
decoder_mask=None, |
|
decode=False, |
|
deterministic=True, |
|
max_decode_length=None): |
|
"""Applies Transformer model on the inputs. |
|
|
|
Args: |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
targets: target text tokens [B, L]. |
|
pos_emb: positional embeddings. |
|
decoder_mask: decoder self-attention mask. |
|
decode: bool, whether to perform fast autoregressive decoding with cache. |
|
deterministic: bool, deterministic or not (to apply dropout). |
|
max_decode_length: optional max length for positional embeddings. |
|
|
|
Returns: |
|
output of a transformer decoder [B, L, V]. |
|
""" |
|
y = targets.astype("int32") |
|
if not decode: |
|
if self.masked_pred_prob > 0.0 and not deterministic: |
|
|
|
|
|
def _add_random_masks(a): |
|
|
|
n_masked = int(self.masking_ratio * a.shape[1]) |
|
mask_locations = jnp.zeros(a.shape[:2], dtype=jnp.int32) |
|
mask_locations = mask_locations.at[:, :n_masked].set(1) |
|
mask_locations = jax.random.permutation( |
|
self.make_rng("dropout"), mask_locations, axis=1, independent=True |
|
) |
|
|
|
a_masked = jnp.where(mask_locations, self.output_vocab_size, a) |
|
return a_masked |
|
|
|
def where(mask, x, y): |
|
mask = mask.reshape((-1,) + (1,) * (x.ndim - 1)) |
|
return jnp.where(mask, x, y) |
|
|
|
do_masked_pred = ( |
|
jax.random.uniform(self.make_rng("dropout"), (len(y),)) |
|
< self.masked_pred_prob |
|
) |
|
y = where(do_masked_pred, _add_random_masks(y), shift_right(y)) |
|
decoder_mask = where( |
|
do_masked_pred, jnp.ones_like(decoder_mask), decoder_mask |
|
) |
|
|
|
else: |
|
y = shift_right(y) |
|
|
|
embed = nn.Embed( |
|
self.output_vocab_size + (1 if self.masked_pred_prob > 0.0 else 0), |
|
self.emb_dim, |
|
name="EmbedTargets", |
|
embedding_init=nn.initializers.normal(stddev=1.0), |
|
) |
|
y = embed(y) |
|
|
|
y = common.AddPositionEmbs( |
|
decode=decode, name="PosEmbedTargets")(y, pos_emb) |
|
|
|
|
|
|
|
|
|
if self.scan: |
|
|
|
|
|
|
|
|
|
enc_dec_block_remat = nn.remat( |
|
EncoderDecoderBlock, |
|
prevent_cse=False, |
|
static_argnums=(-1,), |
|
policy=getattr(jax.checkpoint_policies, self.remat_policy, None)) |
|
|
|
initializing = self.is_mutable_collection("params") |
|
param_scan_axis = 1 |
|
params_spec = (param_scan_axis if initializing |
|
else partitioning.ScanIn(param_scan_axis)) |
|
dec_scanned = nn.scan(enc_dec_block_remat, |
|
variable_axes={ |
|
"params": params_spec, |
|
"cache": 0, |
|
}, |
|
split_rngs={"params": True, "dropout": True}, |
|
in_axes=nn.broadcast, |
|
length=self.num_layers) |
|
|
|
y, _ = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim, |
|
dropout_rate=self.dropout_rate, decode=decode, |
|
use_bias=self.use_bias, name="EncDecBlock")( |
|
y, encoded, decoder_mask, deterministic) |
|
else: |
|
for lyr in range(self.num_layers): |
|
y, _ = EncoderDecoderBlock( |
|
num_heads=self.num_heads, mlp_dim=self.mlp_dim, |
|
dropout_rate=self.dropout_rate, decode=decode, |
|
use_bias=self.use_bias, name=f"EncDecBlock{lyr}")( |
|
y, encoded, decoder_mask=decoder_mask, |
|
deterministic=deterministic) |
|
|
|
y = nn.LayerNorm(name="LayerNorm")(y) |
|
|
|
logits = nn.Dense( |
|
self.output_vocab_size, |
|
kernel_init=nn.initializers.zeros, |
|
name="LogitsDense", |
|
)(y) |
|
return logits |
|
|
|
|
|
class Model(nn.Module): |
|
"""Transformer Model for sequence to sequence translation.""" |
|
|
|
num_heads: int = 8 |
|
num_layers: int = 6 |
|
mlp_dim: int = 2048 |
|
emb_dim: int = 512 |
|
enc_dropout_rate: float = 0. |
|
vocab_size: int = 32_000 |
|
seq_len: int = 256 |
|
|
|
|
|
patches: Sequence[int] = (16, 16) |
|
input_seq_len: int = 768 |
|
posemb_type: str = "learn" |
|
patch_dropout: float = 0. |
|
|
|
|
|
decoder_num_heads: int = 0 |
|
decoder_num_layers: int = 0 |
|
decoder_mlp_dim: int = 0 |
|
decoder_emb_dim: int = 0 |
|
dec_dropout_rate: float = 0. |
|
|
|
masked_pred_prob: float = 0. |
|
|
|
masking_ratio: float = 0. |
|
|
|
decoder_bias: bool = True |
|
|
|
scan: bool = False |
|
remat_policy: str = "nothing_saveable" |
|
|
|
def setup(self): |
|
|
|
self.encoder = vit.Model( |
|
patch_size=self.patches, |
|
width=self.emb_dim, |
|
depth=self.num_layers, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
dropout=self.enc_dropout_rate, |
|
posemb=self.posemb_type, |
|
scan=self.scan, |
|
remat_policy=self.remat_policy, |
|
) |
|
|
|
self.pos_emb_for_decoder = vit.get_posemb( |
|
self, |
|
self.posemb_type, |
|
(1, self.seq_len), |
|
self.decoder_emb_dim or self.emb_dim, |
|
"pos_embedding_decoder", |
|
) |
|
self.decoder = Decoder( |
|
num_layers=self.decoder_num_layers or self.num_layers, |
|
mlp_dim=self.decoder_mlp_dim or self.mlp_dim, |
|
num_heads=self.decoder_num_heads or self.num_heads, |
|
dropout_rate=self.dec_dropout_rate, |
|
emb_dim=self.decoder_emb_dim or self.emb_dim, |
|
output_vocab_size=self.vocab_size, |
|
masked_pred_prob=self.masked_pred_prob, |
|
masking_ratio=self.masking_ratio, |
|
use_bias=self.decoder_bias, |
|
scan=self.scan, |
|
remat_policy=self.remat_policy, |
|
) |
|
|
|
def encode(self, image, train=False, return_enc_features=False): |
|
"""Encodes input image or embeddings.""" |
|
|
|
_, out = self.encoder(image, train=train) |
|
encoded = out["encoded"] |
|
|
|
|
|
if return_enc_features: |
|
return encoded, out |
|
|
|
return encoded |
|
|
|
def decode(self, encoded, targets, decode=False, train=False, |
|
max_decode_length=None): |
|
"""Applies Transformer decoder-branch on encoded-input and target. |
|
|
|
Args: |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
targets: target text tokens [B, L]. |
|
decode: whether to prepare and use an autoregressive cache. |
|
train: whether it is training. |
|
max_decode_length: optional max length for positional embeddings. |
|
|
|
Returns: |
|
logits array from transformer decoder [B, L, V]. |
|
""" |
|
decoder_mask = None if decode else nn.make_causal_mask(targets) |
|
logits = self.decoder( |
|
encoded, |
|
targets, |
|
pos_emb=self.pos_emb_for_decoder, |
|
decoder_mask=decoder_mask, |
|
decode=decode, |
|
deterministic=not train, |
|
max_decode_length=max_decode_length) |
|
return logits |
|
|
|
def __call__(self, image, text, *, decode=False, |
|
train=False, return_enc_features=False): |
|
"""Applies Transformer model on the inputs. |
|
|
|
Args: |
|
image: batch of images [B, H, W, 3]. |
|
text: batch of tokenized texts [B, L]. |
|
decode: whether to prepare and use an autoregressive cache. |
|
train: whether it is training. |
|
return_enc_features: whether to return the encoder features. |
|
|
|
Returns: |
|
logits array from full transformer [B, L, V]. |
|
""" |
|
if return_enc_features: |
|
encoded, out = self.encode(image, train=train, return_enc_features=True) |
|
return encoded, out |
|
|
|
encoded = self.encode(image, train=train) |
|
|
|
decoded = self.decode(encoded, text, decode=decode, train=train) |
|
return decoded |
|
|
|
|
|
def load(init_params, init_files, model_params=None, |
|
dont_load=("head/kernel", "head/bias", "cls")): |
|
"""Loads params from init checkpoint and merges into init_params.""" |
|
|
|
if isinstance(init_files, str): |
|
|
|
ckpt_params = utils.load_params(init_files) |
|
ckpt_params = flax.training.checkpoints.convert_pre_linen(ckpt_params) |
|
ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) |
|
|
|
|
|
if (model_params.get("scan") and |
|
"encoderblock" not in ckpt_params["encoder"]["Transformer"]): |
|
raise NotImplementedError("Loading a non-scan checkpoint into a " |
|
"scan model is not supported yet!") |
|
if (not model_params.get("scan") |
|
and "encoderblock" in ckpt_params["encoder"]["Transformer"]): |
|
assert "decoder.*" in dont_load or "decoder/.*" in dont_load, ( |
|
"Converting scan decoder to a non-scan one is not supported yet!") |
|
ckpt_params["encoder"] = utils.jit_cpu()( |
|
vit.scan_to_pyloop)(ckpt_params["encoder"]) |
|
|
|
else: |
|
assert set(init_files) == {"encoder"}, "Only encoder init supported" |
|
enc_init = init_files["encoder"] |
|
ckpt_params = flax.core.freeze(init_params).unfreeze() |
|
vit_params = ckpt_params["encoder"] |
|
encoder_params = vit.load( |
|
vit_params, enc_init, model_cfg={}, |
|
dont_load=dont_load) |
|
ckpt_params["encoder"] = encoder_params |
|
|
|
ckpt_params["encoder"]["pos_embedding"] = vit.resample_posemb( |
|
old=ckpt_params["encoder"]["pos_embedding"], |
|
new=init_params["encoder"]["pos_embedding"]) |
|
|
|
return ckpt_params |
|
|