Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
import cv2 | |
import numpy as np | |
from PIL import Image, ImageOps | |
import imageio.v3 as iio | |
import time | |
from textwrap import wrap | |
import matplotlib.pylab as plt | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
import tensorflow_hub as hub | |
from tensorflow.keras import Input | |
from tensorflow.keras.layers import ( | |
GRU, | |
Add, | |
AdditiveAttention, | |
Attention, | |
Concatenate, | |
Dense, | |
Embedding, | |
LayerNormalization, | |
Reshape, | |
StringLookup, | |
TextVectorization, | |
) | |
def load_image_model(): | |
image_model=tf.keras.models.load_model('./image_caption_model.h5') | |
return image_model | |
def load_decoder_model(): | |
decoder_model=tf.keras.models.load_model('./decoder_pred_model.h5') | |
return decoder_model | |
def load_encoder_model(): | |
encoder=tf.keras.models.load_model('./encoder_model.h5') | |
return encoder | |
st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation") | |
image = Image.open('./title.jpg') | |
st.image(image) | |
st.write(""" | |
# Multi-Modal Machine Learning | |
""" | |
) | |
file = st.file_uploader("Upload any image and the model will try to provide a caption to it!", type= ['png', 'jpg']) | |
MAX_CAPTION_LEN = 64 | |
MINIMUM_SENTENCE_LENGTH = 5 | |
IMG_HEIGHT = 299 | |
IMG_WIDTH = 299 | |
IMG_CHANNELS = 3 | |
ATTENTION_DIM = 512 # size of dense layer in Attention | |
VOCAB_SIZE = 20000 | |
# We will override the default standardization of TextVectorization to preserve | |
# "<>" characters, so we preserve the tokens for the <start> and <end>. | |
def standardize(inputs): | |
inputs = tf.strings.lower(inputs) | |
return tf.strings.regex_replace( | |
inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", "" | |
) | |
# Choose the most frequent words from the vocabulary & remove punctuation etc. | |
tokenizer = TextVectorization( | |
max_tokens=VOCAB_SIZE, | |
standardize=standardize, | |
output_sequence_length=MAX_CAPTION_LEN, | |
) | |
# Lookup table: Word -> Index | |
word_to_index = StringLookup( | |
mask_token="", vocabulary=tokenizer.get_vocabulary() | |
) | |
# Lookup table: Index -> Word | |
index_to_word = StringLookup( | |
mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True | |
) | |
## Probabilistic prediction using the trained model | |
def predict_caption(file): | |
gru_state = tf.zeros((1, ATTENTION_DIM)) | |
img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS) | |
img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH)) | |
img = img / 255 | |
encoder = load_encoder_model() | |
features = encoder(tf.expand_dims(img, axis=0)) | |
dec_input = tf.expand_dims([word_to_index("<start>")], 1) | |
result = [] | |
decoder_pred_model = load_decoder_model() | |
for i in range(MAX_CAPTION_LEN): | |
predictions, gru_state = decoder_pred_model( | |
[dec_input, gru_state, features] | |
) | |
# draws from log distribution given by predictions | |
top_probs, top_idxs = tf.math.top_k( | |
input=predictions[0][0], k=10, sorted=False | |
) | |
chosen_id = tf.random.categorical([top_probs], 1)[0].numpy() | |
predicted_id = top_idxs.numpy()[chosen_id][0] | |
result.append(tokenizer.get_vocabulary()[predicted_id]) | |
if predicted_id == word_to_index("<end>"): | |
return img, result | |
dec_input = tf.expand_dims([predicted_id], 1) | |
return img, result | |
filename = "../sample_images/surf.jpeg" # you can also try surf.jpeg | |
for i in range(5): | |
image, caption = predict_caption(filename) | |
print(" ".join(caption[:-1]) + ".") | |
img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS) | |
plt.imshow(img) | |
plt.axis("off") | |
filename = np.array(Image.open(file).convert('RGB')) | |
def model_prediction(path): | |
resize = tf.image.resize(path, (256,256)) | |
with st.spinner('Model is being loaded..'): | |
model=load_image_model() | |
yhat = model.predict(np.expand_dims(resize/255, 0)) | |
return yhat | |
def on_click(): | |
if file is None: | |
st.text("Please upload an image file") | |
else: | |
image = Image.open(file) | |
st.image(image, use_column_width=True) | |
image = image.convert('RGB') | |
predictions = model_prediction(np.array(image)) | |
if (predictions>0.5): | |
st.write("""# Prediction : Implant is loose""") | |
else: | |
st.write("""# Prediction : Implant is in control""") | |
st.button('Predict', on_click=on_click) |