File size: 2,642 Bytes
61448a4
 
 
 
 
 
 
 
 
 
 
 
dc3cb2a
 
 
61448a4
dc3cb2a
61448a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc3cb2a
 
61448a4
 
 
 
 
 
 
 
 
 
 
 
dc3cb2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61448a4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import tokenizers
from io import BytesIO
import torch
from transformers import (
    VisionTextDualEncoderModel,
    AutoFeatureExtractor,
    AutoTokenizer,
    CLIPModel,
    AutoProcessor
)
import streamlit.components.v1 as components


def embed_texts(model, texts, processor):
    inputs = processor(text=texts, padding="longest")
    input_ids = torch.tensor(inputs["input_ids"])
    attention_mask = torch.tensor(inputs["attention_mask"])

    with torch.no_grad():
        embeddings = model.get_text_features(
            input_ids=input_ids, attention_mask=attention_mask
        )
    return embeddings

@st.cache
def load_embeddings(embeddings_path):
    print("loading embeddings")
    return np.load(embeddings_path)

@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None
    }
)
def load_path_clip():
    model = CLIPModel.from_pretrained("vinid/plip")
    processor = AutoProcessor.from_pretrained("vinid/plip")
    return model, processor


def app():
    st.title('PLIP Image Search')

    plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
    plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t")

    model, processor = load_path_clip()

    image_embedding = load_embeddings("tweet_eval_embeddings.npy")

    query = st.text_input('Search Query', '')

    if query:

        text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()

        text_embedding = text_embedding/np.linalg.norm(text_embedding)
        
        # Sort IDs by cosine-similarity from high to low
        similarity_scores = text_embedding.dot(image_embedding.T)
        id_sorted = np.argsort(similarity_scores)[::-1]


        best_id = id_sorted[0]
        score = similarity_scores[best_id]
        target_url = plip_imgURL.iloc[best_id]["imageURL"]
        target_weblink = plip_weblink.iloc[best_id]["weblink"]

        st.caption('Most relevant image (similarity = %.4f)' % score)
        #response = requests.get(target_url)
        #img = Image.open(BytesIO(response.content))
        #st.image(img)


        components.html('''
            <blockquote class="twitter-tweet">
                <a href="%s"></a>
            </blockquote>
            <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
            </script>
            ''' % target_weblink,
        height=600)