Spaces:
Sleeping
Sleeping
from numpy.lib.npyio import load | |
from torch._C import device | |
import sys | |
sys.path.append('/scratch/shared/beegfs/szwu/projects/video3d/RAFT') | |
from core.raft import RAFT | |
from .utils import InputPadder | |
import torch | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
class FlowModel(): | |
def __init__(self, model, device): | |
args = AttrDict({'model': model, 'small': False, 'mixed_precision': False, 'alternate_corr': False}) | |
self.model = self.load_model(args, device) | |
self.device = device | |
def load_model(args, device): | |
model = torch.nn.DataParallel(RAFT(args)) | |
model.load_state_dict(torch.load(args.model)) | |
model = model.module | |
model.to(device) | |
model.eval() | |
return model | |
def preprocess_image(self, image): | |
# image = image[:, :, ::-1].copy() | |
image = torch.from_numpy(image).permute(2, 0, 1).float() | |
image = image.to(self.device) | |
image = image[None] | |
# size = [540, 960] | |
# image = torch.nn.functional.interpolate(image, size=size, mode='bilinear', align_corners=False) | |
padder = InputPadder(image.shape) | |
return padder.pad(image)[0], padder | |
def compute_flow(self, frame, next_frame, iters=20): | |
frame, padder = self.preprocess_image(frame) | |
next_frame, padder = self.preprocess_image(next_frame) | |
_, flow = self.model(frame, next_frame, iters=iters, test_mode=True) | |
return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy() | |