from numpy.lib.npyio import load
from torch._C import device
import sys
sys.path.append('/scratch/shared/beegfs/szwu/projects/video3d/RAFT')
from core.raft import RAFT

from .utils import InputPadder
import torch


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self



class FlowModel():
    def __init__(self, model, device):
        args = AttrDict({'model': model, 'small': False, 'mixed_precision': False, 'alternate_corr': False})
        self.model = self.load_model(args, device)
        self.device = device


    @staticmethod
    def load_model(args, device):
        model = torch.nn.DataParallel(RAFT(args))
        model.load_state_dict(torch.load(args.model))

        model = model.module
        model.to(device)
        model.eval()
        return model


    def preprocess_image(self, image):
        # image = image[:, :, ::-1].copy()
        image = torch.from_numpy(image).permute(2, 0, 1).float()
        image = image.to(self.device)
        image = image[None]
        # size = [540, 960]
        # image = torch.nn.functional.interpolate(image, size=size, mode='bilinear', align_corners=False)
        padder = InputPadder(image.shape)
        return padder.pad(image)[0], padder


    def compute_flow(self, frame, next_frame, iters=20):
        frame, padder = self.preprocess_image(frame)
        next_frame, padder = self.preprocess_image(next_frame)
        _, flow = self.model(frame, next_frame, iters=iters, test_mode=True)
        return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy()