|
import sys |
|
import os |
|
sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core')) |
|
from raft import RAFT |
|
from utils import flow_viz |
|
sys.path = sys.path[1:] |
|
import torch |
|
from cwm.utils import imagenet_unnormalize |
|
from torch import nn |
|
import argparse |
|
|
|
|
|
class Args: |
|
model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth') |
|
small = False |
|
path = None |
|
mixed_precision = False |
|
alternate_corr = False |
|
|
|
def __iter__(self): |
|
for attr, value in self.__dict__.items(): |
|
yield attr, value |
|
|
|
class RAFTInterface(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
args = Args() |
|
model = torch.nn.DataParallel(RAFT(args)) |
|
model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))) |
|
self.model = model.module |
|
self.model.eval() |
|
|
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
|
|
@staticmethod |
|
def prepare_inputs(x): |
|
|
|
if x.max() <= 1.0 and x.min() >= 0.: |
|
x = x * 255. |
|
elif x.min() < 0: |
|
x = imagenet_unnormalize(x) |
|
x = x * 255. |
|
|
|
return x |
|
|
|
def forward(self, x0, x1, return_magnitude=False): |
|
|
|
|
|
|
|
|
|
x0 = self.prepare_inputs(x0) |
|
x1 = self.prepare_inputs(x1) |
|
with torch.no_grad(): |
|
_, flow_up = self.model(x0, x1, iters=20, test_mode=True) |
|
|
|
if return_magnitude: |
|
flow_magnitude = flow_up.norm(p=2, dim=1) |
|
return flow_up, flow_magnitude |
|
|
|
return flow_up |
|
|
|
def viz(self, flow): |
|
flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy()) |
|
return flow_rgb |