Update 3 files
Browse files- ced_model/configuration_ced.py +3 -6
- ced_model/feature_extraction_ced.py +50 -8
- ced_model/modeling_ced.py +28 -45
ced_model/configuration_ced.py
CHANGED
|
@@ -123,15 +123,12 @@ class CedConfig(PretrainedConfig):
|
|
| 123 |
self.qkv_bias = qkv_bias
|
| 124 |
self.target_length = target_length
|
| 125 |
self.win_size = kwargs.get("win_size", 512)
|
|
|
|
| 126 |
|
| 127 |
if self.outputdim == 527:
|
| 128 |
-
with open(
|
| 129 |
-
cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r"
|
| 130 |
-
) as f:
|
| 131 |
self.id2label = {
|
| 132 |
-
int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2]
|
| 133 |
-
.replace('"', "")
|
| 134 |
-
.strip("\n")
|
| 135 |
for line in f.readlines()[1:]
|
| 136 |
}
|
| 137 |
self.label2id = {v: k for k, v in self.id2label.items()}
|
|
|
|
| 123 |
self.qkv_bias = qkv_bias
|
| 124 |
self.target_length = target_length
|
| 125 |
self.win_size = kwargs.get("win_size", 512)
|
| 126 |
+
self.loss = "BCE"
|
| 127 |
|
| 128 |
if self.outputdim == 527:
|
| 129 |
+
with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f:
|
|
|
|
|
|
|
| 130 |
self.id2label = {
|
| 131 |
+
int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2].replace('"', "").strip("\n")
|
|
|
|
|
|
|
| 132 |
for line in f.readlines()[1:]
|
| 133 |
}
|
| 134 |
self.label2id = {v: k for k, v in self.id2label.items()}
|
ced_model/feature_extraction_ced.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
Feature extractor class for CED.
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
from typing import Optional, Union
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
import torch
|
|
@@ -77,10 +77,14 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
| 77 |
self.f_max = f_max
|
| 78 |
self.hop_size = hop_size
|
| 79 |
|
|
|
|
|
|
|
| 80 |
def __call__(
|
| 81 |
self,
|
| 82 |
-
x: Union[np.ndarray, torch.Tensor],
|
| 83 |
sampling_rate: Optional[int] = None,
|
|
|
|
|
|
|
| 84 |
return_tensors="pt",
|
| 85 |
) -> BatchFeature:
|
| 86 |
r"""
|
|
@@ -88,6 +92,14 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
| 88 |
|
| 89 |
Args:
|
| 90 |
x: Input audio signal tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
Returns:
|
| 93 |
BatchFeature: A dictionary containing the extracted features.
|
|
@@ -96,9 +108,7 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
| 96 |
sampling_rate = self.sampling_rate
|
| 97 |
|
| 98 |
if return_tensors != "pt":
|
| 99 |
-
raise NotImplementedError(
|
| 100 |
-
"Only return_tensors='pt' is currently supported."
|
| 101 |
-
)
|
| 102 |
|
| 103 |
mel_spectrogram = audio_transforms.MelSpectrogram(
|
| 104 |
f_min=self.f_min,
|
|
@@ -112,10 +122,42 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
| 112 |
)
|
| 113 |
amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
|
|
|
| 119 |
x = mel_spectrogram(x)
|
| 120 |
x = amplitude_to_db(x)
|
| 121 |
return BatchFeature({"input_values": x})
|
|
|
|
| 16 |
Feature extractor class for CED.
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
from typing import List, Optional, Union
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
import torch
|
|
|
|
| 77 |
self.f_max = f_max
|
| 78 |
self.hop_size = hop_size
|
| 79 |
|
| 80 |
+
self.model_input_names = ["input_values"]
|
| 81 |
+
|
| 82 |
def __call__(
|
| 83 |
self,
|
| 84 |
+
x: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
|
| 85 |
sampling_rate: Optional[int] = None,
|
| 86 |
+
max_length: Optional[int] = 16000,
|
| 87 |
+
truncation: bool = True,
|
| 88 |
return_tensors="pt",
|
| 89 |
) -> BatchFeature:
|
| 90 |
r"""
|
|
|
|
| 92 |
|
| 93 |
Args:
|
| 94 |
x: Input audio signal tensor.
|
| 95 |
+
sampling_rate (int, *optional*, defaults to `None`):
|
| 96 |
+
Sampling rate of the input audio signal.
|
| 97 |
+
max_length (int, *optional*, defaults to 16000):
|
| 98 |
+
Maximum length of the input audio signal.
|
| 99 |
+
truncation (bool, *optional*, defaults to `True`):
|
| 100 |
+
Whether to truncate the input signal to max_length.
|
| 101 |
+
return_tensors (str, *optional*, defaults to "pt"):
|
| 102 |
+
If set to "pt", the return type will be a PyTorch tensor.
|
| 103 |
|
| 104 |
Returns:
|
| 105 |
BatchFeature: A dictionary containing the extracted features.
|
|
|
|
| 108 |
sampling_rate = self.sampling_rate
|
| 109 |
|
| 110 |
if return_tensors != "pt":
|
| 111 |
+
raise NotImplementedError("Only return_tensors='pt' is currently supported.")
|
|
|
|
|
|
|
| 112 |
|
| 113 |
mel_spectrogram = audio_transforms.MelSpectrogram(
|
| 114 |
f_min=self.f_min,
|
|
|
|
| 122 |
)
|
| 123 |
amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
|
| 124 |
|
| 125 |
+
if isinstance(x, np.ndarray):
|
| 126 |
+
if x.ndim == 1:
|
| 127 |
+
x = x[np.newaxis, :]
|
| 128 |
+
if x.ndim != 2:
|
| 129 |
+
raise ValueError("np.ndarray input must be a 1D or 2D.")
|
| 130 |
+
x = torch.from_numpy(x)
|
| 131 |
+
elif isinstance(x, torch.Tensor):
|
| 132 |
+
if x.dim() == 1:
|
| 133 |
+
x = x.unsqueeze(0)
|
| 134 |
+
if x.dim() != 2:
|
| 135 |
+
raise ValueError("torch.Tensor input must be a 1D or 2D.")
|
| 136 |
+
elif isinstance(x, (list, tuple)):
|
| 137 |
+
longest_length = max(x_.shape[0] for x_ in x)
|
| 138 |
+
if not truncation and max_length < longest_length:
|
| 139 |
+
max_length = longest_length
|
| 140 |
+
|
| 141 |
+
if all(isinstance(x_, np.ndarray) for x_ in x):
|
| 142 |
+
if not all(x_.ndim == 1 for x_ in x):
|
| 143 |
+
raise ValueError("All np.ndarray in a list must be 1D.")
|
| 144 |
+
|
| 145 |
+
x_trim = [x_[:max_length] for x_ in x]
|
| 146 |
+
x_pad = [np.pad(x_, (0, max_length - x_.shape[0]), mode="constant", constant_values=0) for x_ in x_trim]
|
| 147 |
+
x = torch.stack([torch.from_numpy(x_) for x_ in x_pad])
|
| 148 |
+
elif all(isinstance(x_, torch.Tensor) for x_ in x):
|
| 149 |
+
if not all(x_.dim() == 1 for x_ in x):
|
| 150 |
+
raise ValueError("All torch.Tensor in a list must be 1D.")
|
| 151 |
+
x_pad = [torch.nn.functional.pad(x_, (0, max_length - x_.shape[0]), value=0) for x_ in x]
|
| 152 |
+
x = torch.stack(x_pad)
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError("Input list must be numpy arrays or PyTorch tensors.")
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
"Input must be a numpy array, a list of numpy arrays, a PyTorch tensor, or a list of PyTorch tensor."
|
| 158 |
+
)
|
| 159 |
|
| 160 |
+
x = x.float()
|
| 161 |
x = mel_spectrogram(x)
|
| 162 |
x = amplitude_to_db(x)
|
| 163 |
return BatchFeature({"input_values": x})
|
ced_model/modeling_ced.py
CHANGED
|
@@ -106,9 +106,7 @@ class CedAudioPatchEmbed(nn.Module):
|
|
| 106 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 107 |
self.flatten = flatten
|
| 108 |
|
| 109 |
-
self.proj = nn.Conv2d(
|
| 110 |
-
in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride
|
| 111 |
-
)
|
| 112 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 113 |
|
| 114 |
def forward(self, x):
|
|
@@ -143,11 +141,7 @@ class CedAttention(nn.Module):
|
|
| 143 |
|
| 144 |
def forward(self, x):
|
| 145 |
B, N, C = x.shape
|
| 146 |
-
qkv = (
|
| 147 |
-
self.qkv(x)
|
| 148 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 149 |
-
.permute(2, 0, 3, 1, 4)
|
| 150 |
-
)
|
| 151 |
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 152 |
|
| 153 |
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
@@ -221,9 +215,7 @@ class DropPath(nn.Module):
|
|
| 221 |
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 222 |
|
| 223 |
|
| 224 |
-
def drop_path(
|
| 225 |
-
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
| 226 |
-
):
|
| 227 |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 228 |
|
| 229 |
This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
|
|
@@ -236,9 +228,7 @@ def drop_path(
|
|
| 236 |
if drop_prob == 0.0 or not training:
|
| 237 |
return x
|
| 238 |
keep_prob = 1 - drop_prob
|
| 239 |
-
shape = (x.shape[0],) + (1,) * (
|
| 240 |
-
x.ndim - 1
|
| 241 |
-
) # work with diff dim tensors, not just 2D ConvNets
|
| 242 |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 243 |
if keep_prob > 0.0 and scale_by_keep:
|
| 244 |
random_tensor.div_(keep_prob)
|
|
@@ -373,17 +363,11 @@ class CedModel(CedPreTrainedModel):
|
|
| 373 |
patch_stride=config.patch_stride,
|
| 374 |
)
|
| 375 |
|
| 376 |
-
self.time_pos_embed = nn.Parameter(
|
| 377 |
-
|
| 378 |
-
)
|
| 379 |
-
self.freq_pos_embed = nn.Parameter(
|
| 380 |
-
torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
|
| 381 |
-
)
|
| 382 |
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 383 |
act_layer = nn.GELU
|
| 384 |
-
dpr = [
|
| 385 |
-
x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)
|
| 386 |
-
] # stochastic depth decay rule
|
| 387 |
self.pos_drop = nn.Dropout(p=config.drop_rate)
|
| 388 |
self.blocks = nn.Sequential(
|
| 389 |
*[
|
|
@@ -407,13 +391,16 @@ class CedModel(CedPreTrainedModel):
|
|
| 407 |
# Initialize weights and apply final processing
|
| 408 |
self.post_init()
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 411 |
x = self.patch_embed(x)
|
| 412 |
_, _, _, t = x.shape
|
| 413 |
x = x + self.time_pos_embed[:, :, :, :t]
|
| 414 |
-
x =
|
| 415 |
-
x + self.freq_pos_embed[:, :, :, :]
|
| 416 |
-
) # Just to support __getitem__ in posembed
|
| 417 |
|
| 418 |
# x = rearrange(x, 'b c f t -> b (f t) c')
|
| 419 |
x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
|
|
@@ -442,9 +429,7 @@ class CedModel(CedPreTrainedModel):
|
|
| 442 |
|
| 443 |
if splits[-1].shape[-1] < self.maximal_allowed_length:
|
| 444 |
if self.config.pad_last:
|
| 445 |
-
pad = torch.zeros(
|
| 446 |
-
*x.shape[:-1], self.maximal_allowed_length, device=x.device
|
| 447 |
-
)
|
| 448 |
pad[..., : splits[-1].shape[-1]] = splits[-1]
|
| 449 |
splits = torch.stack((*splits[:-1], pad), dim=0)
|
| 450 |
else:
|
|
@@ -497,9 +482,7 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
| 497 |
elif self.config.pooling == "dm":
|
| 498 |
# Unpack using the frequency dimension, which is constant
|
| 499 |
# 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
|
| 500 |
-
x = torch.reshape(
|
| 501 |
-
x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3])
|
| 502 |
-
)
|
| 503 |
|
| 504 |
# First poolin frequency, then sigmoid the (B T D) output
|
| 505 |
x = self.outputlayer(x.mean(1)).sigmoid()
|
|
@@ -507,9 +490,10 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
| 507 |
else:
|
| 508 |
return x.mean(1)
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
|
|
|
| 513 |
@add_code_sample_docstrings(
|
| 514 |
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
| 515 |
output_type=SequenceClassifierOutput,
|
|
@@ -519,9 +503,7 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
| 519 |
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
| 520 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
| 521 |
)
|
| 522 |
-
def forward(
|
| 523 |
-
self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None
|
| 524 |
-
):
|
| 525 |
"""
|
| 526 |
Runs a forward pass of the CED model for audio classification task.
|
| 527 |
|
|
@@ -554,14 +536,15 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
| 554 |
logits = self.forward_head(last_hidden_states)
|
| 555 |
|
| 556 |
if labels is not None:
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
| 561 |
loss = loss_fct(logits, labels)
|
| 562 |
else:
|
| 563 |
loss = None
|
| 564 |
|
| 565 |
-
return SequenceClassifierOutput(
|
| 566 |
-
logits=logits, loss=loss, hidden_states=last_hidden_states
|
| 567 |
-
)
|
|
|
|
| 106 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 107 |
self.flatten = flatten
|
| 108 |
|
| 109 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride)
|
|
|
|
|
|
|
| 110 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 111 |
|
| 112 |
def forward(self, x):
|
|
|
|
| 141 |
|
| 142 |
def forward(self, x):
|
| 143 |
B, N, C = x.shape
|
| 144 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 146 |
|
| 147 |
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
| 215 |
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 216 |
|
| 217 |
|
| 218 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
|
|
|
|
|
|
| 219 |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 220 |
|
| 221 |
This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
|
|
|
|
| 228 |
if drop_prob == 0.0 or not training:
|
| 229 |
return x
|
| 230 |
keep_prob = 1 - drop_prob
|
| 231 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
|
|
|
|
|
| 232 |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 233 |
if keep_prob > 0.0 and scale_by_keep:
|
| 234 |
random_tensor.div_(keep_prob)
|
|
|
|
| 363 |
patch_stride=config.patch_stride,
|
| 364 |
)
|
| 365 |
|
| 366 |
+
self.time_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02)
|
| 367 |
+
self.freq_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 369 |
act_layer = nn.GELU
|
| 370 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)] # stochastic depth decay rule
|
|
|
|
|
|
|
| 371 |
self.pos_drop = nn.Dropout(p=config.drop_rate)
|
| 372 |
self.blocks = nn.Sequential(
|
| 373 |
*[
|
|
|
|
| 391 |
# Initialize weights and apply final processing
|
| 392 |
self.post_init()
|
| 393 |
|
| 394 |
+
def _freeze_parameters(self):
|
| 395 |
+
for param in self.parameters():
|
| 396 |
+
param.requires_grad = False
|
| 397 |
+
self._requires_grad = False
|
| 398 |
+
|
| 399 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 400 |
x = self.patch_embed(x)
|
| 401 |
_, _, _, t = x.shape
|
| 402 |
x = x + self.time_pos_embed[:, :, :, :t]
|
| 403 |
+
x = x + self.freq_pos_embed[:, :, :, :] # Just to support __getitem__ in posembed
|
|
|
|
|
|
|
| 404 |
|
| 405 |
# x = rearrange(x, 'b c f t -> b (f t) c')
|
| 406 |
x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
|
|
|
|
| 429 |
|
| 430 |
if splits[-1].shape[-1] < self.maximal_allowed_length:
|
| 431 |
if self.config.pad_last:
|
| 432 |
+
pad = torch.zeros(*x.shape[:-1], self.maximal_allowed_length, device=x.device)
|
|
|
|
|
|
|
| 433 |
pad[..., : splits[-1].shape[-1]] = splits[-1]
|
| 434 |
splits = torch.stack((*splits[:-1], pad), dim=0)
|
| 435 |
else:
|
|
|
|
| 482 |
elif self.config.pooling == "dm":
|
| 483 |
# Unpack using the frequency dimension, which is constant
|
| 484 |
# 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
|
| 485 |
+
x = torch.reshape(x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3]))
|
|
|
|
|
|
|
| 486 |
|
| 487 |
# First poolin frequency, then sigmoid the (B T D) output
|
| 488 |
x = self.outputlayer(x.mean(1)).sigmoid()
|
|
|
|
| 490 |
else:
|
| 491 |
return x.mean(1)
|
| 492 |
|
| 493 |
+
def freeze_encoder(self):
|
| 494 |
+
self.encoder._freeze_parameters()
|
| 495 |
+
|
| 496 |
+
@add_start_docstrings_to_model_forward(CED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 497 |
@add_code_sample_docstrings(
|
| 498 |
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
| 499 |
output_type=SequenceClassifierOutput,
|
|
|
|
| 503 |
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
| 504 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
| 505 |
)
|
| 506 |
+
def forward(self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
|
|
|
|
|
|
|
| 507 |
"""
|
| 508 |
Runs a forward pass of the CED model for audio classification task.
|
| 509 |
|
|
|
|
| 536 |
logits = self.forward_head(last_hidden_states)
|
| 537 |
|
| 538 |
if labels is not None:
|
| 539 |
+
if self.config.loss == "CE":
|
| 540 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 541 |
+
elif self.config.loss == "BCE":
|
| 542 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
| 543 |
+
else:
|
| 544 |
+
raise NotImplementedError("Need to set 'CE' or 'BCE' as config.loss.")
|
| 545 |
+
labels = nn.functional.one_hot(labels, num_classes=self.config.outputdim).float()
|
| 546 |
loss = loss_fct(logits, labels)
|
| 547 |
else:
|
| 548 |
loss = None
|
| 549 |
|
| 550 |
+
return SequenceClassifierOutput(logits=logits, loss=loss, hidden_states=last_hidden_states)
|
|
|
|
|
|