Spaces:
Build error
Build error
import streamlit as st | |
from PIL import Image | |
import tensorflow as tf | |
import numpy as np | |
# Load the pre-trained model | |
caption_model = get_caption_model() | |
# Load the index lookup dictionary | |
with open('index_lookup.pkl', 'rb') as f: | |
index_lookup = pickle.load(f) | |
# Set the maximum decoded sentence length | |
max_decoded_sentence_length = 40 | |
def generate_caption(img): | |
# Preprocess the image | |
img = tf.expand_dims(img, 0) | |
img_embed = caption_model.cnn_model(img) | |
# Pass the image features to the Transformer encoder | |
encoded_img = caption_model.encoder(img_embed, training=False) | |
# Generate the caption using the Transformer decoder | |
decoded_caption = "<start> " | |
for i in range(max_decoded_sentence_length): | |
tokenized_caption = vectorization([decoded_caption])[:, :-1] | |
mask = tf.math.not_equal(tokenized_caption, 0) | |
predictions = caption_model.decoder( | |
tokenized_caption, encoded_img, training=False, mask=mask | |
) | |
sampled_token_index = np.argmax(predictions[0, i, :]) | |
sampled_token = index_lookup[sampled_token_index] | |
if sampled_token == "<end>": | |
break | |
decoded_caption += " " + sampled_token | |
decoded_caption = decoded_caption.replace("<start> ", "") | |
decoded_caption = decoded_caption.replace(" <end>", "").strip() | |
return decoded_caption | |
st.title("Image Captioning") | |
# Upload an image | |
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
# Generate the caption | |
img = tf.keras.preprocessing.image.img_to_array(image) | |
img = tf.image.resize(img, (299, 299)) | |
caption = generate_caption(img) | |
# Display the generated caption | |
st.write("Generated Caption:", caption) |