aswinkvj's picture
commit
0482489
import streamlit as st
import requests
import io
# Designing the interface
st.title("πŸ–ΌοΈ Image Captioning Demo πŸ“")
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
st.sidebar.markdown(
"""
An image captioning model by combining ViT model with GPT2 model.
The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
[Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
"""
)
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
from model import *
random_image_id = get_random_image_id()
st.sidebar.title("Select a sample image")
sample_image_id = st.sidebar.selectbox(
"Please choose a sample image",
sample_image_ids
)
if st.sidebar.button("Random COCO 2017 (val) images"):
random_image_id = get_random_image_id()
sample_image_id = "None"
bytes_data = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.file_uploader("Choose a file")
submitted = st.form_submit_button("Upload")
if submitted and uploaded_file is not None:
bytes_data = io.BytesIO(uploaded_file.getvalue())
if (bytes_data is None) and submitted:
st.write("No file is selected to upload")
else:
image_id = random_image_id
if sample_image_id != "None":
assert type(sample_image_id) == int
image_id = sample_image_id
sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
sample_path = os.path.join(sample_dir, sample_name)
if bytes_data is not None:
image = Image.open(bytes_data)
elif os.path.isfile(sample_path):
image = Image.open(sample_path)
else:
url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
image = Image.open(requests.get(url, stream=True).raw)
width, height = image.size
resized = image.resize(size=(width, height))
if height > 384:
width = int(width / height * 384)
height = 384
resized = resized.resize(size=(width, height))
width, height = resized.size
if width > 512:
width = 512
height = int(height / width * 512)
resized = resized.resize(size=(width, height))
if bytes_data is None:
st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
show = st.image(resized)
show.image(resized, '\n\nSelected Image')
resized.close()
# For newline
st.sidebar.write('\n')
with st.spinner('Generating image caption ...'):
caption = predict(image)
caption_en = caption
st.header(f'Predicted caption:\n\n')
st.subheader(caption_en)
st.sidebar.header("ViT-GPT2 predicts: ")
st.sidebar.write(f"{caption}")
image.close()