File size: 5,213 Bytes
0dcccdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from typing import List

import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision.datasets.utils import download_url

from .longclip import longclip
from .viclip import get_viclip
from .video_utils import extract_frames

# All metrics.
__all__ = ["VideoCLIPXLScore"]

_MODELS = {
    "ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth",
    "LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt",
    "VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin",
}
_MD5 = {
    "ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0",
    "LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3",
    "VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35",
}

def normalize(
    data: np.array,
    mean: list[float] = [0.485, 0.456, 0.406],
    std: list[float] = [0.229, 0.224, 0.225]
):
    v_mean = np.array(mean).reshape(1, 1, 3)
    v_std = np.array(std).reshape(1, 1, 3)

    return (data / 255.0 - v_mean) / v_std


class VideoCLIPXL(nn.Module):
    def __init__(self, root: str = "~/.cache/clip"):
        super(VideoCLIPXL, self).__init__()

        self.root = os.path.expanduser(root)
        if not os.path.exists(self.root):
            os.makedirs(self.root)
        
        k = "LongCLIP-L"
        filename = os.path.basename(_MODELS[k])
        download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
        self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float()

        k = "ViClip-InternVid-10M-FLT"
        filename = os.path.basename(_MODELS[k])
        download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
        self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float()

        # delete unused encoder
        del self.model.visual
        del self.viclip_model.text_encoder


class VideoCLIPXLScore():
    def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
        self.root = os.path.expanduser(root)
        if not os.path.exists(self.root):
            os.makedirs(self.root)

        k = "VideoCLIP-XL-v2"
        filename = os.path.basename(_MODELS[k])
        download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
        self.model = VideoCLIPXL()
        state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu")
        self.model.load_state_dict(state_dict)
        self.model.to(device)

        self.device = device
    
    def __call__(self, videos: List[List[Image.Image]], texts: List[str]):
        assert len(videos) == len(texts)

        # Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3].
        videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos]
        resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos]
        resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos]

        video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos])
        video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True)  # BTCHW

        with torch.no_grad():
            vid_features = torch.stack(
                [self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs]
            )
            vid_features.squeeze_()
            # vid_features = self.model.viclip_model.get_vid_features(video_inputs).float()
            text_inputs = longclip.tokenize(texts, truncate=True).to(self.device)
            text_features = self.model.model.encode_text(text_inputs)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            scores = text_features @ vid_features.T
        
        return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist()
    
    def __repr__(self):
        return "videoclipxl_score"


if __name__ == "__main__":
    videos = ["your_video_path"] * 3
    texts = [
        "a joker",
        "glasses and flower",
        "The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues."
    ]

    video_clip_xl_score = VideoCLIPXLScore(device="cuda")
    batch_frames = []
    for v in videos:
        sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1]
        batch_frames.append(sampled_frames)
    print(video_clip_xl_score(batch_frames, texts))