""" Taken from https://github.com/ClementPinard/FlowNetPytorch """ import pdb import torch import torch.nn.functional as F def EPE(input_flow, target_flow, mask, sparse=False, mean=True): #mask = target_flow[:,2]>0 target_flow = target_flow[:,:2] EPE_map = torch.norm(target_flow-input_flow,2,1) batch_size = EPE_map.size(0) if sparse: # invalid flow is defined with both flow coordinates to be exactly 0 mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) EPE_map = EPE_map[~mask] if mean: return EPE_map[mask].mean() else: return EPE_map[mask].sum()/batch_size def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True): #mask = target_flow[:,2]>0 target_flow = target_flow[:,:2] #TODO # EPE_map = torch.norm(target_flow-input_flow,2,1) EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4) batch_size = EPE_map.size(0) if sparse: # invalid flow is defined with both flow coordinates to be exactly 0 mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) EPE_map = EPE_map[~mask] if mean: return EPE_map[mask].mean() else: return EPE_map[mask].sum()/batch_size def sparse_max_pool(input, size): '''Downsample the input by considering 0 values as invalid. Unfortunately, no generic interpolation mode can resize a sparse map correctly, the strategy here is to use max pooling for positive values and "min pooling" for negative values, the two results are then summed. This technique allows sparsity to be minized, contrary to nearest interpolation, which could potentially lose information for isolated data points.''' positive = (input > 0).float() negative = (input < 0).float() output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size) return output def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False): def one_scale(output, target, mask, sparse): b, _, h, w = output.size() if sparse: target_scaled = sparse_max_pool(target, (h, w)) else: target_scaled = F.interpolate(target, (h, w), mode='area') mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1 if rob_loss: return rob_EPE(output, target_scaled, mask, sparse, mean=False) else: return EPE(output, target_scaled, mask, sparse, mean=False) if type(network_output) not in [tuple, list]: network_output = [network_output] if weights is None: weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article assert(len(weights) == len(network_output)) loss = 0 for output, weight in zip(network_output, weights): loss += weight * one_scale(output, target_flow, mask, sparse) return loss def realEPE(output, target, mask, sparse=False): b, _, h, w = target.size() upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) return EPE(upsampled_output, target,mask, sparse, mean=True)