File size: 3,993 Bytes
ed00004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from pathlib import Path
from urllib.parse import urlparse

import numpy as np
from PIL import Image


def get_video_frames(video_pth, frames_video=15):
    import cv2

    assert Path(video_pth).exists(), f"Video {video_pth} does not exist"

    video_pth = str(video_pth)

    # use OpenCV to read the video
    cap = cv2.VideoCapture(video_pth)

    # get the total number of frames in the video
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frame_idxs = sample_frames(total_frames, n_frames=frames_video)

    frames = []
    f_idxs = []
    for frame_idx in frame_idxs:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()

        if not ret or frame is None:
            print(f"Video {video_pth} is corrupted")
            frames = [
                Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8))
            ] * frames_video
            f_idxs = [-1] * frames_video
            return frames, f_idxs

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(Image.fromarray(frame))
        f_idxs.append(frame_idx)

    # pad frames to have the same number of frames
    n_frames = len(frames)
    if n_frames < frames_video:
        frames += [Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8))] * (
            frames_video - n_frames
        )

    # Add -1 to f_idxs for the remaining frames
    f_idxs += [-1] * (frames_video - len(f_idxs))

    return frames, f_idxs


def sample_frames(vlen, n_frames=15):
    acc_samples = min(vlen, n_frames)
    intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
    ranges = []
    for idx, interv in enumerate(intervals[:-1]):
        ranges.append((interv, intervals[idx + 1] - 1))

    frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]

    return frame_idxs


def concat_h_imgs(im_list, resample=Image.Resampling.BICUBIC):
    min_height = min(im.height for im in im_list)
    im_list_resize = [
        im.resize(
            (int(im.width * min_height / im.height), min_height), resample=resample
        )
        for im in im_list
    ]
    total_width = sum(im.width for im in im_list_resize)
    dst = Image.new("RGB", (total_width, min_height))
    pos_x = 0
    for im in im_list_resize:
        dst.paste(im, (pos_x, 0))
        pos_x += im.width
    return dst


def extract_frames(url, n_frames=10):
    import urllib.request

    import cv2
    import numpy as np
    from PIL import Image

    # Download the video from the URL
    resp = urllib.request.urlopen(url)
    video = resp.read()

    # Load the video using OpenCV
    video = cv2.VideoCapture(url)

    # Get the total number of frames in the video
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    # Sample `n_frames` frames from the video
    frames_idxs = sample_frames(total_frames, n_frames=n_frames)

    frames = []
    for frame_number in frames_idxs:
        video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, frame = video.read()
        if ret:
            # Convert the color channels from BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # Convert the NumPy array to a PIL Image object
            pil_image = Image.fromarray(np.uint8(frame))
            frames.append(pil_image)
    return frames


def visualize_url_video(url, n_frames=10):
    frames = extract_frames(url, n_frames=n_frames)
    if n_frames == 1:
        return frames[0]
    return concat_h_imgs(frames)


def visualize_pth_video(video_pth, n_frames=10):
    frames, _ = get_video_frames(video_pth, frames_video=n_frames)
    if n_frames == 1:
        return frames[0]
    return concat_h_imgs(frames)


def visualize_video(video, n_frames=10):
    if is_url(video):
        return visualize_url_video(video, n_frames=n_frames)
    return visualize_pth_video(video, n_frames=n_frames)


def is_url(url_or_filename):
    parsed = urlparse(str(url_or_filename))
    return parsed.scheme in ("http", "https")