Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pickle | |
import tensorflow as tf | |
import cv2 | |
import numpy as np | |
from PIL import Image, ImageOps | |
import imageio.v3 as iio | |
import time | |
from textwrap import wrap | |
import matplotlib.pylab as plt | |
from tensorflow.keras import Input | |
from tensorflow.keras.layers import ( | |
GRU, | |
Add, | |
AdditiveAttention, | |
Attention, | |
Concatenate, | |
Dense, | |
Embedding, | |
LayerNormalization, | |
Reshape, | |
StringLookup, | |
TextVectorization, | |
) | |
MAX_CAPTION_LEN = 64 | |
MINIMUM_SENTENCE_LENGTH = 5 | |
IMG_HEIGHT = 299 | |
IMG_WIDTH = 299 | |
IMG_CHANNELS = 3 | |
ATTENTION_DIM = 512 # size of dense layer in Attention | |
VOCAB_SIZE = 20000 | |
FEATURES_SHAPE = (8, 8, 1536) | |
def load_image_model(): | |
image_model=tf.keras.models.load_model('./image_caption_model.h5') | |
return image_model | |
# @st.cache_resource() | |
# def load_decoder_model(): | |
# decoder_model=tf.keras.models.load_model('./decoder_pred_model.h5') | |
# return decoder_model | |
# @st.cache_resource() | |
# def load_encoder_model(): | |
# encoder=tf.keras.models.load_model('./encoder_model.h5') | |
# return encoder | |
# **** ENCODER **** | |
image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)) | |
encoder_output = Dense(ATTENTION_DIM, activation="relu")(x) | |
encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output) | |
# **** ENCODER **** | |
# **** DECODER **** | |
word_input = Input(shape=(MAX_CAPTION_LEN), name="words") | |
embed_x = Embedding(VOCAB_SIZE, ATTENTION_DIM)(word_input) | |
decoder_gru = GRU( | |
ATTENTION_DIM, | |
return_sequences=True, | |
return_state=True, | |
) | |
gru_output, gru_state = decoder_gru(embed_x) | |
decoder_attention = Attention() | |
context_vector = decoder_attention([gru_output, encoder_output]) | |
addition = Add()([gru_output, context_vector]) | |
layer_norm = LayerNormalization(axis=-1) | |
layer_norm_out = layer_norm(addition) | |
decoder_output_dense = Dense(VOCAB_SIZE) | |
# ----------- | |
gru_state_input = Input(shape=(ATTENTION_DIM), name="gru_state_input") | |
# Reuse trained GRU, but update it so that it can receive states. | |
gru_output, gru_state = decoder_gru(embed_x, initial_state=gru_state_input) | |
# Reuse other layers as well | |
context_vector = decoder_attention([gru_output, encoder_output]) | |
addition_output = Add()([gru_output, context_vector]) | |
layer_norm_output = layer_norm(addition_output) | |
decoder_output = decoder_output_dense(layer_norm_output) | |
# Define prediction Model with state input and output | |
decoder_pred_model = tf.keras.Model( | |
inputs=[word_input, gru_state_input, encoder_output], | |
outputs=[decoder_output, gru_state], | |
) | |
# **** DECODER **** | |
st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation") | |
image = Image.open('./title.jpg') | |
st.image(image) | |
st.write(""" | |
# Multi-Modal Machine Learning | |
""" | |
) | |
file = st.file_uploader("Upload any image and the model will try to provide a caption to it!", type= ['png', 'jpg']) | |
# We will override the default standardization of TextVectorization to preserve | |
# "<>" characters, so we preserve the tokens for the <start> and <end>. | |
def standardize(inputs): | |
inputs = tf.strings.lower(inputs) | |
return tf.strings.regex_replace( | |
inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", "" | |
) | |
# Choose the most frequent words from the vocabulary & remove punctuation etc. | |
vocab = open('./tokenizer_vocab.txt', 'rb') | |
tokenizer = pickle.load(vocab) | |
# Lookup table: Word -> Index | |
word_to_index = StringLookup( | |
mask_token="", vocabulary=tokenizer | |
) | |
## Probabilistic prediction using the trained model | |
def predict_caption(file): | |
filename = Image.open(file) | |
image = filename.convert('RGB') | |
image = np.array(image) | |
gru_state = tf.zeros((1, ATTENTION_DIM)) | |
resize = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH)) | |
img = resize/255 | |
# encoder = load_encoder_model() | |
features = encoder(tf.expand_dims(img, axis=0)) | |
dec_input = tf.expand_dims([word_to_index("<start>")], 1) | |
result = [] | |
# decoder_pred_model = load_decoder_model() | |
for i in range(MAX_CAPTION_LEN): | |
predictions, gru_state = decoder_pred_model( | |
[dec_input, gru_state, features] | |
) | |
# draws from log distribution given by predictions | |
top_probs, top_idxs = tf.math.top_k( | |
input=predictions[0][0], k=10, sorted=False | |
) | |
chosen_id = tf.random.categorical([top_probs], 1)[0].numpy() | |
predicted_id = top_idxs.numpy()[chosen_id][0] | |
result.append(tokenizer[predicted_id]) | |
if predicted_id == word_to_index("<end>"): | |
return img, result | |
dec_input = tf.expand_dims([predicted_id], 1) | |
return img, result | |
def on_click(): | |
if file is None: | |
st.text("Please upload an image file") | |
else: | |
image = Image.open(file) | |
st.image(image, use_column_width=True) | |
for i in range(5): | |
image, caption = predict_caption(file) | |
#print(" ".join(caption[:-1]) + ".") | |
st.write(" ".join(caption[:-1]) + ".") | |
st.button('Generate', on_click=on_click) |