File size: 986 Bytes
485c2ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from statistics import median
from skimage.metrics import structural_similarity

def getSSIM(gt, out, gt_flag=None, data_range=1):
    if gt_flag is None:  # all of the samples have GTs
        gt_flag = [True]*gt.shape[0]

    vals = []
    for i in range(gt.shape[0]):
        if not gt_flag[i]:
            continue
        vals.extend(
            structural_similarity(
                gt[i, j, ...], out[i, j, ...], data_range=data_range
            )
            for j in range(gt.shape[1])
        )
    return median(vals)

def ema(source, target, decay):
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(target_dict[key].data * decay +
                                    source_dict[key].data * (1 - decay))


class WarmupLR:
    def __init__(self, warmup) -> None:
        self.warmup = warmup

    def __call__(self, step):
        return min(step, self.warmup) / self.warmup