File size: 11,564 Bytes
73eb2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
import tempfile
import scipy.io.wavfile as wav
import ffmpeg
import cv2
from PIL import Image

import decord
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, GaussianBlur, Grayscale, Resize
import torchaudio

decord.bridge.set_bridge('torch')
torchaudio.set_audio_backend("sox_io")


class AudioEncoder(nn.Module):
    """
    A PyTorch Module to encode audio data into a fixed-size vector 
    (also known as an "embedding"). This can be useful for various machine 
    learning tasks such as classification, similarity matching, etc.
    """
    def __init__(self, path):
        """
        Initialize the AudioEncoder object.

        Args:
            path (str): The file path where the pre-trained model is stored.
        """
        super().__init__()
        self.model = torch.jit.load(path)
        self.register_buffer('hidden', torch.zeros(2, 1, 256))

    def forward(self, audio):
        """
        The forward method is where the actual encoding happens. Given an 
        audio sample, this function returns its corresponding embedding.

        Args:
            audio (Tensor): A PyTorch tensor containing the audio data.
        
        Returns:
            Tensor: The embedding of the given audio.
        """
        self.reset()
        x = create_windowed_sequence(audio, 3200, cutting_stride=640, pad_samples=3200-640, cut_dim=1)
        embs = []                                           
        for i in range(x.shape[1]):
            emb, _, self.hidden = self.model(x[:, i], torch.LongTensor([3200]), init_state=self.hidden)
            embs.append(emb)
        return torch.vstack(embs)

    def reset(self):
        """
        Resets the hidden states in the model. Call this function 
        before processing a new audio sample to ensure that there is 
        no state carried over from the previous sample.
        """
        self.hidden = torch.zeros(2, 1, 256).to(self.hidden.device)


def get_audio_emb(audio_path, checkpoint, device):
    """
    This function takes the path of an audio file, loads it into a 
    PyTorch tensor, and returns its embedding.

    Args:
        audio_path (str): The file path of the audio to be loaded.
        checkpoint (str): The file path of the pre-trained model.
        device (str): The computing device ('cpu' or 'cuda').

    Returns:
        Tensor, Tensor: The original audio as a tensor and its corresponding embedding.
    """
    audio, audio_rate = torchaudio.load(audio_path, channels_first=False)
    assert audio_rate == 16000, 'Only 16 kHZ audio is supported.'
    audio = audio[None, None, :, 0].to(device)

    audio_encoder = AudioEncoder(checkpoint).to(device)

    emb = audio_encoder(audio)
    return audio, emb


def get_id_frame(path, random=False, resize=128):
    """
    Retrieves a frame from either a video or image file. This frame can 
    serve as an identifier or reference for the video or image.

    Args:
        path (str): File path to the video or image.
        random (bool): Whether to randomly select a frame from the video.
        resize (int): The dimensions to which the frame should be resized.

    Returns:
        Tensor: The image frame as a tensor.
    """
    if path.endswith('.mp4'):
        vr = decord.VideoReader(path)
        if random:
            idx = [np.random.randint(len(vr))]
        else:
            idx = [0]
        frame = vr.get_batch(idx).permute(0, 3, 1, 2)
    else:
        frame = load_image_to_torch(path).unsqueeze(0)
    
    frame = (frame / 255) * 2 - 1
    frame = Resize((resize, resize), antialias=True)(frame).float()
    return frame


def get_motion_transforms(args):
    """
    Applies a series of transformations like Gaussian blur and grayscale 
    conversion based on the provided arguments. This is commonly used for 
    data augmentation or preprocessing.

    Args:
        args (Namespace): Arguments containing options for motion transformations.
    
    Returns:
        Compose: A composed function of transforms.
    """
    motion_transforms = []
    if args.motion_blur:
        motion_transforms.append(GaussianBlur(5, sigma=2.0))
    if args.grayscale_motion:
        motion_transforms.append(Grayscale(1))
    return Compose(motion_transforms)


def save_audio(path, audio, audio_rate=16000):
    """
    Saves the audio data as a WAV file.

    Args:
        path (str): The file path where the audio will be saved.
        audio (Tensor or np.array): The audio data.
        audio_rate (int): The sampling rate of the audio, defaults to 16000Hz.
    """
    if torch.is_tensor(audio):
        aud = audio.squeeze().detach().cpu().numpy()
    else:
        aud = audio.copy()  # Make a copy so that we don't alter the object

    aud = ((2 ** 15) * aud).astype(np.int16)
    wav.write(path, audio_rate, aud)


def save_video(path, video, fps=25, scale=2, audio=None, audio_rate=16000, overlay_pts=None, ffmpeg_experimental=False):
    """
    Saves the video data as an MP4 file. Optionally includes audio and overlay points.

    Args:
        path (str): The file path where the video will be saved.
        video (Tensor or np.array): The video data.
        fps (int): Frames per second of the video.
        scale (int): Scaling factor for the video dimensions.
        audio (Tensor or np.array, optional): Audio data.
        audio_rate (int, optional): The sampling rate for the audio.
        overlay_pts (list of points, optional): Points to overlay on the video frames.
        ffmpeg_experimental (bool): Whether to use experimental ffmpeg options.

    Returns:
        bool: Success status.
    """
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    success = True    
    out_size = (scale * video.shape[-1], scale * video.shape[-2])
    video_path = get_temp_path(os.path.split(path)[0], ext=".mp4")
    if torch.is_tensor(video):
        vid = video.squeeze().detach().cpu().numpy()
    else:
        vid = video.copy()  # Make a copy so that we don't alter the object

    if np.min(vid) < 0:
        vid = 127 * vid + 127
    elif np.max(vid) <= 1:
        vid = 255 * vid

    is_color = True
    if vid.ndim == 3:
        is_color = False

    writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), float(fps), out_size, isColor=is_color)
    for i, frame in enumerate(vid):
        if is_color:
            frame = cv2.cvtColor(np.rollaxis(frame, 0, 3), cv2.COLOR_RGB2BGR)

        if scale != 1:
            frame = cv2.resize(frame, out_size)

        write_frame = frame.astype('uint8')

        if overlay_pts is not None:
            for pt in overlay_pts[i]:
                cv2.circle(write_frame, (int(scale * pt[0]), int(scale * pt[1])), 2, (0, 0, 0), -1)

        writer.write(write_frame)
    writer.release()

    inputs = [ffmpeg.input(video_path)['v']]

    if audio is not None:  # Save the audio file
        audio_path = swp_extension(video_path, ".wav")
        save_audio(audio_path, audio, audio_rate)
        inputs += [ffmpeg.input(audio_path)['a']]

    try:
        if ffmpeg_experimental:
            out = ffmpeg.output(*inputs, path, strict='-2', loglevel="panic", vcodec='h264').overwrite_output()
        else:
            out = ffmpeg.output(*inputs, path, loglevel="panic", vcodec='h264').overwrite_output()
        out.run(quiet=True)
    except:
        success = False

    if audio is not None and os.path.isfile(audio_path):
        os.remove(audio_path)
    if os.path.isfile(video_path):
        os.remove(video_path)

    return success


def load_image_to_torch(dir):
    """
    Load an image from disk and convert it to a PyTorch tensor.
    
    Args:
        dir (str): The directory path to the image file.
        
    Returns:
        torch.Tensor: A tensor representation of the image.
    """
    img = Image.open(dir).convert('RGB')
    img = np.array(img)
    return torch.from_numpy(img).permute(2, 0, 1)


def get_temp_path(tmp_dir, mode="", ext=""):
    """
    Generate a temporary file path for storing data.
    
    Args:
        tmp_dir (str): The directory where the temporary file will be created.
        mode (str, optional): A string to append to the file name.
        ext (str, optional): The file extension.
        
    Returns:
        str: The full path to the temporary file.
    """
    file_path = next(tempfile._get_candidate_names()) + mode + ext
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)
    file_path = os.path.join(tmp_dir, file_path)
    return file_path


def swp_extension(file, ext):
    """
    Swap the extension of a given file name.
    
    Args:
        file (str): The original file name.
        ext (str): The new extension.
        
    Returns:
        str: The file name with the new extension.
    """
    return os.path.splitext(file)[0] + ext


def pad_both_ends(tensor, left, right, dim=0):
    """
    Pad a tensor on both ends along a specific dimension.
    
    Args:
        tensor (torch.Tensor): The tensor to be padded.
        left (int): The padding size for the left side.
        right (int): The padding size for the right side.
        dim (int, optional): The dimension along which to pad.
        
    Returns:
        torch.Tensor: The padded tensor.
    """
    no_dims = len(tensor.size())
    if dim == -1:
        dim = no_dims - 1

    padding = [0] * 2 * no_dims
    padding[2 * (no_dims - dim - 1)] = left
    padding[2 * (no_dims - dim - 1) + 1] = right
    return F.pad(tensor, padding, "constant", 0)


def cut_n_stack(seq, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0):
    """
    Divide a sequence tensor into smaller snips and stack them.
    
    Args:
        seq (torch.Tensor): The original sequence tensor.
        snip_length (int): The length of each snip.
        cut_dim (int, optional): The dimension along which to cut.
        cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length.
        pad_samples (int, optional): Number of samples to pad at both ends.
        
    Returns:
        torch.Tensor: A tensor containing the stacked snips.
    """
    if cutting_stride is None:
        cutting_stride = snip_length

    pad_left = pad_samples // 2
    pad_right = pad_samples - pad_samples // 2

    seq = pad_both_ends(seq, pad_left, pad_right, dim=cut_dim)

    stacked = seq.narrow(cut_dim, 0, snip_length).unsqueeze(0)
    iterations = (seq.size()[cut_dim] - snip_length) // cutting_stride + 1
    for i in range(1, iterations):
        stacked = torch.cat((stacked, seq.narrow(cut_dim, i * cutting_stride, snip_length).unsqueeze(0)))
    return stacked


def create_windowed_sequence(seqs, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0):
    """
    Create a windowed sequence from a list of sequences.
    
    Args:
        seqs (list of torch.Tensor): List of sequence tensors.
        snip_length (int): The length of each snip.
        cut_dim (int, optional): The dimension along which to cut.
        cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length.
        pad_samples (int, optional): Number of samples to pad at both ends.
        
    Returns:
        torch.Tensor: A tensor containing the windowed sequences.
    """
    windowed_seqs = []
    for seq in seqs:
        windowed_seqs.append(cut_n_stack(seq, snip_length, cut_dim, cutting_stride, pad_samples).unsqueeze(0))

    return torch.cat(windowed_seqs)