Flip-Sketch / gifs_filter.py
g-ronimo's picture
Upload 3 files
d6748e0 verified
raw
history blame
2.29 kB
# filter images
from PIL import Image, ImageSequence
import requests
from tqdm import tqdm
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
converted_len = int(clip_len * frame_sample_rate)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len
indices = np.linspace(start_idx, end_idx, num=clip_len)
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
return indices
def load_frames(image: Image, mode='RGBA'):
return np.array([
np.array(frame.convert(mode))
for frame in ImageSequence.Iterator(image)
])
img_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
img_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def filter(gifs, input_image):
max_cosine = 0.9
max_gif = []
for gif in tqdm(gifs, total=len(gifs)):
with Image.open(gif) as im:
frames = load_frames(im)
frames = np.array(frames)
frames = frames[:, :, :, :3]
frames = np.transpose(frames, (0, 3, 1, 2))[1:]
image = Image.open(input_image)
inputs = img_processor(images=frames, return_tensors="pt", padding=False)
inputs_base = img_processor(images=image, return_tensors="pt", padding=False)
with torch.no_grad():
feat_img_base = img_model.get_image_features(pixel_values=inputs_base["pixel_values"])
feat_img_vid = img_model.get_image_features(pixel_values=inputs["pixel_values"])
cos_avg = 0
avg_score_for_vid = 0
for i in range(len(feat_img_vid)):
cosine_similarity = torch.nn.functional.cosine_similarity(
feat_img_base,
feat_img_vid[0].unsqueeze(0),
dim=1)
# print(cosine_similarity)
cos_avg += cosine_similarity.item()
cos_avg /= len(feat_img_vid)
print("Current cosine similarity: ", cos_avg)
print("Max cosine similarity: ", max_cosine)
if cos_avg > max_cosine:
# max_cosine = cos_avg
max_gif.append(gif)
return max_gif