Update 3 files
Browse files- ced_model/configuration_ced.py +3 -6
- ced_model/feature_extraction_ced.py +50 -8
- ced_model/modeling_ced.py +28 -45
ced_model/configuration_ced.py
CHANGED
@@ -123,15 +123,12 @@ class CedConfig(PretrainedConfig):
|
|
123 |
self.qkv_bias = qkv_bias
|
124 |
self.target_length = target_length
|
125 |
self.win_size = kwargs.get("win_size", 512)
|
|
|
126 |
|
127 |
if self.outputdim == 527:
|
128 |
-
with open(
|
129 |
-
cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r"
|
130 |
-
) as f:
|
131 |
self.id2label = {
|
132 |
-
int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2]
|
133 |
-
.replace('"', "")
|
134 |
-
.strip("\n")
|
135 |
for line in f.readlines()[1:]
|
136 |
}
|
137 |
self.label2id = {v: k for k, v in self.id2label.items()}
|
|
|
123 |
self.qkv_bias = qkv_bias
|
124 |
self.target_length = target_length
|
125 |
self.win_size = kwargs.get("win_size", 512)
|
126 |
+
self.loss = "BCE"
|
127 |
|
128 |
if self.outputdim == 527:
|
129 |
+
with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f:
|
|
|
|
|
130 |
self.id2label = {
|
131 |
+
int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2].replace('"', "").strip("\n")
|
|
|
|
|
132 |
for line in f.readlines()[1:]
|
133 |
}
|
134 |
self.label2id = {v: k for k, v in self.id2label.items()}
|
ced_model/feature_extraction_ced.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
Feature extractor class for CED.
|
17 |
"""
|
18 |
|
19 |
-
from typing import Optional, Union
|
20 |
|
21 |
import numpy as np
|
22 |
import torch
|
@@ -77,10 +77,14 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
77 |
self.f_max = f_max
|
78 |
self.hop_size = hop_size
|
79 |
|
|
|
|
|
80 |
def __call__(
|
81 |
self,
|
82 |
-
x: Union[np.ndarray, torch.Tensor],
|
83 |
sampling_rate: Optional[int] = None,
|
|
|
|
|
84 |
return_tensors="pt",
|
85 |
) -> BatchFeature:
|
86 |
r"""
|
@@ -88,6 +92,14 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
88 |
|
89 |
Args:
|
90 |
x: Input audio signal tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
Returns:
|
93 |
BatchFeature: A dictionary containing the extracted features.
|
@@ -96,9 +108,7 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
96 |
sampling_rate = self.sampling_rate
|
97 |
|
98 |
if return_tensors != "pt":
|
99 |
-
raise NotImplementedError(
|
100 |
-
"Only return_tensors='pt' is currently supported."
|
101 |
-
)
|
102 |
|
103 |
mel_spectrogram = audio_transforms.MelSpectrogram(
|
104 |
f_min=self.f_min,
|
@@ -112,10 +122,42 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
|
|
112 |
)
|
113 |
amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
|
|
119 |
x = mel_spectrogram(x)
|
120 |
x = amplitude_to_db(x)
|
121 |
return BatchFeature({"input_values": x})
|
|
|
16 |
Feature extractor class for CED.
|
17 |
"""
|
18 |
|
19 |
+
from typing import List, Optional, Union
|
20 |
|
21 |
import numpy as np
|
22 |
import torch
|
|
|
77 |
self.f_max = f_max
|
78 |
self.hop_size = hop_size
|
79 |
|
80 |
+
self.model_input_names = ["input_values"]
|
81 |
+
|
82 |
def __call__(
|
83 |
self,
|
84 |
+
x: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
|
85 |
sampling_rate: Optional[int] = None,
|
86 |
+
max_length: Optional[int] = 16000,
|
87 |
+
truncation: bool = True,
|
88 |
return_tensors="pt",
|
89 |
) -> BatchFeature:
|
90 |
r"""
|
|
|
92 |
|
93 |
Args:
|
94 |
x: Input audio signal tensor.
|
95 |
+
sampling_rate (int, *optional*, defaults to `None`):
|
96 |
+
Sampling rate of the input audio signal.
|
97 |
+
max_length (int, *optional*, defaults to 16000):
|
98 |
+
Maximum length of the input audio signal.
|
99 |
+
truncation (bool, *optional*, defaults to `True`):
|
100 |
+
Whether to truncate the input signal to max_length.
|
101 |
+
return_tensors (str, *optional*, defaults to "pt"):
|
102 |
+
If set to "pt", the return type will be a PyTorch tensor.
|
103 |
|
104 |
Returns:
|
105 |
BatchFeature: A dictionary containing the extracted features.
|
|
|
108 |
sampling_rate = self.sampling_rate
|
109 |
|
110 |
if return_tensors != "pt":
|
111 |
+
raise NotImplementedError("Only return_tensors='pt' is currently supported.")
|
|
|
|
|
112 |
|
113 |
mel_spectrogram = audio_transforms.MelSpectrogram(
|
114 |
f_min=self.f_min,
|
|
|
122 |
)
|
123 |
amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
|
124 |
|
125 |
+
if isinstance(x, np.ndarray):
|
126 |
+
if x.ndim == 1:
|
127 |
+
x = x[np.newaxis, :]
|
128 |
+
if x.ndim != 2:
|
129 |
+
raise ValueError("np.ndarray input must be a 1D or 2D.")
|
130 |
+
x = torch.from_numpy(x)
|
131 |
+
elif isinstance(x, torch.Tensor):
|
132 |
+
if x.dim() == 1:
|
133 |
+
x = x.unsqueeze(0)
|
134 |
+
if x.dim() != 2:
|
135 |
+
raise ValueError("torch.Tensor input must be a 1D or 2D.")
|
136 |
+
elif isinstance(x, (list, tuple)):
|
137 |
+
longest_length = max(x_.shape[0] for x_ in x)
|
138 |
+
if not truncation and max_length < longest_length:
|
139 |
+
max_length = longest_length
|
140 |
+
|
141 |
+
if all(isinstance(x_, np.ndarray) for x_ in x):
|
142 |
+
if not all(x_.ndim == 1 for x_ in x):
|
143 |
+
raise ValueError("All np.ndarray in a list must be 1D.")
|
144 |
+
|
145 |
+
x_trim = [x_[:max_length] for x_ in x]
|
146 |
+
x_pad = [np.pad(x_, (0, max_length - x_.shape[0]), mode="constant", constant_values=0) for x_ in x_trim]
|
147 |
+
x = torch.stack([torch.from_numpy(x_) for x_ in x_pad])
|
148 |
+
elif all(isinstance(x_, torch.Tensor) for x_ in x):
|
149 |
+
if not all(x_.dim() == 1 for x_ in x):
|
150 |
+
raise ValueError("All torch.Tensor in a list must be 1D.")
|
151 |
+
x_pad = [torch.nn.functional.pad(x_, (0, max_length - x_.shape[0]), value=0) for x_ in x]
|
152 |
+
x = torch.stack(x_pad)
|
153 |
+
else:
|
154 |
+
raise ValueError("Input list must be numpy arrays or PyTorch tensors.")
|
155 |
+
else:
|
156 |
+
raise ValueError(
|
157 |
+
"Input must be a numpy array, a list of numpy arrays, a PyTorch tensor, or a list of PyTorch tensor."
|
158 |
+
)
|
159 |
|
160 |
+
x = x.float()
|
161 |
x = mel_spectrogram(x)
|
162 |
x = amplitude_to_db(x)
|
163 |
return BatchFeature({"input_values": x})
|
ced_model/modeling_ced.py
CHANGED
@@ -106,9 +106,7 @@ class CedAudioPatchEmbed(nn.Module):
|
|
106 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
107 |
self.flatten = flatten
|
108 |
|
109 |
-
self.proj = nn.Conv2d(
|
110 |
-
in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride
|
111 |
-
)
|
112 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
113 |
|
114 |
def forward(self, x):
|
@@ -143,11 +141,7 @@ class CedAttention(nn.Module):
|
|
143 |
|
144 |
def forward(self, x):
|
145 |
B, N, C = x.shape
|
146 |
-
qkv = (
|
147 |
-
self.qkv(x)
|
148 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
149 |
-
.permute(2, 0, 3, 1, 4)
|
150 |
-
)
|
151 |
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
152 |
|
153 |
attn = (q @ k.transpose(-2, -1)) * self.scale
|
@@ -221,9 +215,7 @@ class DropPath(nn.Module):
|
|
221 |
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
222 |
|
223 |
|
224 |
-
def drop_path(
|
225 |
-
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
226 |
-
):
|
227 |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
228 |
|
229 |
This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
|
@@ -236,9 +228,7 @@ def drop_path(
|
|
236 |
if drop_prob == 0.0 or not training:
|
237 |
return x
|
238 |
keep_prob = 1 - drop_prob
|
239 |
-
shape = (x.shape[0],) + (1,) * (
|
240 |
-
x.ndim - 1
|
241 |
-
) # work with diff dim tensors, not just 2D ConvNets
|
242 |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
243 |
if keep_prob > 0.0 and scale_by_keep:
|
244 |
random_tensor.div_(keep_prob)
|
@@ -373,17 +363,11 @@ class CedModel(CedPreTrainedModel):
|
|
373 |
patch_stride=config.patch_stride,
|
374 |
)
|
375 |
|
376 |
-
self.time_pos_embed = nn.Parameter(
|
377 |
-
|
378 |
-
)
|
379 |
-
self.freq_pos_embed = nn.Parameter(
|
380 |
-
torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
|
381 |
-
)
|
382 |
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
383 |
act_layer = nn.GELU
|
384 |
-
dpr = [
|
385 |
-
x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)
|
386 |
-
] # stochastic depth decay rule
|
387 |
self.pos_drop = nn.Dropout(p=config.drop_rate)
|
388 |
self.blocks = nn.Sequential(
|
389 |
*[
|
@@ -407,13 +391,16 @@ class CedModel(CedPreTrainedModel):
|
|
407 |
# Initialize weights and apply final processing
|
408 |
self.post_init()
|
409 |
|
|
|
|
|
|
|
|
|
|
|
410 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
411 |
x = self.patch_embed(x)
|
412 |
_, _, _, t = x.shape
|
413 |
x = x + self.time_pos_embed[:, :, :, :t]
|
414 |
-
x =
|
415 |
-
x + self.freq_pos_embed[:, :, :, :]
|
416 |
-
) # Just to support __getitem__ in posembed
|
417 |
|
418 |
# x = rearrange(x, 'b c f t -> b (f t) c')
|
419 |
x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
|
@@ -442,9 +429,7 @@ class CedModel(CedPreTrainedModel):
|
|
442 |
|
443 |
if splits[-1].shape[-1] < self.maximal_allowed_length:
|
444 |
if self.config.pad_last:
|
445 |
-
pad = torch.zeros(
|
446 |
-
*x.shape[:-1], self.maximal_allowed_length, device=x.device
|
447 |
-
)
|
448 |
pad[..., : splits[-1].shape[-1]] = splits[-1]
|
449 |
splits = torch.stack((*splits[:-1], pad), dim=0)
|
450 |
else:
|
@@ -497,9 +482,7 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
497 |
elif self.config.pooling == "dm":
|
498 |
# Unpack using the frequency dimension, which is constant
|
499 |
# 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
|
500 |
-
x = torch.reshape(
|
501 |
-
x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3])
|
502 |
-
)
|
503 |
|
504 |
# First poolin frequency, then sigmoid the (B T D) output
|
505 |
x = self.outputlayer(x.mean(1)).sigmoid()
|
@@ -507,9 +490,10 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
507 |
else:
|
508 |
return x.mean(1)
|
509 |
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
513 |
@add_code_sample_docstrings(
|
514 |
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
515 |
output_type=SequenceClassifierOutput,
|
@@ -519,9 +503,7 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
519 |
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
520 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
521 |
)
|
522 |
-
def forward(
|
523 |
-
self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None
|
524 |
-
):
|
525 |
"""
|
526 |
Runs a forward pass of the CED model for audio classification task.
|
527 |
|
@@ -554,14 +536,15 @@ class CedForAudioClassification(CedPreTrainedModel):
|
|
554 |
logits = self.forward_head(last_hidden_states)
|
555 |
|
556 |
if labels is not None:
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
|
|
|
|
|
|
561 |
loss = loss_fct(logits, labels)
|
562 |
else:
|
563 |
loss = None
|
564 |
|
565 |
-
return SequenceClassifierOutput(
|
566 |
-
logits=logits, loss=loss, hidden_states=last_hidden_states
|
567 |
-
)
|
|
|
106 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
107 |
self.flatten = flatten
|
108 |
|
109 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride)
|
|
|
|
|
110 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
111 |
|
112 |
def forward(self, x):
|
|
|
141 |
|
142 |
def forward(self, x):
|
143 |
B, N, C = x.shape
|
144 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
|
|
|
|
145 |
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
146 |
|
147 |
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
215 |
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
216 |
|
217 |
|
218 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
|
|
|
|
219 |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
220 |
|
221 |
This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
|
|
|
228 |
if drop_prob == 0.0 or not training:
|
229 |
return x
|
230 |
keep_prob = 1 - drop_prob
|
231 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
|
|
|
232 |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
233 |
if keep_prob > 0.0 and scale_by_keep:
|
234 |
random_tensor.div_(keep_prob)
|
|
|
363 |
patch_stride=config.patch_stride,
|
364 |
)
|
365 |
|
366 |
+
self.time_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02)
|
367 |
+
self.freq_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02)
|
|
|
|
|
|
|
|
|
368 |
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
369 |
act_layer = nn.GELU
|
370 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)] # stochastic depth decay rule
|
|
|
|
|
371 |
self.pos_drop = nn.Dropout(p=config.drop_rate)
|
372 |
self.blocks = nn.Sequential(
|
373 |
*[
|
|
|
391 |
# Initialize weights and apply final processing
|
392 |
self.post_init()
|
393 |
|
394 |
+
def _freeze_parameters(self):
|
395 |
+
for param in self.parameters():
|
396 |
+
param.requires_grad = False
|
397 |
+
self._requires_grad = False
|
398 |
+
|
399 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
400 |
x = self.patch_embed(x)
|
401 |
_, _, _, t = x.shape
|
402 |
x = x + self.time_pos_embed[:, :, :, :t]
|
403 |
+
x = x + self.freq_pos_embed[:, :, :, :] # Just to support __getitem__ in posembed
|
|
|
|
|
404 |
|
405 |
# x = rearrange(x, 'b c f t -> b (f t) c')
|
406 |
x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
|
|
|
429 |
|
430 |
if splits[-1].shape[-1] < self.maximal_allowed_length:
|
431 |
if self.config.pad_last:
|
432 |
+
pad = torch.zeros(*x.shape[:-1], self.maximal_allowed_length, device=x.device)
|
|
|
|
|
433 |
pad[..., : splits[-1].shape[-1]] = splits[-1]
|
434 |
splits = torch.stack((*splits[:-1], pad), dim=0)
|
435 |
else:
|
|
|
482 |
elif self.config.pooling == "dm":
|
483 |
# Unpack using the frequency dimension, which is constant
|
484 |
# 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
|
485 |
+
x = torch.reshape(x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3]))
|
|
|
|
|
486 |
|
487 |
# First poolin frequency, then sigmoid the (B T D) output
|
488 |
x = self.outputlayer(x.mean(1)).sigmoid()
|
|
|
490 |
else:
|
491 |
return x.mean(1)
|
492 |
|
493 |
+
def freeze_encoder(self):
|
494 |
+
self.encoder._freeze_parameters()
|
495 |
+
|
496 |
+
@add_start_docstrings_to_model_forward(CED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
497 |
@add_code_sample_docstrings(
|
498 |
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
499 |
output_type=SequenceClassifierOutput,
|
|
|
503 |
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
504 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
505 |
)
|
506 |
+
def forward(self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
|
|
|
|
|
507 |
"""
|
508 |
Runs a forward pass of the CED model for audio classification task.
|
509 |
|
|
|
536 |
logits = self.forward_head(last_hidden_states)
|
537 |
|
538 |
if labels is not None:
|
539 |
+
if self.config.loss == "CE":
|
540 |
+
loss_fct = nn.CrossEntropyLoss()
|
541 |
+
elif self.config.loss == "BCE":
|
542 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
543 |
+
else:
|
544 |
+
raise NotImplementedError("Need to set 'CE' or 'BCE' as config.loss.")
|
545 |
+
labels = nn.functional.one_hot(labels, num_classes=self.config.outputdim).float()
|
546 |
loss = loss_fct(logits, labels)
|
547 |
else:
|
548 |
loss = None
|
549 |
|
550 |
+
return SequenceClassifierOutput(logits=logits, loss=loss, hidden_states=last_hidden_states)
|
|
|
|