Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -36,9 +36,11 @@ def augment_and_extract_features(audio_path, sr=16000, n_mfcc=40, n_fft=2048, ho
|
|
36 |
def predict_voice(audio_file_path):
|
37 |
try:
|
38 |
features_tensor = augment_and_extract_features(audio_file_path)
|
39 |
-
#
|
40 |
-
if features_tensor.
|
41 |
-
features_tensor =
|
|
|
|
|
42 |
with torch.no_grad():
|
43 |
outputs = model(features_tensor)
|
44 |
|
@@ -63,4 +65,3 @@ iface = gr.Interface(
|
|
63 |
)
|
64 |
|
65 |
iface.launch()
|
66 |
-
|
|
|
36 |
def predict_voice(audio_file_path):
|
37 |
try:
|
38 |
features_tensor = augment_and_extract_features(audio_file_path)
|
39 |
+
# Adjust model input size or preprocessing to avoid size mismatch with convolution kernel
|
40 |
+
if features_tensor.dim() < 4: # Ensure tensor is 4D (batch, channel, height, width) for CNNs
|
41 |
+
features_tensor = features_tensor.unsqueeze(1) # Add a channel dimension if missing
|
42 |
+
# Apply adaptive pooling to match model expected input size if necessary
|
43 |
+
features_tensor = torch.nn.AdaptiveAvgPool2d((model.config.num_labels, model.config.num_labels))(features_tensor)
|
44 |
with torch.no_grad():
|
45 |
outputs = model(features_tensor)
|
46 |
|
|
|
65 |
)
|
66 |
|
67 |
iface.launch()
|
|