File size: 5,505 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Diarization Permutation Invarant Task

Authors
  * Jiatong Shi 2022
  * Leo 2022
"""

from pathlib import Path
from typing import List

import torch
import torch.nn as nn

from s3prl.metric.diarization import calc_diarization_error
from s3prl.nn.pit import get_label_perm, pit_loss

from .base import Task

TOLERANT_FRAME_DIFF = 2

__all__ = ["DiarizationPIT"]


class DiarizationPIT(Task):
    def __init__(
        self,
        model: nn.Module,
    ):
        super().__init__()
        self.model = model
        self.objective = pit_loss

    def _tile_representations(self, reps, factor):
        """
        Tile up the representations by `factor`.
        Input - sequence of representations, shape: (batch_size, seq_len, feature_dim)
        Output - sequence of tiled representations, shape: (batch_size, seq_len * factor, feature_dim)
        """
        assert (
            len(reps.shape) == 3
        ), "Input argument `reps` has invalid shape: {}".format(reps.shape)
        tiled_reps = reps.repeat(1, 1, factor)
        tiled_reps = tiled_reps.reshape(
            reps.size(0), reps.size(1) * factor, reps.size(2)
        )
        return tiled_reps

    def _match_length(self, inputs, labels):
        """
        Since the upstream extraction process can sometimes cause a mismatch
        between the seq lenth of inputs and labels:
        - if len(inputs) > len(labels), we truncate the final few timestamp of inputs to match the length of labels
        - if len(inputs) < len(labels), we duplicate the last timestep of inputs to match the length of labels
        Note that the length of labels should never be changed.
        """
        input_len, label_len = inputs.size(1), labels.size(1)

        factor = int(round(label_len / input_len))
        if factor > 1:
            inputs = self._tile_representations(inputs, factor)
            input_len = inputs.size(1)

        if input_len > label_len:
            inputs = inputs[:, :label_len, :]
        elif input_len < label_len:
            pad_vec = inputs[:, -1, :].unsqueeze(1)  # (batch_size, 1, feature_dim)
            inputs = torch.cat(
                (inputs, pad_vec.repeat(1, label_len - input_len, 1)), dim=1
            )  # (batch_size, seq_len, feature_dim), where seq_len == labels.size(-1)
        return inputs, labels

    def predict(self, x, x_len):
        predicted, predicted_len = self.model(x, x_len)
        return predicted, predicted_len

    def forward(
        self,
        _mode: str,
        x,
        x_len,
        label,
        label_len,
        record_id: str,
        chunk_id: int,
        _dump_dir: str = None,
    ):
        predicted, predicted_len = self.predict(x, x_len)

        for pl, ll in zip(predicted_len, label_len):
            assert (
                abs(pl - ll) <= TOLERANT_FRAME_DIFF
            ), f"predicted: {pl}, label: {ll}, TOLERANT_FRAME_DIFF: {TOLERANT_FRAME_DIFF}"

        predicted, label = self._match_length(predicted, label)
        loss, perm_idx, perm_list = self.objective(predicted, label.float(), label_len)
        label_perm = get_label_perm(label, perm_idx, perm_list)

        (
            correct,
            num_frames,
            speech_scored,
            speech_miss,
            speech_falarm,
            speaker_scored,
            speaker_miss,
            speaker_falarm,
            speaker_error,
        ) = calc_diarization_error(predicted, label_perm, label_len)

        if speech_scored > 0 and speaker_scored > 0 and num_frames > 0:
            SAD_MR, SAD_FR, MI, FA, CF, ACC, DER = (
                speech_miss / speech_scored,
                speech_falarm / speech_scored,
                speaker_miss / speaker_scored,
                speaker_falarm / speaker_scored,
                speaker_error / speaker_scored,
                correct / num_frames,
                (speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
            )
        else:
            SAD_MR, SAD_FR, MI, FA, CF, ACC, DER = 0, 0, 0, 0, 0, 0, 0

        if _mode == "test" and _dump_dir is not None:
            assert (
                len(set(list(record_id))) == 1
            ), "During testing, all utterances in a batch should come from the same recording"

            if len(label_len) > 1:
                assert (
                    len(set(label_len[:-1].tolist())) == 1
                ), f"Except the final chunk, other chunks from the same recording should have the same length"

            predicted_sorted = []
            for idx in chunk_id.long().topk(len(chunk_id), largest=False).indices:
                predicted_sorted.append(predicted[idx])

            predict = torch.vstack(predicted_sorted)
            predict = predict.detach().cpu()
            predict = 1 / (1 + (-predict).exp())

            prediction_dir = Path(_dump_dir) / f"prediction"
            prediction_dir.mkdir(exist_ok=True, parents=True)
            torch.save(predict, prediction_dir / f"{record_id[0]}.pt")

        cacheable = dict(
            loss=loss.detach().cpu(),
            accuracy=ACC,
            der=DER,
        )

        return loss, cacheable

    def reduction(self, _mode: str, cached_results: List[dict], _dump_dir: str = None):
        results = self.parse_cached_results(cached_results)
        logs = dict(
            accuracy=torch.FloatTensor(results["accuracy"]).mean().item(),
            der=torch.FloatTensor(results["der"]).mean().item(),
        )
        return logs