File size: 4,035 Bytes
7925ce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")
from model import *

# # TODO:
# - Reformat the model introduction
# - Center the images using the 3 column method
# - Make the iterative text generation

def gen_show_caption(sub_prompt=None, cap_prompt = ""):
    with st.spinner("Generating Caption"):
        if sub_prompt is None and cap_prompt is not "":
            st.write("Without a specified subreddit we default to /r/pics")
        subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
        st.markdown(
            f"""
            <style>
                red{{
                    color:#c62828
                }}
                mono{{
                    font-family: "Inconsolata";
                }}
            </style>

            ### <red> r/{subreddit} </red>  {caption}
            """, 
            unsafe_allow_html=True)
    

st.title("Image Captioning Demo from RedCaps")
st.sidebar.markdown(
    """
    ### Image Captioning Model from VirTex trained on RedCaps
    
    Use this page to caption your own images or try out some of our samples.
    You can also generate captions as if they are from specific subreddits,
    as if they start with a particular prompt, or even both.
    
    Share your results on twitter with #redcaps or with a friend.
    """
)

with st.spinner("Loading Model"):
    virtexModel, imageLoader, sample_images, valid_subs = create_objects()
    

# staggered = st.sidebar.checkbox("Iteratively Generate Captions")

# if staggered:
#     pass
# else:

select_idx = None

st.sidebar.title("Select a sample image")

if st.sidebar.button("Random Sample Image"):
    select_idx = get_rand_idx(sample_images)

sample_image = sample_images[0 if select_idx is None else select_idx]


uploaded_image = None
# with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.sidebar.file_uploader("Choose a file")
# submitted = st.form_submit_button("Submit")
if uploaded_file is not None:# and submitted:
    uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
    select_idx = None # set this to help rewrite the cache

# class OnChange():
#     def __init__(self, idx):
#         self.idx = idx

#     def __call__(self):
#         st.write(f"the idx is: {self.idx}")
#         st.write(f"the sample_image is {sample_image}")

# sample_image = st.sidebar.selectbox(
#     "",
#     sample_images,
#     index = 0 if select_idx is None else select_idx,
#     on_change=OnChange(0 if select_idx is None else select_idx)
# )

st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
    "Type below to condition on a subreddit. Select None for a predicted subreddit",
    valid_subs
)

st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
    "Write the start of your caption below", 
    value=""
)

_ = st.sidebar.button("Regenerate Caption")

advanced = st.sidebar.checkbox("Advanced Options")
num_captions=1
if advanced:
    num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
    nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
    virtexModel.model.decoder.nucleus_size = nuc_size

if False: #uploaded_image is None:# and submitted:
    st.write("Please select a file to upload")

else:
    image_file = sample_image

    # LOAD AND CACHE THE IMAGE
    if uploaded_image is not None:
        image = uploaded_image
    elif select_idx is None and 'image' in st.session_state:
        image = st.session_state['image']
    else:
        image = Image.open(image_file)

    image = image.convert("RGB")

    st.session_state['image'] = image


    image_dict = imageLoader.transform(image)

    show_image = imageLoader.show_resize(image)

    show = st.image(show_image)
    show.image(show_image, "Your Image")

    for i in range(num_captions):
        gen_show_caption(sub, imageLoader.text_transform(cap_prompt))