|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
from PIL import Image |
|
import requests |
|
|
|
from predict import generate_text |
|
from model import load_model |
|
|
|
from streamlit_image_select import image_select |
|
|
|
|
|
|
|
st.set_page_config(page_title="Caption Machine", page_icon="๐ธ") |
|
|
|
|
|
|
|
model, image_transform, tokenizer = load_model() |
|
|
|
if 'model' not in st.session_state: |
|
st.session_state['model'] = model |
|
|
|
if 'image_transform' not in st.session_state: |
|
st.session_state['image_transform'] = image_transform |
|
|
|
if 'tokenizer' not in st.session_state: |
|
st.session_state['tokenizer'] = tokenizer |
|
|
|
|
|
|
|
|
|
st.write( |
|
"""<style> |
|
[data-testid="column"] { |
|
width: calc(50% - 1rem); |
|
flex: 1 1 calc(50% - 1rem); |
|
min-width: calc(50% - 1rem); |
|
} |
|
</style>""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
st.title("Image Captioner") |
|
st.markdown( |
|
"This app generates Image Caption using OpenAI's [GPT-2](https://openai.com/research/better-language-models) and [CLIP](https://openai.com/research/clip) model." |
|
) |
|
|
|
|
|
|
|
|
|
select_file = image_select( |
|
label="Select a photo:", |
|
images=[ |
|
"https://farm5.staticflickr.com/4084/5093294428_2f50d54acb_z.jpg", |
|
"https://farm8.staticflickr.com/7044/6855243647_cd204d079c_z.jpg", |
|
"http://farm4.staticflickr.com/3016/2650267987_f478c8d682_z.jpg", |
|
"https://farm8.staticflickr.com/7249/6913786280_c145ecc433_z.jpg", |
|
], |
|
|
|
) |
|
|
|
|
|
|
|
upload_file = st.file_uploader("Upload an image:", type=['png','jpg','jpeg']) |
|
|
|
st.divider() |
|
|
|
|
|
if upload_file or select_file: |
|
|
|
img = None |
|
|
|
if upload_file: |
|
img = Image.open(upload_file) |
|
|
|
elif select_file: |
|
st.text(select_file) |
|
img = Image.open(requests.get(select_file, stream=True).raw) |
|
|
|
|
|
st.image(img) |
|
|
|
|
|
|
|
with st.spinner('Generating caption...'): |
|
caption = generate_text(st.session_state['model'], img, st.session_state['tokenizer'], st.session_state['image_transform']) |
|
|
|
st.success(f"Result: {caption}") |
|
|
|
|
|
|
|
with st.expander("See model architecture"): |
|
st.write("") |
|
|
|
model_img = Image.open('./model.png') |
|
st.image(model_img, width=500) |
|
|
|
|
|
|