Spaces:
Sleeping
Sleeping
File size: 867 Bytes
9bb001a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
from einops import rearrange
@torch.no_grad()
def get_optical_flow(raft_model, pixel_values, video_length, encode_chunk_size=48, num_flow_updates=14):
imgs_1 = pixel_values[:, :-1]
imgs_2 = pixel_values[:, 1:]
imgs_1 = rearrange(imgs_1, "b f c h w -> (b f) c h w")
imgs_2 = rearrange(imgs_2, "b f c h w -> (b f) c h w")
flow_embedding = []
for i in range(0, imgs_1.shape[0], encode_chunk_size):
imgs_1_chunk = imgs_1[i:i + encode_chunk_size]
imgs_2_chunk = imgs_2[i:i + encode_chunk_size]
flow_embedding_chunk = raft_model(imgs_1_chunk, imgs_2_chunk, num_flow_updates)[-1]
flow_embedding.append(flow_embedding_chunk)
flow_embedding = torch.cat(flow_embedding).contiguous()
flow_embedding = rearrange(flow_embedding, "(b f) c h w -> b c f h w", f=video_length)
return flow_embedding |