Spaces:
Runtime error
Runtime error
""" | |
Taken from https://github.com/ClementPinard/FlowNetPytorch | |
""" | |
import pdb | |
import torch | |
import torch.nn.functional as F | |
def EPE(input_flow, target_flow, mask, sparse=False, mean=True): | |
#mask = target_flow[:,2]>0 | |
target_flow = target_flow[:,:2] | |
EPE_map = torch.norm(target_flow-input_flow,2,1) | |
batch_size = EPE_map.size(0) | |
if sparse: | |
# invalid flow is defined with both flow coordinates to be exactly 0 | |
mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) | |
EPE_map = EPE_map[~mask] | |
if mean: | |
return EPE_map[mask].mean() | |
else: | |
return EPE_map[mask].sum()/batch_size | |
def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True): | |
#mask = target_flow[:,2]>0 | |
target_flow = target_flow[:,:2] | |
#TODO | |
# EPE_map = torch.norm(target_flow-input_flow,2,1) | |
EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4) | |
batch_size = EPE_map.size(0) | |
if sparse: | |
# invalid flow is defined with both flow coordinates to be exactly 0 | |
mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) | |
EPE_map = EPE_map[~mask] | |
if mean: | |
return EPE_map[mask].mean() | |
else: | |
return EPE_map[mask].sum()/batch_size | |
def sparse_max_pool(input, size): | |
'''Downsample the input by considering 0 values as invalid. | |
Unfortunately, no generic interpolation mode can resize a sparse map correctly, | |
the strategy here is to use max pooling for positive values and "min pooling" | |
for negative values, the two results are then summed. | |
This technique allows sparsity to be minized, contrary to nearest interpolation, | |
which could potentially lose information for isolated data points.''' | |
positive = (input > 0).float() | |
negative = (input < 0).float() | |
output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size) | |
return output | |
def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False): | |
def one_scale(output, target, mask, sparse): | |
b, _, h, w = output.size() | |
if sparse: | |
target_scaled = sparse_max_pool(target, (h, w)) | |
else: | |
target_scaled = F.interpolate(target, (h, w), mode='area') | |
mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1 | |
if rob_loss: | |
return rob_EPE(output, target_scaled, mask, sparse, mean=False) | |
else: | |
return EPE(output, target_scaled, mask, sparse, mean=False) | |
if type(network_output) not in [tuple, list]: | |
network_output = [network_output] | |
if weights is None: | |
weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article | |
assert(len(weights) == len(network_output)) | |
loss = 0 | |
for output, weight in zip(network_output, weights): | |
loss += weight * one_scale(output, target_flow, mask, sparse) | |
return loss | |
def realEPE(output, target, mask, sparse=False): | |
b, _, h, w = target.size() | |
upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) | |
return EPE(upsampled_output, target,mask, sparse, mean=True) | |