File size: 10,315 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
# Copyright 2021 Jiatong Shi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.diar.decoder.abs_decoder import AbsDecoder
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetDiarizationModel(AbsESPnetModel):
"""Speaker Diarization model"""
def __init__(
self,
frontend: Optional[AbsFrontend],
normalize: Optional[AbsNormalize],
label_aggregator: torch.nn.Module,
encoder: AbsEncoder,
decoder: AbsDecoder,
loss_type: str = "pit", # only support pit loss for now
):
assert check_argument_types()
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.num_spk = decoder.num_spk
self.normalize = normalize
self.frontend = frontend
self.label_aggregator = label_aggregator
self.loss_type = loss_type
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor = None,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, samples)
speech_lengths: (Batch,) default None for chunk interator,
because the chunk-iterator does not
have the speech_lengths returned.
see in
espnet2/iterators/chunk_iter_factory.py
spk_labels: (Batch, )
"""
assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# 2. Decoder (baiscally a predction layer after encoder_out)
pred = self.decoder(encoder_out, encoder_out_lens)
# 3. Aggregate time-domain labels
spk_labels, spk_labels_lengths = self.label_aggregator(
spk_labels, spk_labels_lengths
)
if self.loss_type == "pit":
loss, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
(
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
if speech_scored > 0 and num_frames > 0:
sad_mr, sad_fr, mi, fa, cf, acc, der = (
speech_miss / speech_scored,
speech_falarm / speech_scored,
speaker_miss / speaker_scored,
speaker_falarm / speaker_scored,
speaker_error / speaker_scored,
correct / num_frames,
(speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
)
else:
sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
stats = dict(
loss=loss.detach(),
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
fa=fa,
cf=cf,
acc=acc,
der=der,
)
else:
raise NotImplementedError
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 3. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = speech.shape[0]
speech_lengths = (
speech_lengths
if speech_lengths is not None
else torch.ones(batch_size).int() * speech.shape[1]
)
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def pit_loss_single_permute(self, pred, label, length):
bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
mask = self.create_length_mask(length, label.size(1), label.size(2))
loss = bce_loss(pred, label)
loss = loss * mask
loss = torch.sum(torch.mean(loss, dim=2), dim=1)
loss = torch.unsqueeze(loss, dim=1)
return loss
def pit_loss(self, pred, label, lengths):
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
num_output = label.size(2)
permute_list = [np.array(p) for p in permutations(range(num_output))]
loss_list = []
for p in permute_list:
label_perm = label[:, :, p]
loss_perm = self.pit_loss_single_permute(pred, label_perm, lengths)
loss_list.append(loss_perm)
loss = torch.cat(loss_list, dim=1)
min_loss, min_idx = torch.min(loss, dim=1)
loss = torch.sum(min_loss) / torch.sum(lengths.float())
batch_size = len(min_idx)
label_list = []
for i in range(batch_size):
label_list.append(label[i, :, permute_list[min_idx[i]]].data.cpu().numpy())
label_permute = torch.from_numpy(np.array(label_list)).float()
return loss, min_idx, permute_list, label_permute
def create_length_mask(self, length, max_len, num_output):
batch_size = len(length)
mask = torch.zeros(batch_size, max_len, num_output)
for i in range(batch_size):
mask[i, : length[i], :] = 1
mask = to_device(self, mask)
return mask
@staticmethod
def calc_diarization_error(pred, label, length):
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
(batch_size, max_len, num_output) = label.size()
# mask the padding part
mask = np.zeros((batch_size, max_len, num_output))
for i in range(batch_size):
mask[i, : length[i], :] = 1
# pred and label have the shape (batch_size, max_len, num_output)
label_np = label.data.cpu().numpy().astype(int)
pred_np = (pred.data.cpu().numpy() > 0).astype(int)
label_np = label_np * mask
pred_np = pred_np * mask
length = length.data.cpu().numpy()
# compute speech activity detection error
n_ref = np.sum(label_np, axis=2)
n_sys = np.sum(pred_np, axis=2)
speech_scored = float(np.sum(n_ref > 0))
speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
# compute speaker diarization error
speaker_scored = float(np.sum(n_ref))
speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
num_frames = np.sum(length)
return (
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
)
|