Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from plip_support import embed_text | |
import numpy as np | |
from PIL import Image | |
import requests | |
import tokenizers | |
from io import BytesIO | |
import torch | |
from transformers import ( | |
VisionTextDualEncoderModel, | |
AutoFeatureExtractor, | |
AutoTokenizer, | |
CLIPModel, | |
AutoProcessor | |
) | |
import streamlit.components.v1 as components | |
def embed_images(model, images, processor): | |
inputs = processor(images=images) | |
pixel_values = torch.tensor(np.array(inputs["pixel_values"])) | |
with torch.no_grad(): | |
embeddings = model.get_image_features(pixel_values=pixel_values) | |
return embeddings | |
def load_embeddings(embeddings_path): | |
print("loading embeddings") | |
return np.load(embeddings_path) | |
def load_path_clip(): | |
model = CLIPModel.from_pretrained("vinid/plip") | |
processor = AutoProcessor.from_pretrained("vinid/plip") | |
return model, processor | |
def app(): | |
st.title('PLIP Image Search') | |
plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") | |
plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t") | |
model, processor = load_path_clip() | |
image_embedding = load_embeddings("tweet_eval_embeddings.npy") | |
query = st.file_uploader("Choose a file") | |
if query: | |
image = Image.open(query) | |
single_image = embed_images(model, [image], processor)[0].detach().cpu().numpy() | |
single_image = single_image/np.linalg.norm(single_image) | |
# Sort IDs by cosine-similarity from high to low | |
similarity_scores = single_image.dot(image_embedding.T) | |
id_sorted = np.argsort(similarity_scores)[::-1] | |
best_id = id_sorted[0] | |
score = similarity_scores[best_id] | |
target_weblink = plip_weblink.iloc[best_id]["weblink"] | |
st.caption('Most relevant image (similarity = %.4f)' % score) | |
components.html(''' | |
<blockquote class="twitter-tweet"> | |
<a href="%s"></a> | |
</blockquote> | |
<script async src="https://platform.twitter.com/widgets.js" charset="utf-8"> | |
</script> | |
''' % target_weblink, | |
height=600) | |