bilma_AR / modeling_bilma.py
guillermoruiz's picture
Update modeling_bilma.py
c4f22f8 verified
from transformers import TFPreTrainedModel, PreTrainedTokenizer, BatchEncoding
from tensorflow.keras.models import Model, load_model, Sequential
from tensorflow.keras.layers import Layer, Dense, concatenate, Input, add, Dropout, LayerNormalization, MultiHeadAttention, Embedding
import tensorflow as tf
import numpy as np
from typing import Dict
import re
import unicodedata
from .configuration_bilma import BilmaConfig
# copied from preprocessing.py
BLANK = ' '
RE_OPS = re.I | re.M | re.S
RE_USR = re.compile(r"""@\S+""", RE_OPS)
RE_TAG = re.compile(r"""#\S+""", RE_OPS)
RE_URL = re.compile(r"""(http|ftp|https)://\S+""", RE_OPS)
RE_NUM = re.compile(r"""[-+]?\d+\.?\d*""", RE_OPS)
SYMBOLS_ = "()[]¿?¡!{}~<>|"
SYMBOLS = set(";:,.@\\-\"/" + SYMBOLS_)
# ------------------
# Class declaration
# ------------------
class TFBilma(TFPreTrainedModel):
config_class = BilmaConfig
main_input_name = "input_ids"
#base_model_prefix = "bilma"
def __init__(self, config):
self.seq_max_length = config.seq_max_length
self.include_top = config.include_top
self.add_head = config.add_head
super().__init__(config)
self.model = bilma(num_enc=config.num_hidden_layers,
embed_dim=config.hidden_size,
max_length=config.seq_max_length,
num_heads=config.num_attention_heads,
ff_dim=config.hidden_size,
vocab_size=config.vocab_size,
rate=config.hidden_dropout_prob,
include_top = config.include_top,
add_head = config.add_head,
pooling = config.pooling)
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
dummies = {}
for key, spec in self.input_signature.items():
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
if spec.shape[0] is None:
dummy_shape[0] = 1
dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
return dummies
@property
def input_signature(self) -> Dict[str, tf.TensorSpec]:
sig = {}
sig["input_ids"] = tf.TensorSpec([None, self.seq_max_length], tf.int32, name="input_ids")
return sig
def call(self, inputs):
if isinstance(inputs, Dict) or isinstance(inputs, BatchEncoding):
ins = tf.cast(inputs["input_ids"], tf.float32)
else:
ins = inputs
if self.include_top:
output = {"logits":self.model(ins)}
else:
if self.add_head is None:
output = {"last_hidden_state":self.model(ins)}
else:
output = {"label":self.model(ins)}
return output
def get_loss_function():
return loss_funtion()
def get_acc_function():
return accuracy_function()
# copied from bilma_model.py
# --------------------------
def loss_function(ignore_id=0):
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, ignore_id))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
sum_ = tf.reduce_sum(mask,axis=1)
loss_ = tf.math.divide_no_nan(tf.reduce_sum(loss_, axis=1), sum_)
return loss_
return loss
def accuracy_function(ignore_id=0):
def acc_mlm(real, pred):
accuracies = tf.equal(tf.cast(real, tf.int64), tf.argmax(pred, axis=2))
mask = tf.math.logical_not(tf.math.equal(real, ignore_id))
accuracies = tf.math.logical_and(mask, accuracies)
accuracies = tf.cast(accuracies, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
return acc_mlm
def mean_vectors(inputs, enc_vectors, max_length):
p = tf.where(inputs == 3)
pos = tf.transpose(p)[1]
C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
C = tf.reshape(C, (-1, max_length, 1))
S = tf.reduce_sum(enc_vectors * C, 1)
x = S / tf.expand_dims(tf.cast(pos, tf.float32), (1))
return x
def mean_diff_vectors(inputs, enc_vectors, max_length):
p = tf.where(inputs == 3)
pos = tf.transpose(p)[1]
C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
C = tf.reshape(C, (-1, max_length, 1))
vecs = enc_vectors * C
S = tf.reduce_sum(vecs, 1)
mu = S / tf.expand_dims(tf.cast(pos, tf.float32), (1))
x = tf.reduce_sum(mu - vecs, 1) / tf.expand_dims(tf.cast(pos, tf.float32), (1))
return x
def max_vectors(inputs, enc_vectors, max_length):
p = tf.where(inputs == 3)
pos = tf.transpose(p)[1]
C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
C = tf.reshape(C, (-1, max_length, 1))
x = tf.reduce_max(enc_vectors * C, 1)
return x
def cls_vectors(inputs, enc_vectors, max_length):
x = tf.squeeze(enc_vectors[:, 0:1, :], axis=1)
return x
def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True, add_head=None, pooling=None):
capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
capt_inputs = capt_embedding(capt_inputs_ids)
enc = Encoder(num_enc, embed_dim, max_length, num_heads, ff_dim, rate=rate, name="bilma/encoder")
enc_output = enc(capt_inputs)
if include_top:
fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
else:
x = enc_output
if pooling == "mean":
x = mean_vectors(capt_inputs_ids, x, max_length)
elif pooling == "cls":
x = cls_vectors(capt_inputs_ids, x, max_length)
elif pooling == "max":
x = max_vectors(capt_inputs_ids, x, max_length)
if add_head is None:
fin_output = x
else:
for i, m in enumerate(add_head[:-1]):
x = Dense(m, use_bias=True, activation="relu", name=f"bilma/dense_ex_{i}")(x)
fin_output = Dense(add_head[-1], use_bias=True, activation="softmax", name=f"bilma/dense_ex_final")(x)
caption_model = Model(inputs=capt_inputs_ids, outputs=fin_output, name="bilma_model")
return caption_model
def load(model_file):
custom_objects={"EncoderBlock": EncoderBlock,
"Encoder": Encoder,
"loss": loss_function(),
"acc_mlm":accuracy_function(),
}
return load_model(model_file, custom_objects=custom_objects)
#
# Copied from transformer_text.py
# -------------------------------
class EncoderBlock(Layer):
def __init__(self, layer_num, patch_dim, num_heads, ff_dim, rate=0.1, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.ln = layer_num
self.p_d = patch_dim
self.n_h = num_heads
self.f_d = ff_dim
self.rate = rate
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=patch_dim, name=f"bilma/MHA_{layer_num}")
self.ffn = Sequential(
#[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu),
# Conv1D(patch_dim, kernel_size=1),]
[Dense(ff_dim, activation=tf.nn.gelu, name=f"bilma/dense1_{layer_num}"),
Dense(patch_dim, name=f"bilma/dense2_{layer_num}")]
)
#self.layernorm0 = LayerNormalization(epsilon=1e-6)
self.layernorm1 = LayerNormalization(epsilon=1e-6, name=f"ln1_{layer_num}")
self.layernorm2 = LayerNormalization(epsilon=1e-6, name=f"ln2_{layer_num}")
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def get_config(self):
config = super(EncoderBlock, self).get_config()
config.update({"layer_num":self.ln, "patch_dim":self.p_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
return config
def call(self, inputs, training=False):
#inputs = self.layernorm0(inputs)
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(add([inputs, attn_output]))
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(add([out1, ffn_output]))
class DecoderBlock(Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.e_d = embed_dim
self.n_h = num_heads
self.f_d = ff_dim
self.rate = rate
self.att1 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.att2 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = Sequential(
#[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu),
# Conv1D(embed_dim, kernel_size=1),]
[Dense(ff_dim, activation=tf.nn.gelu),
Dense(embed_dim),]
)
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
self.dropout3 = Dropout(rate)
def get_config(self):
config = super(DecoderBlock, self).get_config()
config.update({"embed_dim":self.e_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
return config
def call(self, inputs, encoder_output, look_ahead_mask, padding_mask, training=None):
y, attn_output1 = self.att1(inputs, inputs, attention_mask=look_ahead_mask, return_attention_scores=True)
y = self.dropout1(y, training=training)
y = add([inputs, y])
out1 = self.layernorm1(y)
y, attn_encoder = self.att2(out1, encoder_output, attention_mask=padding_mask, return_attention_scores=True)
y = self.dropout2(y, training=training)
y = add([out1, y])
out2 = self.layernorm1(y)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
final_output = self.layernorm2(out2 + ffn_output)
return final_output, attn_output1, attn_encoder
class Encoder(Layer):
def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs):
super(Encoder, self).__init__(**kwargs)
self.n = n
self.embed_dim = embed_dim
self.max_length = max_length
self.n_h = num_heads
self.f_d = ff_dim
self.rate = rate
self._layers = [EncoderBlock(i, embed_dim, num_heads, ff_dim, rate=0.1, name=f"enc_block_{i}") for i in range(n)]
self.pe = positional_encoding(self.max_length, self.embed_dim)
def get_config(self):
config = super(Encoder, self).get_config()
config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
return config
def call(self, x, training=False):
x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
x = x + self.pe[:, :tf.shape(x)[1], :]
for layer in self._layers:
x = layer(x, training)
return x
class Decoder(Layer):
def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs):
super(Decoder, self).__init__(**kwargs)
self.n = n
self.embed_dim = embed_dim
self.max_length = max_length
self.n_h = num_heads
self.f_d = ff_dim
self.rate = rate
self._layers = [DecoderBlock(embed_dim, num_heads, ff_dim, rate=0.1) for _ in range(n)]
self.pe = positional_encoding(self.max_length, self.embed_dim)
def get_config(self):
config = super(Decoder, self).get_config()
config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
return config
def call(self, x, encoder_output, look_ahead_mask, padding_mask, training):
x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
x = x + self.pe[:, :tf.shape(x)[1], :]
for layer in self._layers:
x, self_att, enc_att = layer(x, encoder_output, look_ahead_mask, padding_mask, training)
return x
# =========================================
# M A S K S
# =========================================
def create_padding_mask(seq):
"""
For self-attention
seq shape(bs, max_length, emb_dim)
output shape (bs, max_length, max_length)
"""
mask = tf.cast(tf.not_equal(seq, 0), tf.bool)
mask = tf.reduce_any(mask, 2)
mask = tf.repeat(mask, seq.shape[1], 0)
mask = tf.reshape(mask, (-1,seq.shape[1], seq.shape[1]))
return tf.cast(mask, tf.float32)
def create_cross_padding_mask(seq, target_seq):
"""
For cross-attention
seq shape(bs, k, image_features)
target_seq(bs, max_length, emb_dim)
output shape (bs, max_length, k)
"""
mask = tf.cast(tf.not_equal(target_seq, 0), tf.bool)
mask = tf.reduce_any(mask, 2)
mask = tf.repeat(mask, seq.shape[1], 0)
mask = tf.reshape(mask, (-1, tf.shape(seq)[1], tf.shape(target_seq)[1]))
mask = tf.transpose(mask, [0, 2, 1])
return mask
def create_look_ahead_mask(seq):
"""
seq shape(bs, max_length, emb_dim)
output 2D matrix of shape (bs, max_length, max_length) with ones on the diagonal and below.
"""
size = seq.shape[1]
mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0)
mask = tf.expand_dims(mask, 0)
mask = tf.repeat(mask, tf.shape(seq)[0], 0)
return mask
def create_masks(seq, target_seq):
decoder_mask = create_padding_mask(target_seq)
decoder_mask *= create_look_ahead_mask(target_seq)
cross_att_mask = create_cross_padding_mask(seq, target_seq)
return decoder_mask, cross_att_mask
def create_masks_looking_ahead(seq, target_seq):
decoder_mask = create_padding_mask(target_seq)
cross_att_mask = create_cross_padding_mask(seq, target_seq)
return decoder_mask, cross_att_mask
# =========================================
# P O S I T I O N A L E N C O D I N G
# =========================================
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
@tf.autograph.experimental.do_not_convert
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
class PatchEncoder(Layer):
def __init__(self, num_patches, projection_dim, **kwargs):
super(PatchEncoder, self).__init__(**kwargs)
self.num_patches = num_patches
self.projection_dim = projection_dim
self.projection = Dense(units=projection_dim)
self.position_embedding = Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def get_config(self):
config = super(PatchEncoder, self).get_config()
config.update({"num_patches": self.num_patches, "projection_dim":self.projection_dim})
return config
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded