File size: 3,673 Bytes
51e2f90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from abc import ABC
from typing import Any, Dict, Union
import torch
import torch_audiomentations as tam
from torch import nn
from models.bandit.core.data._types import BatchedDataDict, DataDict
class BaseAugmentor(nn.Module, ABC):
def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
DataDict, BatchedDataDict]:
raise NotImplementedError
class StemAugmentor(BaseAugmentor):
def __init__(
self,
audiomentations: Dict[str, Dict[str, Any]],
fix_clipping: bool = True,
scaler_margin: float = 0.5,
apply_both_default_and_common: bool = False,
) -> None:
super().__init__()
augmentations = {}
self.has_default = "[default]" in audiomentations
self.has_common = "[common]" in audiomentations
self.apply_both_default_and_common = apply_both_default_and_common
for stem in audiomentations:
if audiomentations[stem]["name"] == "Compose":
augmentations[stem] = getattr(
tam,
audiomentations[stem]["name"]
)(
[
getattr(tam, aug["name"])(**aug["kwargs"])
for aug in
audiomentations[stem]["kwargs"]["transforms"]
],
**audiomentations[stem]["kwargs"]["kwargs"],
)
else:
augmentations[stem] = getattr(
tam,
audiomentations[stem]["name"]
)(
**audiomentations[stem]["kwargs"]
)
self.augmentations = nn.ModuleDict(augmentations)
self.fix_clipping = fix_clipping
self.scaler_margin = scaler_margin
def check_and_fix_clipping(
self, item: Union[DataDict, BatchedDataDict]
) -> Union[DataDict, BatchedDataDict]:
max_abs = []
for stem in item["audio"]:
max_abs.append(item["audio"][stem].abs().max().item())
if max(max_abs) > 1.0:
scaler = 1.0 / (max(max_abs) + torch.rand(
(1,),
device=item["audio"]["mixture"].device
) * self.scaler_margin)
for stem in item["audio"]:
item["audio"][stem] *= scaler
return item
def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
DataDict, BatchedDataDict]:
for stem in item["audio"]:
if stem == "mixture":
continue
if self.has_common:
item["audio"][stem] = self.augmentations["[common]"](
item["audio"][stem]
).samples
if stem in self.augmentations:
item["audio"][stem] = self.augmentations[stem](
item["audio"][stem]
).samples
elif self.has_default:
if not self.has_common or self.apply_both_default_and_common:
item["audio"][stem] = self.augmentations["[default]"](
item["audio"][stem]
).samples
item["audio"]["mixture"] = sum(
[item["audio"][stem] for stem in item["audio"]
if stem != "mixture"]
) # type: ignore[call-overload, assignment]
if self.fix_clipping:
item = self.check_and_fix_clipping(item)
return item
|