|
from statistics import median |
|
from skimage.metrics import structural_similarity |
|
|
|
def getSSIM(gt, out, gt_flag=None, data_range=1): |
|
if gt_flag is None: |
|
gt_flag = [True]*gt.shape[0] |
|
|
|
vals = [] |
|
for i in range(gt.shape[0]): |
|
if not gt_flag[i]: |
|
continue |
|
vals.extend( |
|
structural_similarity( |
|
gt[i, j, ...], out[i, j, ...], data_range=data_range |
|
) |
|
for j in range(gt.shape[1]) |
|
) |
|
return median(vals) |
|
|
|
def ema(source, target, decay): |
|
source_dict = source.state_dict() |
|
target_dict = target.state_dict() |
|
for key in source_dict.keys(): |
|
target_dict[key].data.copy_(target_dict[key].data * decay + |
|
source_dict[key].data * (1 - decay)) |
|
|
|
|
|
class WarmupLR: |
|
def __init__(self, warmup) -> None: |
|
self.warmup = warmup |
|
|
|
def __call__(self, step): |
|
return min(step, self.warmup) / self.warmup |