|
import streamlit as st |
|
import requests |
|
import io |
|
|
|
|
|
|
|
st.title("🖼️ Image Captioning Demo 📝") |
|
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)") |
|
|
|
st.sidebar.markdown( |
|
""" |
|
An image captioning model by combining ViT model with GPT2 model. |
|
The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder |
|
framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html). |
|
The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights. |
|
The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256). |
|
[Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n |
|
""" |
|
) |
|
|
|
with st.spinner('Loading and compiling ViT-GPT2 model ...'): |
|
from model import * |
|
|
|
random_image_id = get_random_image_id() |
|
|
|
st.sidebar.title("Select a sample image") |
|
sample_image_id = st.sidebar.selectbox( |
|
"Please choose a sample image", |
|
sample_image_ids |
|
) |
|
|
|
if st.sidebar.button("Random COCO 2017 (val) images"): |
|
random_image_id = get_random_image_id() |
|
sample_image_id = "None" |
|
|
|
bytes_data = None |
|
with st.sidebar.form("file-uploader-form", clear_on_submit=True): |
|
uploaded_file = st.file_uploader("Choose a file") |
|
submitted = st.form_submit_button("Upload") |
|
if submitted and uploaded_file is not None: |
|
bytes_data = io.BytesIO(uploaded_file.getvalue()) |
|
|
|
if (bytes_data is None) and submitted: |
|
|
|
st.write("No file is selected to upload") |
|
|
|
else: |
|
|
|
image_id = random_image_id |
|
if sample_image_id != "None": |
|
assert type(sample_image_id) == int |
|
image_id = sample_image_id |
|
|
|
sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg" |
|
sample_path = os.path.join(sample_dir, sample_name) |
|
|
|
if bytes_data is not None: |
|
image = Image.open(bytes_data) |
|
elif os.path.isfile(sample_path): |
|
image = Image.open(sample_path) |
|
else: |
|
url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
width, height = image.size |
|
resized = image.resize(size=(width, height)) |
|
if height > 384: |
|
width = int(width / height * 384) |
|
height = 384 |
|
resized = resized.resize(size=(width, height)) |
|
width, height = resized.size |
|
if width > 512: |
|
width = 512 |
|
height = int(height / width * 512) |
|
resized = resized.resize(size=(width, height)) |
|
|
|
if bytes_data is None: |
|
st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)") |
|
show = st.image(resized) |
|
show.image(resized, '\n\nSelected Image') |
|
resized.close() |
|
|
|
|
|
st.sidebar.write('\n') |
|
|
|
with st.spinner('Generating image caption ...'): |
|
|
|
caption = predict(image) |
|
|
|
caption_en = caption |
|
st.header(f'Predicted caption:\n\n') |
|
st.subheader(caption_en) |
|
|
|
st.sidebar.header("ViT-GPT2 predicts: ") |
|
st.sidebar.write(f"{caption}") |
|
|
|
image.close() |
|
|