Spaces:
Sleeping
Sleeping
import torch | |
from einops import rearrange | |
def get_optical_flow(raft_model, pixel_values, video_length, encode_chunk_size=48, num_flow_updates=14): | |
imgs_1 = pixel_values[:, :-1] | |
imgs_2 = pixel_values[:, 1:] | |
imgs_1 = rearrange(imgs_1, "b f c h w -> (b f) c h w") | |
imgs_2 = rearrange(imgs_2, "b f c h w -> (b f) c h w") | |
flow_embedding = [] | |
for i in range(0, imgs_1.shape[0], encode_chunk_size): | |
imgs_1_chunk = imgs_1[i:i + encode_chunk_size] | |
imgs_2_chunk = imgs_2[i:i + encode_chunk_size] | |
flow_embedding_chunk = raft_model(imgs_1_chunk, imgs_2_chunk, num_flow_updates)[-1] | |
flow_embedding.append(flow_embedding_chunk) | |
flow_embedding = torch.cat(flow_embedding).contiguous() | |
flow_embedding = rearrange(flow_embedding, "(b f) c h w -> b c f h w", f=video_length) | |
return flow_embedding |