#!/usr/bin/python3 # -*- coding: utf-8 -*- import numpy as np import torch def anti_wrapping_function(x): return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) def phase_losses(phase_r, phase_g): ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) return ip_loss, gd_loss, iaf_loss if __name__ == '__main__': pass