File size: 4,461 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright Hear Benchmark Team
# Copyright Shu-wen Yang (refactor from https://github.com/hearbenchmark/hear-eval-kit)

from typing import List

import torch

from s3prl.dataio.encoder.category import CategoryEncoder
from s3prl.task.base import Task

from ._hear_score import available_scores, validate_score_return_type

__all__ = ["ScenePredictionTask"]


class OneHotToCrossEntropyLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        assert torch.all(torch.sum(y, dim=1) == y.new_ones(y.shape[0]))
        y = y.argmax(dim=1)
        return self.loss(y_hat, y)


class ScenePredictionTask(Task):
    def __init__(
        self,
        model: torch.nn.Module,
        category: CategoryEncoder,
        prediction_type: str,
        scores: List[str],
    ):
        super().__init__()
        self.model = model

        self.label_to_idx = {
            str(category.decode(idx)): idx for idx in range(len(category))
        }
        self.idx_to_label = {
            idx: str(category.decode(idx)) for idx in range(len(category))
        }
        self.scores = [
            available_scores[score](label_to_idx=self.label_to_idx) for score in scores
        ]

        if prediction_type == "multilabel":
            self.activation: torch.nn.Module = torch.nn.Sigmoid()
            self.logit_loss = torch.nn.BCEWithLogitsLoss()
        elif prediction_type == "multiclass":
            self.activation = torch.nn.Softmax(dim=-1)
            self.logit_loss = OneHotToCrossEntropyLoss()
        else:
            raise ValueError(f"Unknown prediction_type {prediction_type}")

    def predict(self, x, x_len):
        logits, _ = self.model(x, x_len)
        prediction = self.activation(logits)
        return prediction, logits

    def forward(
        self, _mode: str, x, x_len, y, labels, unique_name: str, _dump_dir: str = None
    ):
        y_pr, y_hat = self.predict(x, x_len)
        loss = self.logit_loss(y_hat.float(), y.float())

        cacheable = dict(
            loss=loss.detach().cpu().item(),
            label=y.detach().cpu().unbind(dim=0),  # (batch_size, num_class)
            logit=y_hat.detach().cpu().unbind(dim=0),  # (batch_size, num_class)
            prediction=y_pr.detach().cpu().unbind(dim=0),  # (batch_size, num_class)
        )

        return loss, cacheable

    def log_scores(self, score_args):
        """Logs the metric score value for each score defined for the model"""
        assert hasattr(self, "scores"), "Scores for the model should be defined"
        end_scores = {}
        # The first score in the first `self.scores` is the optimization criterion
        for score in self.scores:
            score_ret = score(*score_args)
            validate_score_return_type(score_ret)
            # If the returned score is a tuple, store each subscore as separate entry
            if isinstance(score_ret, tuple):
                end_scores[f"{score}"] = score_ret[0][1]
                # All other scores will also be logged
                for subscore, value in score_ret:
                    end_scores[f"{score}_{subscore}"] = value
            elif isinstance(score_ret, float):
                end_scores[f"{score}"] = score_ret
            else:
                raise ValueError(
                    f"Return type {type(score_ret)} is unexpected. Return type of "
                    "the score function should either be a "
                    "tuple(tuple) or float."
                )
        return end_scores

    def reduction(
        self,
        _mode: str,
        cached_results: List[dict],
        _dump_dir: str = None,
    ):
        result = self.parse_cached_results(cached_results)

        target = torch.stack(result["label"], dim=0)
        prediction_logit = torch.stack(result["logit"], dim=0)
        prediction = torch.stack(result["prediction"], dim=0)

        loss = self.logit_loss(prediction_logit, target)

        logs = dict(
            loss=loss.detach().cpu().item(),
        )

        if _mode in ["valid", "test"]:
            logs.update(
                self.log_scores(
                    score_args=(
                        prediction.detach().cpu().numpy(),
                        target.detach().cpu().numpy(),
                    ),
                )
            )

        return logs