File size: 3,839 Bytes
6cc012f
 
a8d91bf
6cc012f
 
 
 
 
 
 
578e499
6cc012f
92f215b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3396f01
578e499
6cc012f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578e499
 
92f215b
 
578e499
 
 
3396f01
578e499
 
 
 
 
92f215b
578e499
6cc012f
 
578e499
6cc012f
 
578e499
6cc012f
 
 
 
 
 
 
 
 
 
4474721
578e499
6cc012f
 
578e499
a8d91bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor
from model import FlaxHybridCLIP
from PIL import Image
import jax.numpy as jnp
import os
import jax

# st.header('Under construction')




        
st.sidebar.title("CLIP React Demo")
st.sidebar.write("Search Reaction GIFs with CLIP [Model Card](https://huggingface.co/flax-community/clip-reply)")
st.sidebar.image("./huggingface_explode3.png",width=150)


top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)

show_val=st.sidebar.button("show all validation set images")
if show_val:
    cols=st.sidebar.beta_columns(col_count)
    for i,im in enumerate(file_names):
        j=i%col_count
        cols[j].image("./imgs/"+im)

st.write("    ")
st.write("    ")
@st.cache(allow_output_mutation=True)
def load_model():
    model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base")
    return model, processor

@st.cache(allow_output_mutation=True)
def load_image_index():
    index = nmslib.init(method='hnsw', space='cosinesimil')
    index.loadIndex("./features/image_embeddings", load_data=True)

    return index

file_names=os.listdir("./imgs")
file_names.sort()

image_index = load_image_index()
model, processor = load_model()

col_count=4


# TODO
def add_image_emb(image):
    image = Image.open(image).convert("RGB")

    inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
              
    inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
    features = model(**inputs).image_embeds
    
    image_index.addDataPoint(features)


def query_with_images(query_images,query_text):
    images = [Image.open(im).convert("RGB") for im in query_images]
    inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True)
    inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image.reshape(-1)
    st.write(logits_per_image)
    probs = jax.nn.softmax(logits_per_image)
    st.write(probs)
    st.write(list(zip(images,probs)))
    results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True)
    st.write(results)
    return zip(*results)

q_cols=st.beta_columns([5,2,5])    

examples = ["I'm so scared right now"," I got the job 🎉","OMG that is disgusting","I'm awesome","I love you ❤️"]
example_input = q_cols[0].radio("Example Queries :",examples,index=4)
q_cols[2].markdown(
    """   
    Searches among the validation set images if not specified
    
    (There may be non-exact duplicates)

"""
)

query_text = q_cols[0].text_input("Write text you want to get reaction for", value=example_input)
query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg'], accept_multiple_files=True)

if query_images:
    st.write("Ranking your uploaded images with respect to input text:")
    ids, dists = query_with_images(query_images,query_text)
else:
    st.write("Found these images within validation set:")
    proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True)
    vec = np.asarray(model.get_text_features(**proc))
    ids, dists = image_index.knnQuery(vec, k=top_k)

res_cols=st.beta_columns(col_count)
for i,(id_, dist) in enumerate(zip(ids, dists)):
    j=i%col_count
    with res_cols[j]:
        if isinstance(id_, np.int32):
            st.image("./imgs/"+file_names[id_])
            # st.write(file_names[id_])
            st.write(1.0 - dist, help="score")
        else:
            st.image(id_)
            st.write(dist, help="score")