soumickmj's picture
Upload DiffAE
ad947b4 verified
raw
history blame contribute delete
986 Bytes
from statistics import median
from skimage.metrics import structural_similarity
def getSSIM(gt, out, gt_flag=None, data_range=1):
if gt_flag is None: # all of the samples have GTs
gt_flag = [True]*gt.shape[0]
vals = []
for i in range(gt.shape[0]):
if not gt_flag[i]:
continue
vals.extend(
structural_similarity(
gt[i, j, ...], out[i, j, ...], data_range=data_range
)
for j in range(gt.shape[1])
)
return median(vals)
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
class WarmupLR:
def __init__(self, warmup) -> None:
self.warmup = warmup
def __call__(self, step):
return min(step, self.warmup) / self.warmup