endo-yuki-t
initial commit
d7dbcdd
raw
history blame
3.19 kB
"""
Taken from https://github.com/ClementPinard/FlowNetPytorch
"""
import pdb
import torch
import torch.nn.functional as F
def EPE(input_flow, target_flow, mask, sparse=False, mean=True):
#mask = target_flow[:,2]>0
target_flow = target_flow[:,:2]
EPE_map = torch.norm(target_flow-input_flow,2,1)
batch_size = EPE_map.size(0)
if sparse:
# invalid flow is defined with both flow coordinates to be exactly 0
mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
EPE_map = EPE_map[~mask]
if mean:
return EPE_map[mask].mean()
else:
return EPE_map[mask].sum()/batch_size
def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True):
#mask = target_flow[:,2]>0
target_flow = target_flow[:,:2]
#TODO
# EPE_map = torch.norm(target_flow-input_flow,2,1)
EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4)
batch_size = EPE_map.size(0)
if sparse:
# invalid flow is defined with both flow coordinates to be exactly 0
mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
EPE_map = EPE_map[~mask]
if mean:
return EPE_map[mask].mean()
else:
return EPE_map[mask].sum()/batch_size
def sparse_max_pool(input, size):
'''Downsample the input by considering 0 values as invalid.
Unfortunately, no generic interpolation mode can resize a sparse map correctly,
the strategy here is to use max pooling for positive values and "min pooling"
for negative values, the two results are then summed.
This technique allows sparsity to be minized, contrary to nearest interpolation,
which could potentially lose information for isolated data points.'''
positive = (input > 0).float()
negative = (input < 0).float()
output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
return output
def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False):
def one_scale(output, target, mask, sparse):
b, _, h, w = output.size()
if sparse:
target_scaled = sparse_max_pool(target, (h, w))
else:
target_scaled = F.interpolate(target, (h, w), mode='area')
mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1
if rob_loss:
return rob_EPE(output, target_scaled, mask, sparse, mean=False)
else:
return EPE(output, target_scaled, mask, sparse, mean=False)
if type(network_output) not in [tuple, list]:
network_output = [network_output]
if weights is None:
weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
assert(len(weights) == len(network_output))
loss = 0
for output, weight in zip(network_output, weights):
loss += weight * one_scale(output, target_flow, mask, sparse)
return loss
def realEPE(output, target, mask, sparse=False):
b, _, h, w = target.size()
upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False)
return EPE(upsampled_output, target,mask, sparse, mean=True)