import streamlit as st import pandas as pd from plip_support import embed_text import numpy as np from PIL import Image import requests import transformers import tokenizers from io import BytesIO import streamlit as st from transformers import CLIPModel import clip import torch from transformers import ( VisionTextDualEncoderModel, AutoFeatureExtractor, AutoTokenizer ) from transformers import AutoProcessor def embed_texts(model, texts, processor): inputs = processor(text=texts, padding="longest") input_ids = torch.tensor(inputs["input_ids"]) attention_mask = torch.tensor(inputs["attention_mask"]) with torch.no_grad(): embeddings = model.get_text_features( input_ids=input_ids, attention_mask=attention_mask ) return embeddings @st.cache def load_embeddings(embeddings_path): print("loading embeddings") return np.load(embeddings_path) @st.cache( hash_funcs={ torch.nn.parameter.Parameter: lambda _: None, tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None } ) def load_path_clip(): model = CLIPModel.from_pretrained("vinid/plip") processor = AutoProcessor.from_pretrained("vinid/plip") return model, processor st.title('PLIP Image Search') plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") model, processor = load_path_clip() image_embedding = load_embeddings("tweet_eval_embeddings.npy") query = st.text_input('Search Query', '') if query: text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy() text_embedding = text_embedding/np.linalg.norm(text_embedding) best_id = np.argmax(text_embedding.dot(image_embedding.T)) url = (plip_dataset.iloc[best_id]["imageURL"]) response = requests.get(url) img = Image.open(BytesIO(response.content)) st.image(img)