wavlm-large / s3prl_s3prl_main /s3prl /task /speaker_verification_task.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
Speaker Verification with Softmax-based loss
Authors
* Po-Han Chi 2021
* Haibin Wu 2022
"""
import logging
from typing import List, Tuple
import torch
import torch.nn.functional as F
from tqdm import tqdm
from s3prl.dataio.encoder.category import CategoryEncoder
from s3prl.metric import accuracy, compute_eer, compute_minDCF
from s3prl.nn import amsoftmax, softmax
from . import Task
logger = logging.getLogger(__name__)
__all__ = ["SpeakerClassifier", "SpeakerVerification"]
class SpeakerClassifier(torch.nn.Module):
"""
Attributes:
input_size: int
output_size: int
"""
def __init__(self, input_size=3, output_size=4):
super().__init__()
self._input_size = input_size
self._output_size = output_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._output_size
def forward(self, x, x_len):
"""
Args:
x (torch.Tensor): (batch_size, timestemps, input_size)
x_len (torch.LongTensor): (batch_size, )
Return:
output (torch.Tensor): (batch_size, output_size)
"""
assert x.size(-1) == self.input_size
output = torch.randn(x.size(0), self.output_size)
assert output
class SpeakerVerification(Task):
"""
model.output_size should match len(categories)
Args:
model (SpeakerClassifier):
actual model or a callable config for the model
categories (dict[str]):
each key in the Dictionary is the final prediction content in str.
use categories[key] to encode as numeric label
test_trials (List[Tuple[int, str, str]]):
each tuple in the list consists of (label, enroll_utt, test_utt)
loss_type (str): softmax or amsoftmax
loss_conf (dict): arguments for the loss_type class
"""
def __init__(
self,
model: SpeakerClassifier,
category: CategoryEncoder,
test_trials: List[Tuple[int, str, str]] = None,
loss_type: str = "amsoftmax",
loss_conf: dict = None,
):
super().__init__()
self.model = model
self.category = category
self.trials = test_trials
if loss_type == "amsoftmax":
loss_cls = amsoftmax
elif loss_type == "softmax":
loss_cls = softmax
else:
raise ValueError(f"Unsupported loss_type {loss_type}")
self.loss: torch.nn.Module = loss_cls(
input_size=self.model.output_size,
output_size=len(self.category),
**loss_conf,
)
assert self.loss.output_size == len(category)
def get_state(self):
return {
"loss_state": self.loss.state_dict(),
}
def set_state(self, state: dict):
self.loss.load_state_dict(state["loss_state"])
def predict(self, x: torch.Tensor, x_len: torch.LongTensor):
"""
Args:
x (torch.Tensor): (batch_size, timestamps, input_size)
x_len (torch.LongTensor): (batch_size, )
Return:
torch.Tensor
(batch_size, output_size)
"""
spk_embeddings = self.model(x, x_len)
return spk_embeddings
def train_step(
self,
x: torch.Tensor,
x_len: torch.LongTensor,
class_id: torch.LongTensor,
unique_name: List[str],
_dump_dir: str = None,
):
spk_embeddings = self.predict(x, x_len)
loss, logits = self.loss(spk_embeddings, class_id)
prediction = [index for index in logits.argmax(dim=-1).detach().cpu().tolist()]
cacheable = dict(
loss=loss.detach().cpu().item(),
class_id=class_id.detach().cpu().tolist(),
prediction=prediction,
unique_name=unique_name,
)
return loss, cacheable
def train_reduction(self, cached_results: list, _dump_dir: str = None):
results = self.parse_cached_results(cached_results)
acc = accuracy(results["prediction"], results["class_id"])
loss = torch.FloatTensor(results["loss"]).mean().item()
return dict(
loss=loss,
accuracy=acc,
)
def test_step(
self,
x: torch.Tensor,
x_len: torch.LongTensor,
unique_name: List[str],
_dump_dir: str,
):
"""
Args:
x (torch.Tensor): (batch_size, timestamps, input_size)
x_len: torch.LongTensor
unique_name (List[str])
Return:
unique_name (List[str])
output (torch.Tensor):
speaker embeddings corresponding to unique_name
"""
spk_embeddings = self.predict(x, x_len)
cacheable = dict(
unique_name=unique_name.tolist(),
spk_embedding=spk_embeddings.detach().cpu().unbind(dim=0),
)
return None, cacheable
def test_reduction(self, cached_results: List[dict], _dump_dir: str):
results = self.parse_cached_results(cached_results)
embeddings = {}
for name, emb in zip(results["unique_name"], results["spk_embedding"]):
embeddings[name] = emb
trials = self.trials
scores = []
labels = []
for label, enroll, test in tqdm(trials, desc="Test Scoring", total=len(trials)):
enroll_embd = embeddings[enroll]
test_embd = embeddings[test]
score = F.cosine_similarity(enroll_embd, test_embd, dim=0).item()
scores.append(score)
labels.append(label)
EER, EERthreshold = compute_eer(labels, scores)
minDCF, minDCFthreshold = compute_minDCF(labels, scores, p_target=0.01)
return dict(
EER=EER,
EERthreshold=EERthreshold.item(),
minDCF=minDCF,
minDCF_threshold=minDCFthreshold,
)