meepmoo's picture
Upload folder using huggingface_hub
0dcccdd verified
import os
from typing import List
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision.datasets.utils import download_url
from .longclip import longclip
from .viclip import get_viclip
from .video_utils import extract_frames
# All metrics.
__all__ = ["VideoCLIPXLScore"]
_MODELS = {
"ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth",
"LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt",
"VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin",
}
_MD5 = {
"ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0",
"LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3",
"VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35",
}
def normalize(
data: np.array,
mean: list[float] = [0.485, 0.456, 0.406],
std: list[float] = [0.229, 0.224, 0.225]
):
v_mean = np.array(mean).reshape(1, 1, 3)
v_std = np.array(std).reshape(1, 1, 3)
return (data / 255.0 - v_mean) / v_std
class VideoCLIPXL(nn.Module):
def __init__(self, root: str = "~/.cache/clip"):
super(VideoCLIPXL, self).__init__()
self.root = os.path.expanduser(root)
if not os.path.exists(self.root):
os.makedirs(self.root)
k = "LongCLIP-L"
filename = os.path.basename(_MODELS[k])
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float()
k = "ViClip-InternVid-10M-FLT"
filename = os.path.basename(_MODELS[k])
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float()
# delete unused encoder
del self.model.visual
del self.viclip_model.text_encoder
class VideoCLIPXLScore():
def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
self.root = os.path.expanduser(root)
if not os.path.exists(self.root):
os.makedirs(self.root)
k = "VideoCLIP-XL-v2"
filename = os.path.basename(_MODELS[k])
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
self.model = VideoCLIPXL()
state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu")
self.model.load_state_dict(state_dict)
self.model.to(device)
self.device = device
def __call__(self, videos: List[List[Image.Image]], texts: List[str]):
assert len(videos) == len(texts)
# Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3].
videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos]
resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos]
resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos]
video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos])
video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True) # BTCHW
with torch.no_grad():
vid_features = torch.stack(
[self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs]
)
vid_features.squeeze_()
# vid_features = self.model.viclip_model.get_vid_features(video_inputs).float()
text_inputs = longclip.tokenize(texts, truncate=True).to(self.device)
text_features = self.model.model.encode_text(text_inputs)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
scores = text_features @ vid_features.T
return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist()
def __repr__(self):
return "videoclipxl_score"
if __name__ == "__main__":
videos = ["your_video_path"] * 3
texts = [
"a joker",
"glasses and flower",
"The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues."
]
video_clip_xl_score = VideoCLIPXLScore(device="cuda")
batch_frames = []
for v in videos:
sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1]
batch_frames.append(sampled_frames)
print(video_clip_xl_score(batch_frames, texts))