File size: 1,948 Bytes
61448a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import transformers
import tokenizers
from io import BytesIO
import streamlit as st
from transformers import CLIPModel
import clip
import torch
from transformers import (
    VisionTextDualEncoderModel,
    AutoFeatureExtractor,
    AutoTokenizer
)
from transformers import AutoProcessor


def embed_texts(model, texts, processor):
    inputs = processor(text=texts, padding="longest")
    input_ids = torch.tensor(inputs["input_ids"])
    attention_mask = torch.tensor(inputs["attention_mask"])

    with torch.no_grad():
        embeddings = model.get_text_features(
            input_ids=input_ids, attention_mask=attention_mask
        )
    return embeddings

@st.cache
def load_embeddings(embeddings_path):
    print("loading embeddings")
    return np.load(embeddings_path)

@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None
    }
)
def load_path_clip():
    model = CLIPModel.from_pretrained("vinid/plip")
    processor = AutoProcessor.from_pretrained("vinid/plip")
    return model, processor


def app():
    st.title('PLIP Image Search')

    plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")

    model, processor = load_path_clip()

    image_embedding = load_embeddings("tweet_eval_embeddings.npy")

    query = st.text_input('Search Query', '')


    if query:

        text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()

        text_embedding = text_embedding/np.linalg.norm(text_embedding)

        best_id = np.argmax(text_embedding.dot(image_embedding.T))
        url = (plip_dataset.iloc[best_id]["imageURL"])

        response = requests.get(url)
        img = Image.open(BytesIO(response.content))
        st.image(img)