File size: 3,232 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from gyraudio.audio_separation.properties import SIGNAL, NOISE, TOTAL, LOSS_TYPE, COEFFICIENT, SNR
import torch


def snr(prediction: torch.Tensor, ground_truth: torch.Tensor, reduce="mean") -> torch.Tensor:
    """Compute the SNR between two tensors.
    Args:
        prediction (torch.Tensor): prediction tensor
        ground_truth (torch.Tensor): ground truth tensor
    Returns:
        torch.Tensor: SNR
    """
    power_signal = torch.sum(ground_truth**2, dim=(-2, -1))
    power_error = torch.sum((prediction-ground_truth)**2, dim=(-2, -1))
    eps = torch.finfo(torch.float32).eps
    snr_per_element = 10*torch.log10((power_signal+eps)/(power_error+eps))
    final_snr = torch.mean(snr_per_element) if reduce == "mean" else snr_per_element
    return final_snr


DEFAULT_COST = {
    SIGNAL: {
        COEFFICIENT: 0.5,
        LOSS_TYPE: torch.nn.functional.mse_loss
    },
    NOISE: {
        COEFFICIENT: 0.5,
        LOSS_TYPE: torch.nn.functional.mse_loss
    },
    SNR: {
        LOSS_TYPE: snr
    }
}


class Costs:
    """Keep track of cost functions.
    ```
    for epoch in range(...):
        metric.reset_epoch()
        for step in dataloader(...):
            ... # forward
            prediction = model.forward(batch)
            metric.update(prediction1, groundtruth1, SIGNAL1)
            metric.update(prediction2, groundtruth2, SIGNAL2)
            loss = metric.finish_step()

            loss.backward()
            ... # backprop
        metric.finish_epoch()
        ... # log metrics
    ```
    """

    def __init__(self, name: str, costs=DEFAULT_COST) -> None:
        self.name = name
        self.keys = list(costs.keys())
        self.cost = costs

    def __reset_step(self) -> None:
        self.metrics = {key: 0. for key in self.keys}

    def reset_epoch(self) -> None:
        self.__reset_step()
        self.total_metric = {key: 0. for key in self.keys+[TOTAL]}
        self.count = 0

    def update(self,
               prediction: torch.Tensor,
               ground_truth: torch.Tensor,
               key: str
               ) -> torch.Tensor:
        assert key != TOTAL
        # Compute loss for a single batch (=step)
        loss_signal = self.cost[key][LOSS_TYPE](prediction, ground_truth)
        self.metrics[key] = loss_signal

    def finish_step(self) -> torch.Tensor:
        # Reset current total
        self.metrics[TOTAL] = 0.
        # Sum all metrics to total
        for key in self.metrics:
            if key != TOTAL and self.cost[key].get(COEFFICIENT, False):
                self.metrics[TOTAL] += self.cost[key][COEFFICIENT]*self.metrics[key]
        loss_signal = self.metrics[TOTAL]
        for key in self.metrics:
            if not isinstance(self.metrics[key], float):
                self.metrics[key] = self.metrics[key].item()
            self.total_metric[key] += self.metrics[key]
        self.count += 1
        return loss_signal

    def finish_epoch(self) -> None:
        for key in self.metrics:
            self.total_metric[key] /= self.count

    def __repr__(self) -> str:
        rep = f"{self.name}\t:\t"
        for key in self.total_metric:
            rep += f"{key}: {self.total_metric[key]:.3e} | "
        return rep