import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import pdb from basicsr.archs.gmflow.gmflow.gmflow import GMFlow class FlowGenerator(nn.Module): """GM flow generation. Args: path (str): Pre-trained path. Default: None. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. """ def __init__(self, path=None, requires_grad=False,): super().__init__() self.model = GMFlow() if path != None: weights = torch.load( path, map_location=lambda storage, loc: storage)['model'] self.model.load_state_dict(weights, strict=True) if not requires_grad: self.model.eval() for param in self.parameters(): param.requires_grad = False else: self.model.train() for param in self.parameters(): param.requires_grad = True def forward(self, im1, im2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1]): """Forward function. Args: im1 (Tensor): Input tensor with shape (n, c, h, w). im2 (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ assert im1.shape == im2.shape N, C, H, W = im1.shape im1 = (im1 + 1) / 2 * 255 im2 = (im2 + 1) / 2 * 255 flow = self.model(im1, im2, attn_splits_list=attn_splits_list, corr_radius_list=corr_radius_list, prop_radius_list=prop_radius_list, pred_bidir_flow=False)['flow_preds'][-1] # backward_flow = flow[N:] return flow if __name__ == '__main__': h, w = 512, 512 # model = RAFT().cuda() model = FlowGenerator( load_path='../../weights/GMFlow/gmflow_sintel-0c07dcb3.pth').cuda() model.eval() print(model) x = torch.randn((1, 3, h, w)).cuda() y = torch.randn((1, 3, h, w)).cuda() with torch.no_grad(): out = model(x, y) pdb.set_trace() print(out.shape)