|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Simple vision-text transformer with encoder-decoder architecture. |
|
|
|
Used abbreviations for dimension annotations: |
|
B: batch size. |
|
H: image height. |
|
W: image width. |
|
P: number of patches (PH/PW: number of patches in height/width dimensions). |
|
E: embedding size. |
|
L: sequence length of text tokens. |
|
V: vocab size. |
|
""" |
|
from typing import Sequence |
|
from big_vision import utils |
|
from big_vision.models import common |
|
from big_vision.models import vit |
|
import einops |
|
import flax |
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
import ml_collections |
|
import numpy as np |
|
|
|
|
|
def shift_right(x, axis=1): |
|
"""Shift to the right on given axis with padding value 0.""" |
|
pad_widths = [(0, 0)] * len(x.shape) |
|
pad_widths[axis] = (1, 0) |
|
padded = jnp.pad(x, pad_widths, constant_values=0) |
|
return padded[:, :-1] |
|
|
|
|
|
class EncoderDecoderBlock(nn.Module): |
|
"""Transformer encoder-decoder layer.""" |
|
mlp_dim: int |
|
num_heads: int |
|
dropout_rate: float = 0. |
|
decode: bool = False |
|
|
|
@nn.compact |
|
def __call__(self, targets, encoded, decoder_mask=None, deterministic=True): |
|
"""Applies EncoderDecoder1DBlock module. |
|
|
|
Args: |
|
targets: target text embeddings [B, L, E]. |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
decoder_mask: decoder self-attention mask. |
|
deterministic: bool, deterministic or not (to apply dropout). |
|
|
|
Returns: |
|
output after transformer encoder-decoder block [B, L, E]. |
|
""" |
|
|
|
x = nn.LayerNorm(name="LayerNorm1")(targets) |
|
x = nn.SelfAttention( |
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, |
|
dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( |
|
x, decoder_mask, deterministic=deterministic) |
|
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) |
|
x = x + targets |
|
|
|
|
|
y = nn.LayerNorm(name="LayerNorm2")(x) |
|
y = nn.MultiHeadDotProductAttention( |
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, |
|
dropout_rate=self.dropout_rate, name="CrossAttn")( |
|
y, encoded, deterministic=deterministic) |
|
y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic) |
|
y = y + x |
|
|
|
|
|
z = nn.LayerNorm(name="LayerNorm3")(y) |
|
z = vit.MlpBlock(mlp_dim=self.mlp_dim, dropout=self.dropout_rate, |
|
name="MLP")(z, deterministic=deterministic) |
|
|
|
return y + z |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Transformer Model Decoder for sequence to sequence translation.""" |
|
emb_dim: int |
|
mlp_dim: int |
|
num_heads: int |
|
num_layers: int |
|
dropout_rate: float = 0. |
|
output_vocab_size: int = 32000 |
|
zero_decoder_seq: bool = False |
|
|
|
@nn.compact |
|
def __call__(self, |
|
encoded, |
|
targets, |
|
pos_emb, |
|
decoder_mask=None, |
|
decode=False, |
|
deterministic=True, |
|
max_decode_length=None): |
|
"""Applies Transformer model on the inputs. |
|
|
|
Args: |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
targets: target text tokens [B, L]. |
|
pos_emb: positional embeddings. |
|
decoder_mask: decoder self-attention mask. |
|
decode: bool, whether to perform fast autoregressive decoding with cache. |
|
deterministic: bool, deterministic or not (to apply dropout). |
|
max_decode_length: optional max length for positional embeddings. |
|
|
|
Returns: |
|
output of a transformer decoder [B, L, V]. |
|
""" |
|
y = targets.astype("int32") |
|
if not decode: |
|
y = shift_right(y) |
|
y = nn.Embed(self.output_vocab_size, self.emb_dim, name="EmbedTargets", |
|
embedding_init=nn.initializers.normal(stddev=1.0))(y) |
|
if self.zero_decoder_seq: |
|
y = jnp.zeros_like(y) |
|
y = common.AddPositionEmbs( |
|
decode=decode, name="PosEmbedTargets")(y, pos_emb) |
|
y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic) |
|
|
|
for lyr in range(self.num_layers): |
|
y = EncoderDecoderBlock( |
|
num_heads=self.num_heads, mlp_dim=self.mlp_dim, |
|
dropout_rate=self.dropout_rate, decode=decode, |
|
name=f"EncDecBlock{lyr}")(y, encoded, decoder_mask=decoder_mask, |
|
deterministic=deterministic) |
|
y = nn.LayerNorm(name="LayerNorm")(y) |
|
logits = nn.Dense(self.output_vocab_size, kernel_init=nn.initializers.zeros, |
|
name="LogitsDense")(y) |
|
return logits |
|
|
|
|
|
class Model(nn.Module): |
|
"""Transformer Model for sequence to sequence translation.""" |
|
patches: ml_collections.ConfigDict |
|
|
|
num_heads: int = 8 |
|
num_layers: int = 6 |
|
mlp_dim: int = 2048 |
|
dropout_rate: float = 0. |
|
|
|
emb_dim: int = 512 |
|
vocab_size: int = 32000 |
|
seq_len: int = 256 |
|
|
|
input_size: Sequence[int] = (256, 256) |
|
posemb_type: str = "sincos2d" |
|
zero_decoder_seq: bool = False |
|
|
|
def setup(self): |
|
grid_size = np.array(self.input_size) // np.array(self.patches.size) |
|
self.pos_emb_for_encoder = vit.get_posemb( |
|
self, self.posemb_type, grid_size, self.emb_dim, |
|
"pos_embedding_encoder") |
|
self.pos_emb_for_decoder = vit.get_posemb( |
|
self, self.posemb_type, (1, self.seq_len), self.emb_dim, |
|
"pos_embedding_decoder") |
|
|
|
self.encoder = vit.Encoder( |
|
depth=self.num_layers, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
dropout=self.dropout_rate) |
|
self.decoder = Decoder( |
|
num_layers=self.num_layers, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
dropout_rate=self.dropout_rate, |
|
emb_dim=self.emb_dim, |
|
output_vocab_size=self.vocab_size, |
|
zero_decoder_seq=self.zero_decoder_seq, |
|
) |
|
self.conv = nn.Conv(self.emb_dim, self.patches.size, padding="VALID", |
|
strides=self.patches.size, name="EmbedPatches") |
|
|
|
def encode(self, image, train=False): |
|
"""Encodes input image or embeddings.""" |
|
emb = self.conv(image) |
|
patch_embeddings = einops.rearrange(emb, "B PH PW E -> B (PH PW) E") |
|
encoded, _ = self.encoder( |
|
patch_embeddings + self.pos_emb_for_encoder, deterministic=not train) |
|
return encoded |
|
|
|
def decode(self, encoded, targets, decode=False, train=False, |
|
max_decode_length=None): |
|
"""Applies Transformer decoder-branch on encoded-input and target. |
|
|
|
Args: |
|
encoded: encoded image patches from encoder [B, P, E]. |
|
targets: target text tokens [B, L]. |
|
decode: whether to prepare and use an autoregressive cache. |
|
train: whether it is training. |
|
max_decode_length: optional max length for positional embeddings. |
|
|
|
Returns: |
|
logits array from transformer decoder [B, L, V]. |
|
""" |
|
decoder_mask = None if decode else nn.make_causal_mask(targets) |
|
logits = self.decoder( |
|
encoded, |
|
targets, |
|
pos_emb=self.pos_emb_for_decoder, |
|
decoder_mask=decoder_mask, |
|
decode=decode, |
|
deterministic=not train, |
|
max_decode_length=max_decode_length) |
|
return logits |
|
|
|
def __call__(self, image, text, *, decode=False, train=False): |
|
"""Applies Transformer model on the inputs. |
|
|
|
Args: |
|
image: batch of images [B, H, W, 3]. |
|
text: batch of tokenized texts [B, L]. |
|
decode: whether to prepare and use an autoregressive cache. |
|
train: whether it is training. |
|
|
|
Returns: |
|
logits array from full transformer [B, L, V]. |
|
""" |
|
encoded = self.encode(image, train=train) |
|
return self.decode(encoded, text, decode=decode, train=train) |
|
|
|
|
|
def load(init_params, init_files, model_params=None, |
|
dont_load=("head/kernel", "head/bias", "cls")): |
|
"""Loads params from init checkpoint and merges into init_params.""" |
|
del model_params |
|
if isinstance(init_files, str): |
|
|
|
ckpt_params = utils.load_params(None, init_files) |
|
ckpt_params = flax.training.checkpoints.convert_pre_linen(ckpt_params) |
|
if init_params is not None: |
|
ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) |
|
else: |
|
init_files = {**init_files} |
|
|
|
enc_init = init_files.pop("encoder", None) |
|
if enc_init: |
|
ckpt_params = init_params.copy() |
|
vit_params = { |
|
"pos_embedding": ckpt_params["pos_embedding_encoder"], |
|
"Transformer": ckpt_params["encoder"], |
|
"embedding": ckpt_params["EmbedPatches"], |
|
} |
|
encoder_params = vit.load( |
|
vit_params, enc_init, model_cfg={}, |
|
dont_load=dont_load) |
|
ckpt_params["encoder"] = encoder_params["Transformer"] |
|
ckpt_params["pos_embedding_encoder"] = encoder_params["pos_embedding"] |
|
ckpt_params["EmbedPatches"] = encoder_params["embedding"] |
|
else: |
|
raise ValueError("Only encoder init is supported: {}.".format(init_files)) |
|
|
|
return ckpt_params |
|
|