File size: 3,650 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
#   FileName     [ dataset.py ]
#   Synopsis     [ the speaker diarization dataset ]
#   Source       [ Refactored from https://github.com/hitachi-speech/EEND ]
#   Author       [ Jiatong Shi ]
#   Copyright    [ Copyleft(c), Johns Hopkins University ]
"""*********************************************************************************************"""

###############
# IMPORTATION #
###############
import torch
import numpy as np
from itertools import permutations


# compute mask to remove the padding positions
def create_length_mask(length, max_len, num_output, device):
    batch_size = len(length)
    mask = torch.zeros(batch_size, max_len, num_output)
    for i in range(batch_size):
        mask[i, : length[i], :] = 1
    mask = mask.to(device)
    return mask


# compute loss for a single permutation
def pit_loss_single_permute(output, label, length):
    bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
    mask = create_length_mask(length, label.size(1), label.size(2), label.device)
    loss = bce_loss(output, label)
    loss = loss * mask
    loss = torch.sum(torch.mean(loss, dim=2), dim=1)
    loss = torch.unsqueeze(loss, dim=1)
    return loss


def pit_loss(output, label, length):
    num_output = label.size(2)
    device = label.device
    permute_list = [np.array(p) for p in permutations(range(num_output))]
    loss_list = []
    for p in permute_list:
        label_perm = label[:, :, p]
        loss_perm = pit_loss_single_permute(output, label_perm, length)
        loss_list.append(loss_perm)
    loss = torch.cat(loss_list, dim=1)
    min_loss, min_idx = torch.min(loss, dim=1)
    loss = torch.sum(min_loss) / torch.sum(length.float().to(device))
    return loss, min_idx, permute_list


def get_label_perm(label, perm_idx, perm_list):
    batch_size = len(perm_idx)
    label_list = []
    for i in range(batch_size):
        label_list.append(label[i, :, perm_list[perm_idx[i]]].data.cpu().numpy())
    return torch.from_numpy(np.array(label_list)).float()


def calc_diarization_error(pred, label, length):
    (batch_size, max_len, num_output) = label.size()
    # mask the padding part
    mask = np.zeros((batch_size, max_len, num_output))
    for i in range(batch_size):
        mask[i, : length[i], :] = 1

    # pred and label have the shape (batch_size, max_len, num_output)
    label_np = label.data.cpu().numpy().astype(int)
    pred_np = (pred.data.cpu().numpy() > 0).astype(int)
    label_np = label_np * mask
    pred_np = pred_np * mask
    length = length.data.cpu().numpy()

    # compute speech activity detection error
    n_ref = np.sum(label_np, axis=2)
    n_sys = np.sum(pred_np, axis=2)
    speech_scored = float(np.sum(n_ref > 0))
    speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
    speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))

    # compute speaker diarization error
    speaker_scored = float(np.sum(n_ref))
    speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
    speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
    n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
    speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
    correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
    num_frames = np.sum(length)
    return (
        correct,
        num_frames,
        speech_scored,
        speech_miss,
        speech_falarm,
        speaker_scored,
        speaker_miss,
        speaker_falarm,
        speaker_error,
    )