Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from plip_support import embed_text | |
import numpy as np | |
from PIL import Image | |
import requests | |
import tokenizers | |
import os | |
from io import BytesIO | |
import pickle | |
import base64 | |
import torch | |
from transformers import ( | |
VisionTextDualEncoderModel, | |
AutoFeatureExtractor, | |
AutoTokenizer, | |
CLIPModel, | |
AutoProcessor | |
) | |
import streamlit.components.v1 as components | |
from st_clickable_images import clickable_images #pip install st-clickable-images | |
def load_path_clip(): | |
model = CLIPModel.from_pretrained("vinid/plip") | |
processor = AutoProcessor.from_pretrained("vinid/plip") | |
return model, processor | |
def init(): | |
with open('data/twitter.asset', 'rb') as f: | |
data = pickle.load(f) | |
meta = data['meta'].reset_index(drop=True) | |
image_embedding = data['image_embedding'] | |
text_embedding = data['text_embedding'] | |
print(meta.shape, image_embedding.shape) | |
validation_subset_index = meta['source'].values == 'Val_Tweets' | |
return meta, image_embedding, text_embedding, validation_subset_index | |
def embed_images(model, images, processor): | |
inputs = processor(images=images) | |
pixel_values = torch.tensor(np.array(inputs["pixel_values"])) | |
with torch.no_grad(): | |
embeddings = model.get_image_features(pixel_values=pixel_values) | |
return embeddings | |
def embed_texts(model, texts, processor): | |
inputs = processor(text=texts, padding="longest") | |
input_ids = torch.tensor(inputs["input_ids"]) | |
attention_mask = torch.tensor(inputs["attention_mask"]) | |
with torch.no_grad(): | |
embeddings = model.get_text_features( | |
input_ids=input_ids, attention_mask=attention_mask | |
) | |
return embeddings |