from statistics import median from skimage.metrics import structural_similarity def getSSIM(gt, out, gt_flag=None, data_range=1): if gt_flag is None: # all of the samples have GTs gt_flag = [True]*gt.shape[0] vals = [] for i in range(gt.shape[0]): if not gt_flag[i]: continue vals.extend( structural_similarity( gt[i, j, ...], out[i, j, ...], data_range=data_range ) for j in range(gt.shape[1]) ) return median(vals) def ema(source, target, decay): source_dict = source.state_dict() target_dict = target.state_dict() for key in source_dict.keys(): target_dict[key].data.copy_(target_dict[key].data * decay + source_dict[key].data * (1 - decay)) class WarmupLR: def __init__(self, warmup) -> None: self.warmup = warmup def __call__(self, step): return min(step, self.warmup) / self.warmup