Kabatubare commited on
Commit
84de51b
·
verified ·
1 Parent(s): 4a724b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -3,27 +3,17 @@ import librosa
3
  import numpy as np
4
  import torch
5
  import torch.nn.functional as F
6
- from transformers import AutoModelForAudioClassification
7
  import random
8
 
 
9
  model = AutoModelForAudioClassification.from_pretrained("./")
 
10
 
11
- def custom_feature_extraction(audio, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
12
- S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
13
- S_DB = librosa.power_to_db(S, ref=np.max)
14
- pitches, _ = librosa.piptrack(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length)
15
- spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length)
16
- # Correct the dimensionality issue
17
- pitches_max = np.max(pitches, axis=0, keepdims=True)
18
- spectral_centroids = spectral_centroids.reshape(1, -1)
19
- # Ensure the concatenation axis has matching dimensions
20
- features = np.concatenate([S_DB, pitches_max, spectral_centroids], axis=0)
21
- features_tensor = torch.tensor(features).float()
22
- if features_tensor.shape[1] > target_length:
23
- features_tensor = features_tensor[:, :target_length]
24
- else:
25
- features_tensor = F.pad(features_tensor, (0, target_length - features_tensor.shape[1]), 'constant', 0)
26
- return features_tensor.unsqueeze(0)
27
 
28
  def apply_time_shift(waveform, max_shift_fraction=0.1):
29
  shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
@@ -31,13 +21,16 @@ def apply_time_shift(waveform, max_shift_fraction=0.1):
31
 
32
  def predict_voice(audio_file_path):
33
  try:
34
- waveform, sample_rate = librosa.load(audio_file_path, sr=None)
35
  augmented_waveform = apply_time_shift(waveform)
 
36
  original_features = custom_feature_extraction(waveform, sr=sample_rate)
37
  augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
 
38
  with torch.no_grad():
39
  outputs_original = model(original_features)
40
  outputs_augmented = model(augmented_features)
 
41
  logits = (outputs_original.logits + outputs_augmented.logits) / 2
42
  predicted_index = logits.argmax()
43
  label = model.config.id2label[predicted_index.item()]
@@ -55,4 +48,3 @@ iface = gr.Interface(
55
  )
56
 
57
  iface.launch()
58
-
 
3
  import numpy as np
4
  import torch
5
  import torch.nn.functional as F
6
+ from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
7
  import random
8
 
9
+ # Model and feature extractor loading from the specified local path
10
  model = AutoModelForAudioClassification.from_pretrained("./")
11
+ feature_extractor = ASTFeatureExtractor.from_pretrained("./")
12
 
13
+ def custom_feature_extraction(audio, sr=16000, n_mels=128, target_length=1024):
14
+ # Using the loaded feature extractor
15
+ features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
16
+ return features.input_values
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def apply_time_shift(waveform, max_shift_fraction=0.1):
19
  shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
 
21
 
22
  def predict_voice(audio_file_path):
23
  try:
24
+ waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
25
  augmented_waveform = apply_time_shift(waveform)
26
+
27
  original_features = custom_feature_extraction(waveform, sr=sample_rate)
28
  augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
29
+
30
  with torch.no_grad():
31
  outputs_original = model(original_features)
32
  outputs_augmented = model(augmented_features)
33
+
34
  logits = (outputs_original.logits + outputs_augmented.logits) / 2
35
  predicted_index = logits.argmax()
36
  label = model.config.id2label[predicted_index.item()]
 
48
  )
49
 
50
  iface.launch()