import pickle import tensorflow as tf import pandas as pd import numpy as np import os import io import streamlit as st import requests from PIL import Image # Set environment variable os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # Constants MAX_LENGTH = 40 BATCH_SIZE = 32 BUFFER_SIZE = 1000 EMBEDDING_DIM = 512 UNITS = 512 # Load vocabulary vocab = pickle.load(open('saved_vocabulary/vocab_coco.file', 'rb')) tokenizer = tf.keras.layers.TextVectorization( standardize=None, output_sequence_length=MAX_LENGTH, vocabulary=vocab ) idx2word = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True ) # Model Definitions def CNN_Encoder(): inception_v3 = tf.keras.applications.InceptionV3( include_top=False, weights='imagenet' ) output = inception_v3.output output = tf.keras.layers.Reshape( (-1, output.shape[-1]))(output) cnn_model = tf.keras.models.Model(inception_v3.input, output) return cnn_model class TransformerEncoderLayer(tf.keras.layers.Layer): def __init__(self, embed_dim, num_heads): super().__init__() self.layer_norm_1 = tf.keras.layers.LayerNormalization() self.layer_norm_2 = tf.keras.layers.LayerNormalization() self.attention = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim) self.dense = tf.keras.layers.Dense(embed_dim, activation="relu") def call(self, x, training): x = self.layer_norm_1(x) x = self.dense(x) attn_output = self.attention( query=x, value=x, key=x, attention_mask=None, training=training ) x = self.layer_norm_2(x + attn_output) return x class Embeddings(tf.keras.layers.Layer): def __init__(self, vocab_size, embed_dim, max_len): super().__init__() self.token_embeddings = tf.keras.layers.Embedding( vocab_size, embed_dim) self.position_embeddings = tf.keras.layers.Embedding( max_len, embed_dim, input_shape=(None, max_len)) def call(self, input_ids): length = tf.shape(input_ids)[-1] position_ids = tf.range(start=0, limit=length, delta=1) position_ids = tf.expand_dims(position_ids, axis=0) token_embeddings = self.token_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) return token_embeddings + position_embeddings class TransformerDecoderLayer(tf.keras.layers.Layer): def __init__(self, embed_dim, units, num_heads): super().__init__() self.embedding = Embeddings( tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH) self.attention_1 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=0.1 ) self.attention_2 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=0.1 ) self.layernorm_1 = tf.keras.layers.LayerNormalization() self.layernorm_2 = tf.keras.layers.LayerNormalization() self.layernorm_3 = tf.keras.layers.LayerNormalization() self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu") self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim) self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax") self.dropout_1 = tf.keras.layers.Dropout(0.3) self.dropout_2 = tf.keras.layers.Dropout(0.5) def call(self, input_ids, encoder_output, training, mask=None): embeddings = self.embedding(input_ids) combined_mask = None padding_mask = None if mask is not None: causal_mask = self.get_causal_attention_mask(embeddings) padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32) combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32) combined_mask = tf.minimum(combined_mask, causal_mask) attn_output_1 = self.attention_1( query=embeddings, value=embeddings, key=embeddings, attention_mask=combined_mask, training=training ) out_1 = self.layernorm_1(embeddings + attn_output_1) attn_output_2 = self.attention_2( query=out_1, value=encoder_output, key=encoder_output, attention_mask=padding_mask, training=training ) out_2 = self.layernorm_2(out_1 + attn_output_2) ffn_out = self.ffn_layer_1(out_2) ffn_out = self.dropout_1(ffn_out, training=training) ffn_out = self.ffn_layer_2(ffn_out) ffn_out = self.layernorm_3(ffn_out + out_2) ffn_out = self.dropout_2(ffn_out, training=training) preds = self.out(ffn_out) return preds def get_causal_attention_mask(self, inputs): input_shape = tf.shape(inputs) batch_size, sequence_length = input_shape[0], input_shape[1] i = tf.range(sequence_length)[:, tf.newaxis] j = tf.range(sequence_length) mask = tf.cast(i >= j, dtype="int32") mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) mult = tf.concat( [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], axis=0 ) return tf.tile(mask, mult) class ImageCaptioningModel(tf.keras.Model): def __init__(self, cnn_model, encoder, decoder, image_aug=None): super().__init__() self.cnn_model = cnn_model self.encoder = encoder self.decoder = decoder self.image_aug = image_aug self.loss_tracker = tf.keras.metrics.Mean(name="loss") self.acc_tracker = tf.keras.metrics.Mean(name="accuracy") def calculate_loss(self, y_true, y_pred, mask): loss = self.loss(y_true, y_pred) mask = tf.cast(mask, dtype=loss.dtype) loss *= mask return tf.reduce_sum(loss) / tf.reduce_sum(mask) def calculate_accuracy(self, y_true, y_pred, mask): accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2)) accuracy = tf.math.logical_and(mask, accuracy) accuracy = tf.cast(accuracy, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32) return tf.reduce_sum(accuracy) / tf.reduce_sum(mask) def compute_loss_and_acc(self, img_embed, captions, training=True): encoder_output = self.encoder(img_embed, training=True) y_input = captions[:, :-1] y_true = captions[:, 1:] mask = (y_true != 0) y_pred = self.decoder( y_input, encoder_output, training=True, mask=mask ) loss = self.calculate_loss(y_true, y_pred, mask) acc = self.calculate_accuracy(y_true, y_pred, mask) return loss, acc def train_step(self, batch): imgs, captions = batch if self.image_aug: imgs = self.image_aug(imgs) img_embed = self.cnn_model(imgs) with tf.GradientTape() as tape: loss, acc = self.compute_loss_and_acc( img_embed, captions ) train_vars = ( self.encoder.trainable_variables + self.decoder.trainable_variables ) grads = tape.gradient(loss, train_vars) self.optimizer.apply_gradients(zip(grads, train_vars)) self.loss_tracker.update_state(loss) self.acc_tracker.update_state(acc) return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} def test_step(self, batch): imgs, captions = batch img_embed = self.cnn_model(imgs) loss, acc = self.compute_loss_and_acc( img_embed, captions, training=False ) self.loss_tracker.update_state(loss) self.acc_tracker.update_state(acc) return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} @property def metrics(self): return [self.loss_tracker, self.acc_tracker] def load_image_from_path(img_path): img = tf.io.read_file(img_path) img = tf.io.decode_jpeg(img, channels=3) img = tf.keras.layers.Resizing(299, 299)(img) img = tf.keras.applications.inception_v3.preprocess_input(img) return img def generate_caption(img, caption_model, add_noise=False): if isinstance(img, str): img = load_image_from_path(img) if add_noise: noise = tf.random.normal(img.shape) * 0.1 img = (img + noise) img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img)) img = tf.expand_dims(img, 0) # Add batch dimension img_embed = caption_model.cnn_model(img, training=False) encoder_output = caption_model.encoder(img_embed, training=False) caption = [tokenizer.token_to_id("[START]")] for _ in range(MAX_LENGTH): input_caption = tf.convert_to_tensor([caption], dtype=tf.int32) pred = caption_model.decoder(input_caption, encoder_output, training=False) pred = tf.argmax(pred[0, -1, :]).numpy() caption.append(pred) if pred == tokenizer.token_to_id("[END]"): break return ' '.join([idx2word(word).numpy().decode('utf-8') for word in caption[1:-1]]) # Load saved model weights cnn_model = CNN_Encoder() encoder = TransformerEncoderLayer(embed_dim=EMBEDDING_DIM, num_heads=8) decoder = TransformerDecoderLayer(embed_dim=EMBEDDING_DIM, units=UNITS, num_heads=8) caption_model = ImageCaptioningModel(cnn_model=cnn_model, encoder=encoder, decoder=decoder) caption_model.load_weights('saved_model_weights/caption_model') # Streamlit App st.title('Image Captioning with Transformer') st.write('Upload an image to generate a caption.') uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) st.write("") st.write("Generating caption...") img_path = os.path.join("temp", uploaded_file.name) with open(img_path, "wb") as f: f.write(uploaded_file.getbuffer()) img = load_image_from_path(img_path) caption = generate_caption(img, caption_model) st.write("Caption:", caption) # Remove temp file after captioning if uploaded_file is not None and os.path.exists(img_path): os.remove(img_path)