File size: 867 Bytes
9bb001a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from einops import rearrange

@torch.no_grad()
def get_optical_flow(raft_model, pixel_values, video_length, encode_chunk_size=48, num_flow_updates=14):
    imgs_1 = pixel_values[:, :-1]
    imgs_2 = pixel_values[:, 1:]
    imgs_1 = rearrange(imgs_1, "b f c h w -> (b f) c h w")
    imgs_2 = rearrange(imgs_2, "b f c h w -> (b f) c h w")

    flow_embedding = []

    for i in range(0, imgs_1.shape[0], encode_chunk_size):
        imgs_1_chunk = imgs_1[i:i + encode_chunk_size]
        imgs_2_chunk = imgs_2[i:i + encode_chunk_size]
        flow_embedding_chunk = raft_model(imgs_1_chunk, imgs_2_chunk, num_flow_updates)[-1]
        flow_embedding.append(flow_embedding_chunk)

    flow_embedding = torch.cat(flow_embedding).contiguous()
    flow_embedding = rearrange(flow_embedding, "(b f) c h w -> b c f h w", f=video_length)

    return flow_embedding