webplip / helper.py
huangzhii
Add text embedding, allowing input to compare with both text and image
e571e8f
raw
history blame
1.91 kB
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import tokenizers
import os
from io import BytesIO
import pickle
import base64
import torch
from transformers import (
VisionTextDualEncoderModel,
AutoFeatureExtractor,
AutoTokenizer,
CLIPModel,
AutoProcessor
)
import streamlit.components.v1 as components
from st_clickable_images import clickable_images #pip install st-clickable-images
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None
}
)
def load_path_clip():
model = CLIPModel.from_pretrained("vinid/plip")
processor = AutoProcessor.from_pretrained("vinid/plip")
return model, processor
@st.cache
def init():
with open('data/twitter.asset', 'rb') as f:
data = pickle.load(f)
meta = data['meta'].reset_index(drop=True)
image_embedding = data['image_embedding']
text_embedding = data['text_embedding']
print(meta.shape, image_embedding.shape)
validation_subset_index = meta['source'].values == 'Val_Tweets'
return meta, image_embedding, text_embedding, validation_subset_index
def embed_images(model, images, processor):
inputs = processor(images=images)
pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
with torch.no_grad():
embeddings = model.get_image_features(pixel_values=pixel_values)
return embeddings
def embed_texts(model, texts, processor):
inputs = processor(text=texts, padding="longest")
input_ids = torch.tensor(inputs["input_ids"])
attention_mask = torch.tensor(inputs["attention_mask"])
with torch.no_grad():
embeddings = model.get_text_features(
input_ids=input_ids, attention_mask=attention_mask
)
return embeddings