import streamlit as st import pandas as pd from plip_support import embed_text import numpy as np from PIL import Image import requests import tokenizers from io import BytesIO import torch from transformers import ( VisionTextDualEncoderModel, AutoFeatureExtractor, AutoTokenizer, CLIPModel, AutoProcessor ) import streamlit.components.v1 as components def embed_texts(model, texts, processor): inputs = processor(text=texts, padding="longest") input_ids = torch.tensor(inputs["input_ids"]) attention_mask = torch.tensor(inputs["attention_mask"]) with torch.no_grad(): embeddings = model.get_text_features( input_ids=input_ids, attention_mask=attention_mask ) return embeddings @st.cache def load_embeddings(embeddings_path): print("loading embeddings") return np.load(embeddings_path) @st.cache( hash_funcs={ torch.nn.parameter.Parameter: lambda _: None, tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None } ) def load_path_clip(): model = CLIPModel.from_pretrained("vinid/plip") processor = AutoProcessor.from_pretrained("vinid/plip") return model, processor def app(): st.title('PLIP Image Search') plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t") model, processor = load_path_clip() image_embedding = load_embeddings("tweet_eval_embeddings.npy") query = st.text_input('Search Query', '') if query: text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy() text_embedding = text_embedding/np.linalg.norm(text_embedding) # Sort IDs by cosine-similarity from high to low similarity_scores = text_embedding.dot(image_embedding.T) id_sorted = np.argsort(similarity_scores)[::-1] best_id = id_sorted[0] score = similarity_scores[best_id] target_url = plip_imgURL.iloc[best_id]["imageURL"] target_weblink = plip_weblink.iloc[best_id]["weblink"] st.caption('Most relevant image (similarity = %.4f)' % score) #response = requests.get(target_url) #img = Image.open(BytesIO(response.content)) #st.image(img) components.html('''
''' % target_weblink, height=600)