Spaces:
Runtime error
Runtime error
File size: 5,151 Bytes
f6ee7e4 6cc012f a8d91bf f6ee7e4 6cc012f f6ee7e4 6cc012f 578e499 6cc012f 92f215b 07c95a1 f6ee7e4 07c95a1 f6ee7e4 07c95a1 f6ee7e4 07c95a1 f6ee7e4 07c95a1 f6ee7e4 b10f4fc f6ee7e4 92f215b f6ee7e4 92f215b a254501 3396f01 578e499 f6ee7e4 6cc012f f6ee7e4 6cc012f f6ee7e4 6cc012f f6ee7e4 6cc012f b10f4fc 6cc012f 07c95a1 6cc012f f6ee7e4 6cc012f f6ee7e4 6cc012f f6ee7e4 b82b2c8 f6ee7e4 6cc012f b10f4fc 6cc012f b10f4fc f6ee7e4 b10f4fc f6ee7e4 6cc012f 578e499 f6ee7e4 578e499 f6ee7e4 578e499 f6ee7e4 578e499 f6ee7e4 6cc012f 578e499 7dbbdfd f6ee7e4 6cc012f 578e499 7dbbdfd f6ee7e4 7dbbdfd 6cc012f f6ee7e4 b82b2c8 f6ee7e4 7dbbdfd f6ee7e4 6cc012f b82b2c8 4474721 f6ee7e4 6cc012f f6ee7e4 5ffe5df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import jax
import jax.numpy as jnp
import nmslib
import numpy as np
import streamlit as st
from PIL import Image
from transformers import AutoTokenizer, CLIPProcessor
from model import FlaxHybridCLIP
# st.header('Under construction')
st.sidebar.title("CLIP React Demo")
st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
sc = st.sidebar.columns(2)
sc[0].image("./huggingface_explode3.png", width=150)
sc[1].write(" ")
sc[1].write(" ")
sc[1].markdown("## Researching fun")
with st.sidebar.expander("Motivation", expanded=True):
st.markdown(
"""
Reaction GIFs became an integral part of communication.
They convey complex emotions with many levels, in a short compact format.
If a picture is worth a thousand words then a GIF is worth more.
A lot of people would agree it is not always easy to find the perfect reaction GIF.
This is just a first step in the more ambitious goal of GIF/Image generation.
"""
)
top_k = st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
col_count = 4
file_names = os.listdir("./jpg")
file_names.sort()
show_val = st.sidebar.button("show all validation set images")
if show_val:
cols = st.sidebar.columns(col_count)
for i, im in enumerate(file_names):
j = i % col_count
cols[j].image("./jpg/" + im)
st.write("# Search Reaction GIFs with CLIP ")
st.write(" ")
st.write(" ")
@st.cache_resource()
def load_model():
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained(
"cardiffnlp/twitter-roberta-base"
)
return model, processor
@st.cache_resource()
def load_image_index():
index = nmslib.init(method="hnsw", space="cosinesimil")
index.loadIndex("./features/image_embeddings", load_data=True)
return index
image_index = load_image_index()
model, processor = load_model()
# TODO
def add_image_emb(image):
image = Image.open(image).convert("RGB")
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
features = model(**inputs).image_embeds
image_index.addDataPoint(features)
def query_with_images(query_images, query_text):
images = []
for im in query_images:
img = Image.open(im).convert("RGB")
if im.name.endswith(".gif"):
img.seek(0)
images.append(img)
inputs = processor(
text=[query_text], images=images, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image.reshape(-1)
# st.write(logits_per_image)
probs = jax.nn.softmax(logits_per_image)
# st.write(probs)
# st.write(list(zip(images,probs)))
results = sorted(list(zip(images, probs)), key=lambda x: x[1], reverse=True)
# st.write(results)
return zip(*results)
q_cols = st.columns([5, 2, 5])
examples = [
"OMG that is disgusting",
"I'm so scared right now",
" I got the job 🎉",
"Congratulations to all the flax-community week teams",
"You're awesome",
"I love you ❤️",
]
example_input = q_cols[0].radio(
"Example Queries :",
examples,
index=4,
help="These are examples I wrote off the top of my head. They don't occur in the dataset",
)
q_cols[2].markdown(
"""
Searches among the validation set images if not specified
(There may be non-exact duplicates)
"""
)
query_text = q_cols[0].text_input(
"Write text you want to get reaction for", value=example_input
)
query_images = q_cols[2].file_uploader(
"(optional) Upload images to rank them",
type=["jpg", "jpeg", "gif"],
accept_multiple_files=True,
)
if query_images:
st.write("Ranking your uploaded images with respect to input text:")
with st.spinner("Calculating..."):
ids, dists = query_with_images(query_images, query_text)
else:
st.write("Found these images within validation set:")
with st.spinner("Calculating..."):
proc = processor(
text=[query_text], images=None, return_tensors="jax", padding=True
)
vec = np.asarray(model.get_text_features(**proc))
ids, dists = image_index.knnQuery(vec, k=top_k)
show_gif = st.checkbox(
"Play GIFs",
value=True,
help="Will play the original animation. Only first frame is used in training!",
)
ext = "jpg" if not show_gif else "gif"
res_cols = st.columns(col_count)
for i, (id_, dist) in enumerate(zip(ids, dists)):
j = i % col_count
with res_cols[j]:
if isinstance(id_, np.int32):
st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}")
# st.write(file_names[id_])
st.write(1.0 - dist)
else:
st.image(id_)
st.write(dist)
# Credits
st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")
|