waidhoferj commited on
Commit
0a2992f
·
1 Parent(s): 1c22425

fixed preprocessing

Browse files
models/audio_spectrogram_transformer.py CHANGED
@@ -34,9 +34,9 @@ class AST(nn.Module):
34
  super().__init__(*args, **kwargs)
35
  id2label, label2id = get_id_label_mapping(labels)
36
  config = ASTConfig(
37
- hidden_size=300,
38
- num_attention_heads=5,
39
- num_hidden_layers=3,
40
  id2label=id2label,
41
  label2id=label2id,
42
  num_labels=len(label2id),
@@ -48,9 +48,13 @@ class AST(nn.Module):
48
  return self.model(x).logits
49
 
50
 
 
 
 
51
  class ASTExtractorWrapper:
52
  def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
53
- self.extractor = ASTFeatureExtractor()
 
54
  self.sampling_rate = sampling_rate
55
  self.return_tensors = return_tensors
56
  self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
@@ -62,7 +66,11 @@ class ASTExtractorWrapper:
62
  x = self.extractor(
63
  x, return_tensors=self.return_tensors, sampling_rate=self.sampling_rate
64
  )
65
- return x["input_values"].squeeze(0).to(device)
 
 
 
 
66
 
67
 
68
  def train_lightning_ast(config: dict):
 
34
  super().__init__(*args, **kwargs)
35
  id2label, label2id = get_id_label_mapping(labels)
36
  config = ASTConfig(
37
+ hidden_size=256,
38
+ num_hidden_layers=6,
39
+ num_attention_heads=4,
40
  id2label=id2label,
41
  label2id=label2id,
42
  num_labels=len(label2id),
 
48
  return self.model(x).logits
49
 
50
 
51
+ # TODO: Remove waveform normalization from ASTFeatureExtractor.
52
+ # Find correct mean and std dev
53
+ # Find correct max length
54
  class ASTExtractorWrapper:
55
  def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
56
+ max_length = 1024
57
+ self.extractor = ASTFeatureExtractor(do_normalize=False, max_length=max_length)
58
  self.sampling_rate = sampling_rate
59
  self.return_tensors = return_tensors
60
  self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
 
66
  x = self.extractor(
67
  x, return_tensors=self.return_tensors, sampling_rate=self.sampling_rate
68
  )
69
+
70
+ x = x["input_values"].squeeze(0).to(device)
71
+ # normalize
72
+ x = (x - x.mean()) / x.std()
73
+ return x
74
 
75
 
76
  def train_lightning_ast(config: dict):