File size: 4,461 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright Hear Benchmark Team
# Copyright Shu-wen Yang (refactor from https://github.com/hearbenchmark/hear-eval-kit)
from typing import List
import torch
from s3prl.dataio.encoder.category import CategoryEncoder
from s3prl.task.base import Task
from ._hear_score import available_scores, validate_score_return_type
__all__ = ["ScenePredictionTask"]
class OneHotToCrossEntropyLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert torch.all(torch.sum(y, dim=1) == y.new_ones(y.shape[0]))
y = y.argmax(dim=1)
return self.loss(y_hat, y)
class ScenePredictionTask(Task):
def __init__(
self,
model: torch.nn.Module,
category: CategoryEncoder,
prediction_type: str,
scores: List[str],
):
super().__init__()
self.model = model
self.label_to_idx = {
str(category.decode(idx)): idx for idx in range(len(category))
}
self.idx_to_label = {
idx: str(category.decode(idx)) for idx in range(len(category))
}
self.scores = [
available_scores[score](label_to_idx=self.label_to_idx) for score in scores
]
if prediction_type == "multilabel":
self.activation: torch.nn.Module = torch.nn.Sigmoid()
self.logit_loss = torch.nn.BCEWithLogitsLoss()
elif prediction_type == "multiclass":
self.activation = torch.nn.Softmax(dim=-1)
self.logit_loss = OneHotToCrossEntropyLoss()
else:
raise ValueError(f"Unknown prediction_type {prediction_type}")
def predict(self, x, x_len):
logits, _ = self.model(x, x_len)
prediction = self.activation(logits)
return prediction, logits
def forward(
self, _mode: str, x, x_len, y, labels, unique_name: str, _dump_dir: str = None
):
y_pr, y_hat = self.predict(x, x_len)
loss = self.logit_loss(y_hat.float(), y.float())
cacheable = dict(
loss=loss.detach().cpu().item(),
label=y.detach().cpu().unbind(dim=0), # (batch_size, num_class)
logit=y_hat.detach().cpu().unbind(dim=0), # (batch_size, num_class)
prediction=y_pr.detach().cpu().unbind(dim=0), # (batch_size, num_class)
)
return loss, cacheable
def log_scores(self, score_args):
"""Logs the metric score value for each score defined for the model"""
assert hasattr(self, "scores"), "Scores for the model should be defined"
end_scores = {}
# The first score in the first `self.scores` is the optimization criterion
for score in self.scores:
score_ret = score(*score_args)
validate_score_return_type(score_ret)
# If the returned score is a tuple, store each subscore as separate entry
if isinstance(score_ret, tuple):
end_scores[f"{score}"] = score_ret[0][1]
# All other scores will also be logged
for subscore, value in score_ret:
end_scores[f"{score}_{subscore}"] = value
elif isinstance(score_ret, float):
end_scores[f"{score}"] = score_ret
else:
raise ValueError(
f"Return type {type(score_ret)} is unexpected. Return type of "
"the score function should either be a "
"tuple(tuple) or float."
)
return end_scores
def reduction(
self,
_mode: str,
cached_results: List[dict],
_dump_dir: str = None,
):
result = self.parse_cached_results(cached_results)
target = torch.stack(result["label"], dim=0)
prediction_logit = torch.stack(result["logit"], dim=0)
prediction = torch.stack(result["prediction"], dim=0)
loss = self.logit_loss(prediction_logit, target)
logs = dict(
loss=loss.detach().cpu().item(),
)
if _mode in ["valid", "test"]:
logs.update(
self.log_scores(
score_args=(
prediction.detach().cpu().numpy(),
target.detach().cpu().numpy(),
),
)
)
return logs
|