Spaces:
Sleeping
Sleeping
Update tasks/audio.py
Browse files- tasks/audio.py +21 -17
tasks/audio.py
CHANGED
@@ -13,6 +13,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
13 |
from dotenv import load_dotenv
|
14 |
import logging
|
15 |
import csv
|
|
|
16 |
|
17 |
# Configurer le logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
@@ -32,6 +33,14 @@ def preprocess_function(example, feature_extractor):
|
|
32 |
sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt"
|
33 |
)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
|
36 |
async def evaluate_audio(request: AudioEvaluationRequest):
|
37 |
"""
|
@@ -51,28 +60,27 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
51 |
tracker.start()
|
52 |
tracker.start_task("inference")
|
53 |
|
54 |
-
|
55 |
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
56 |
|
57 |
-
test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True,
|
58 |
|
59 |
gc.collect()
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) # Charger le modèle pour la classification audio
|
64 |
-
|
65 |
-
# Appliquer la quantification dynamique si nécessaire
|
66 |
-
model.eval() # Mettre le modèle en mode évaluation
|
67 |
-
model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8) # Appliquer la quantification dynamique
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
predictions = []
|
71 |
logging.info("Début des prédictions par batch")
|
72 |
for data in iter(test_dataset):
|
73 |
with torch.no_grad():
|
74 |
result = classifier(np.asarray(data["input_values"]), batch_size=1)
|
75 |
-
|
76 |
predicted_label = result[0]['label']
|
77 |
label = 1 if predicted_label == 'environment' else 0
|
78 |
predictions.append(label)
|
@@ -84,8 +92,6 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
84 |
gc.collect()
|
85 |
|
86 |
logging.info("Fin des prédictions")
|
87 |
-
del result
|
88 |
-
del label
|
89 |
del classifier
|
90 |
del feature_extractor
|
91 |
|
@@ -93,9 +99,7 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
93 |
# Stop tracking emissions
|
94 |
emissions_data = tracker.stop_task()
|
95 |
|
96 |
-
|
97 |
-
true_labels = [] # Liste pour stocker les labels réels
|
98 |
-
|
99 |
for example in test_dataset:
|
100 |
true_labels.append(example["label"])
|
101 |
|
@@ -119,4 +123,4 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
119 |
}
|
120 |
|
121 |
logging.info("Returning results")
|
122 |
-
return results
|
|
|
13 |
from dotenv import load_dotenv
|
14 |
import logging
|
15 |
import csv
|
16 |
+
import torch.nn.utils.prune as prune
|
17 |
|
18 |
# Configurer le logging
|
19 |
logging.basicConfig(level=logging.INFO)
|
|
|
33 |
sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt"
|
34 |
)
|
35 |
|
36 |
+
def apply_pruning(model, amount=0.3):
|
37 |
+
"""Applique un pruning sur les poids du modèle."""
|
38 |
+
for name, module in model.named_modules():
|
39 |
+
if isinstance(module, torch.nn.Linear):
|
40 |
+
prune.l1_unstructured(module, name="weight", amount=amount)
|
41 |
+
prune.remove(module, "weight")
|
42 |
+
return model
|
43 |
+
|
44 |
@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
|
45 |
async def evaluate_audio(request: AudioEvaluationRequest):
|
46 |
"""
|
|
|
60 |
tracker.start()
|
61 |
tracker.start_task("inference")
|
62 |
|
|
|
63 |
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
64 |
|
65 |
+
test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True, batch_size=32)
|
66 |
|
67 |
gc.collect()
|
68 |
|
69 |
+
model_name = "CindyDelage/Challenge_HuggingFace_DFG_FrugalAI"
|
70 |
+
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
# Appliquer la quantification dynamique et le pruning
|
73 |
+
model.eval()
|
74 |
+
model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)
|
75 |
+
model = apply_pruning(model, amount=0.3) # Prune 30% des poids linéaires
|
76 |
+
|
77 |
+
classifier = pipeline("audio-classification", model=model, feature_extractor=feature_extractor, device=device)
|
78 |
predictions = []
|
79 |
logging.info("Début des prédictions par batch")
|
80 |
for data in iter(test_dataset):
|
81 |
with torch.no_grad():
|
82 |
result = classifier(np.asarray(data["input_values"]), batch_size=1)
|
83 |
+
|
84 |
predicted_label = result[0]['label']
|
85 |
label = 1 if predicted_label == 'environment' else 0
|
86 |
predictions.append(label)
|
|
|
92 |
gc.collect()
|
93 |
|
94 |
logging.info("Fin des prédictions")
|
|
|
|
|
95 |
del classifier
|
96 |
del feature_extractor
|
97 |
|
|
|
99 |
# Stop tracking emissions
|
100 |
emissions_data = tracker.stop_task()
|
101 |
|
102 |
+
true_labels = []
|
|
|
|
|
103 |
for example in test_dataset:
|
104 |
true_labels.append(example["label"])
|
105 |
|
|
|
123 |
}
|
124 |
|
125 |
logging.info("Returning results")
|
126 |
+
return results
|