Spaces:
Runtime error
Runtime error
import os | |
from typing import List | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision.datasets.utils import download_url | |
from .longclip import longclip | |
from .viclip import get_viclip | |
from .video_utils import extract_frames | |
# All metrics. | |
__all__ = ["VideoCLIPXLScore"] | |
_MODELS = { | |
"ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth", | |
"LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt", | |
"VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin", | |
} | |
_MD5 = { | |
"ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0", | |
"LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3", | |
"VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35", | |
} | |
def normalize( | |
data: np.array, | |
mean: list[float] = [0.485, 0.456, 0.406], | |
std: list[float] = [0.229, 0.224, 0.225] | |
): | |
v_mean = np.array(mean).reshape(1, 1, 3) | |
v_std = np.array(std).reshape(1, 1, 3) | |
return (data / 255.0 - v_mean) / v_std | |
class VideoCLIPXL(nn.Module): | |
def __init__(self, root: str = "~/.cache/clip"): | |
super(VideoCLIPXL, self).__init__() | |
self.root = os.path.expanduser(root) | |
if not os.path.exists(self.root): | |
os.makedirs(self.root) | |
k = "LongCLIP-L" | |
filename = os.path.basename(_MODELS[k]) | |
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) | |
self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float() | |
k = "ViClip-InternVid-10M-FLT" | |
filename = os.path.basename(_MODELS[k]) | |
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) | |
self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float() | |
# delete unused encoder | |
del self.model.visual | |
del self.viclip_model.text_encoder | |
class VideoCLIPXLScore(): | |
def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"): | |
self.root = os.path.expanduser(root) | |
if not os.path.exists(self.root): | |
os.makedirs(self.root) | |
k = "VideoCLIP-XL-v2" | |
filename = os.path.basename(_MODELS[k]) | |
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) | |
self.model = VideoCLIPXL() | |
state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu") | |
self.model.load_state_dict(state_dict) | |
self.model.to(device) | |
self.device = device | |
def __call__(self, videos: List[List[Image.Image]], texts: List[str]): | |
assert len(videos) == len(texts) | |
# Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3]. | |
videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos] | |
resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos] | |
resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos] | |
video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos]) | |
video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True) # BTCHW | |
with torch.no_grad(): | |
vid_features = torch.stack( | |
[self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs] | |
) | |
vid_features.squeeze_() | |
# vid_features = self.model.viclip_model.get_vid_features(video_inputs).float() | |
text_inputs = longclip.tokenize(texts, truncate=True).to(self.device) | |
text_features = self.model.model.encode_text(text_inputs) | |
text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
scores = text_features @ vid_features.T | |
return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist() | |
def __repr__(self): | |
return "videoclipxl_score" | |
if __name__ == "__main__": | |
videos = ["your_video_path"] * 3 | |
texts = [ | |
"a joker", | |
"glasses and flower", | |
"The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues." | |
] | |
video_clip_xl_score = VideoCLIPXLScore(device="cuda") | |
batch_frames = [] | |
for v in videos: | |
sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1] | |
batch_frames.append(sampled_frames) | |
print(video_clip_xl_score(batch_frames, texts)) |