File size: 5,151 Bytes
f6ee7e4
 
 
 
6cc012f
 
a8d91bf
f6ee7e4
6cc012f
f6ee7e4
6cc012f
 
578e499
6cc012f
92f215b
 
 
07c95a1
f6ee7e4
07c95a1
f6ee7e4
07c95a1
 
 
 
f6ee7e4
07c95a1
 
f6ee7e4
07c95a1
 
 
 
 
 
 
f6ee7e4
 
 
b10f4fc
 
f6ee7e4
 
92f215b
f6ee7e4
 
 
 
 
92f215b
a254501
3396f01
578e499
f6ee7e4
 
 
6cc012f
 
 
f6ee7e4
 
 
 
6cc012f
 
f6ee7e4
 
6cc012f
f6ee7e4
6cc012f
 
 
 
b10f4fc
6cc012f
 
 
07c95a1
6cc012f
 
 
 
 
f6ee7e4
6cc012f
 
f6ee7e4
6cc012f
 
 
f6ee7e4
 
 
 
 
 
b82b2c8
 
 
 
f6ee7e4
 
 
6cc012f
 
 
b10f4fc
6cc012f
b10f4fc
 
f6ee7e4
b10f4fc
f6ee7e4
6cc012f
 
578e499
f6ee7e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578e499
f6ee7e4
578e499
f6ee7e4
578e499
 
 
 
 
f6ee7e4
 
 
 
 
 
 
 
6cc012f
 
578e499
7dbbdfd
f6ee7e4
6cc012f
578e499
7dbbdfd
f6ee7e4
 
 
7dbbdfd
 
6cc012f
f6ee7e4
 
 
 
 
b82b2c8
f6ee7e4
7dbbdfd
 
f6ee7e4
 
6cc012f
 
b82b2c8
4474721
f6ee7e4
6cc012f
 
f6ee7e4
 
 
 
5ffe5df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os

import jax
import jax.numpy as jnp
import nmslib
import numpy as np
import streamlit as st
from PIL import Image
from transformers import AutoTokenizer, CLIPProcessor

from model import FlaxHybridCLIP

# st.header('Under construction')


st.sidebar.title("CLIP React Demo")

st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
sc = st.sidebar.columns(2)

sc[0].image("./huggingface_explode3.png", width=150)
sc[1].write(" ")
sc[1].write(" ")
sc[1].markdown("## Researching fun")

with st.sidebar.expander("Motivation", expanded=True):
    st.markdown(
        """
        Reaction GIFs became an integral part of communication.
        They convey complex emotions with many levels, in a short compact format.

        If a picture is worth a thousand words then a GIF is worth more.
        A lot of people would agree it is not always easy to find the perfect reaction GIF.
        This is just a first step in the more ambitious goal of GIF/Image generation.
        """
    )
top_k = st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
col_count = 4
file_names = os.listdir("./jpg")
file_names.sort()

show_val = st.sidebar.button("show all validation set images")

if show_val:
    cols = st.sidebar.columns(col_count)

    for i, im in enumerate(file_names):
        j = i % col_count
        cols[j].image("./jpg/" + im)

st.write("# Search Reaction GIFs with CLIP ")
st.write("    ")
st.write("    ")


@st.cache_resource()
def load_model():
    model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    processor.tokenizer = AutoTokenizer.from_pretrained(
        "cardiffnlp/twitter-roberta-base"
    )

    return model, processor


@st.cache_resource()
def load_image_index():
    index = nmslib.init(method="hnsw", space="cosinesimil")
    index.loadIndex("./features/image_embeddings", load_data=True)

    return index


image_index = load_image_index()
model, processor = load_model()


# TODO
def add_image_emb(image):
    image = Image.open(image).convert("RGB")

    inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)

    inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
    features = model(**inputs).image_embeds

    image_index.addDataPoint(features)


def query_with_images(query_images, query_text):
    images = []

    for im in query_images:
        img = Image.open(im).convert("RGB")

        if im.name.endswith(".gif"):
            img.seek(0)
        images.append(img)

    inputs = processor(
        text=[query_text], images=images, return_tensors="jax", padding=True
    )
    inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image.reshape(-1)
    # st.write(logits_per_image)
    probs = jax.nn.softmax(logits_per_image)
    # st.write(probs)
    # st.write(list(zip(images,probs)))
    results = sorted(list(zip(images, probs)), key=lambda x: x[1], reverse=True)
    # st.write(results)

    return zip(*results)


q_cols = st.columns([5, 2, 5])

examples = [
    "OMG that is disgusting",
    "I'm so scared right now",
    " I got the job 🎉",
    "Congratulations to all the flax-community week teams",
    "You're awesome",
    "I love you ❤️",
]
example_input = q_cols[0].radio(
    "Example Queries :",
    examples,
    index=4,
    help="These are examples I wrote off the top of my head. They don't occur in the dataset",
)
q_cols[2].markdown(
    """
    Searches among the validation set images if not specified

    (There may be non-exact duplicates)

"""
)

query_text = q_cols[0].text_input(
    "Write text you want to get reaction for", value=example_input
)
query_images = q_cols[2].file_uploader(
    "(optional) Upload images to rank them",
    type=["jpg", "jpeg", "gif"],
    accept_multiple_files=True,
)

if query_images:
    st.write("Ranking your uploaded images with respect to input text:")
    with st.spinner("Calculating..."):
        ids, dists = query_with_images(query_images, query_text)
else:
    st.write("Found these images within validation set:")
    with st.spinner("Calculating..."):
        proc = processor(
            text=[query_text], images=None, return_tensors="jax", padding=True
        )
        vec = np.asarray(model.get_text_features(**proc))
        ids, dists = image_index.knnQuery(vec, k=top_k)

show_gif = st.checkbox(
    "Play GIFs",
    value=True,
    help="Will play the original animation. Only first frame is used in training!",
)
ext = "jpg" if not show_gif else "gif"
res_cols = st.columns(col_count)


for i, (id_, dist) in enumerate(zip(ids, dists)):
    j = i % col_count
    with res_cols[j]:
        if isinstance(id_, np.int32):
            st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}")
            # st.write(file_names[id_])
            st.write(1.0 - dist)
        else:
            st.image(id_)
            st.write(dist)


# Credits
st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")