""" Utterance Classification Tasks Authors * Leo 2022 """ import logging from typing import List import numpy as np import torch import torch.nn.functional as F from s3prl.dataio.encoder.category import CategoryEncoder, CategoryEncoders from s3prl.metric import accuracy from . import Task logger = logging.getLogger(__name__) __all__ = [ "UtteranceClassifierExample", "UtteranceClassificationTask", ] class UtteranceClassifierExample(torch.nn.Module): """ Attributes: input_size: int output_size: int """ def __init__(self, input_size=3, output_size=4): super().__init__() self._input_size = input_size self._output_size = output_size @property def input_size(self): return self._input_size @property def output_size(self): return self._output_size def forward(self, x, x_len): """ Args: x (torch.Tensor): (batch_size, timestemps, input_size) x_len (torch.LongTensor): (batch_size, ) Return: output (torch.Tensor): (batch_size, output_size) """ assert x.size(-1) == self.input_size output = torch.randn(x.size(0), self.output_size) assert output class UtteranceClassificationTask(Task): """ Attributes: input_size (int): defined by model.input_size output_size (int): defined by len(categories) """ def __init__(self, model: UtteranceClassifierExample, category: CategoryEncoder): """ model.output_size should match len(categories) Args: model (UtteranceClassifier) category: encode: str -> int decode: int -> str __len__: -> int """ super().__init__() self.model = model self.category = category assert self.model.output_size == len(category) def predict(self, x: torch.Tensor, x_len: torch.LongTensor): """ Args: x (torch.Tensor): (batch_size, timestamps, input_size) x_len (torch.LongTensor): (batch_size, ) Return: logits (torch.Tensor): (batch_size, output_size) prediction (list): prediction strings """ logits: torch.Tensor = self.model(x, x_len) predictions = [ self.category.decode(index) for index in logits.argmax(dim=-1).detach().cpu().tolist() ] return logits, predictions def forward( self, _mode: str, x: torch.Tensor, x_len: torch.LongTensor, class_id: torch.LongTensor, label: List[str], unique_name: List[str], _dump_dir: str = None, ): logits, prediction = self.predict(x, x_len) loss = F.cross_entropy(logits, class_id) cacheable = dict( loss=loss.detach().cpu(), prediction=prediction, label=[self.category.decode(idx) for idx in class_id], unique_name=unique_name, ) return loss, cacheable def reduction(self, _mode: str, cached_results: List[dict], _dump_dir: str = None): results = self.parse_cached_results(cached_results) predictions = results["prediction"] labels = results["label"] losses = results["loss"] acc = accuracy(predictions, labels) loss = (sum(losses) / len(losses)).item() return dict( loss=loss, accuracy=acc, ) class UtteranceMultiClassClassificationTask(Task): def __init__(self, model: UtteranceClassifierExample, categories: CategoryEncoders): super().__init__() self.model = model self.categories = categories assert self.model.output_size == len(categories) def predict(self, x: torch.Tensor, x_len: torch.LongTensor): """ Args: x (torch.Tensor): (batch_size, timestamps, input_size) x_len (torch.LongTensor): (batch_size, ) Return: logit (torch.Tensor): List[(batch_size, sub_output_size)] prediction (np.array): (batch_size, num_category) """ logits: torch.Tensor = self.model(x, x_len) logit_start = 0 sub_logits, sub_predictions = [], [] for category in self.categories: logit_end = logit_start + len(category) sub_logit = logits[:, logit_start:logit_end] sub_logits.append(sub_logit) sub_predictions.append( [ category.decode(index) for index in sub_logit.argmax(dim=-1).detach().cpu().tolist() ] ) logit_start = logit_end prediction = np.array(sub_predictions, dtype="object").T return sub_logits, prediction def forward( self, _mode: str, x: torch.Tensor, x_len: torch.LongTensor, class_ids: torch.LongTensor, labels: np.ndarray, unique_name: List[str], _dump_dir: str = None, ): """ Args: x: torch.Tensor, (batch_size, timestamps, input_size) x_len: torch.LongTensor, (batch_size) class_ids: torch.LongTensor, (batch_size, num_category) labels: np.ndarray, (batch_size, num_category) Return: loss: torch.Tensor prediction: np.ndarray label: np.ndarray """ logit, prediction = self.predict(x, x_len) loss = sum( [ F.cross_entropy(sub_logit, class_id) for sub_logit, class_id in zip(logit, class_ids.T) ] ) cacheable = dict( loss=loss.detach().cpu(), prediction=prediction.tolist(), label=labels.tolist(), unique_name=unique_name, ) return loss, cacheable def reduction(self, _mode: str, cached_results: List[dict], _dump_dir: str = None): results = self.parse_cached_results(cached_results) losses = results["loss"] predictions = results["prediction"] labels = results["label"] acc = accuracy(predictions, labels) loss = (sum(losses) / len(losses)).item() return dict( loss=loss, accuracy=acc, )