CindyDelage commited on
Commit
cc249c3
·
verified ·
1 Parent(s): 7f14609

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +21 -17
tasks/audio.py CHANGED
@@ -13,6 +13,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
  from dotenv import load_dotenv
14
  import logging
15
  import csv
 
16
 
17
  # Configurer le logging
18
  logging.basicConfig(level=logging.INFO)
@@ -32,6 +33,14 @@ def preprocess_function(example, feature_extractor):
32
  sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt"
33
  )
34
 
 
 
 
 
 
 
 
 
35
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
36
  async def evaluate_audio(request: AudioEvaluationRequest):
37
  """
@@ -51,28 +60,27 @@ async def evaluate_audio(request: AudioEvaluationRequest):
51
  tracker.start()
52
  tracker.start_task("inference")
53
 
54
-
55
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
56
 
57
- test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True, batch_size=32) # Choix de la taille du batch)
58
 
59
  gc.collect()
60
 
61
- # Charger le modèle depuis Hugging Face Hub (par exemple, à partir de l'ID du modèle)
62
- model_name = "CindyDelage/Challenge_HuggingFace_DFG_FrugalAI" # Nom du modèle dans Hugging Face Hub
63
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) # Charger le modèle pour la classification audio
64
-
65
- # Appliquer la quantification dynamique si nécessaire
66
- model.eval() # Mettre le modèle en mode évaluation
67
- model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8) # Appliquer la quantification dynamique
68
 
69
- classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI",feature_extractor=feature_extractor, device=device)
 
 
 
 
 
70
  predictions = []
71
  logging.info("Début des prédictions par batch")
72
  for data in iter(test_dataset):
73
  with torch.no_grad():
74
  result = classifier(np.asarray(data["input_values"]), batch_size=1)
75
-
76
  predicted_label = result[0]['label']
77
  label = 1 if predicted_label == 'environment' else 0
78
  predictions.append(label)
@@ -84,8 +92,6 @@ async def evaluate_audio(request: AudioEvaluationRequest):
84
  gc.collect()
85
 
86
  logging.info("Fin des prédictions")
87
- del result
88
- del label
89
  del classifier
90
  del feature_extractor
91
 
@@ -93,9 +99,7 @@ async def evaluate_audio(request: AudioEvaluationRequest):
93
  # Stop tracking emissions
94
  emissions_data = tracker.stop_task()
95
 
96
- # Calculate accuracy
97
- true_labels = [] # Liste pour stocker les labels réels
98
-
99
  for example in test_dataset:
100
  true_labels.append(example["label"])
101
 
@@ -119,4 +123,4 @@ async def evaluate_audio(request: AudioEvaluationRequest):
119
  }
120
 
121
  logging.info("Returning results")
122
- return results
 
13
  from dotenv import load_dotenv
14
  import logging
15
  import csv
16
+ import torch.nn.utils.prune as prune
17
 
18
  # Configurer le logging
19
  logging.basicConfig(level=logging.INFO)
 
33
  sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt"
34
  )
35
 
36
+ def apply_pruning(model, amount=0.3):
37
+ """Applique un pruning sur les poids du modèle."""
38
+ for name, module in model.named_modules():
39
+ if isinstance(module, torch.nn.Linear):
40
+ prune.l1_unstructured(module, name="weight", amount=amount)
41
+ prune.remove(module, "weight")
42
+ return model
43
+
44
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
45
  async def evaluate_audio(request: AudioEvaluationRequest):
46
  """
 
60
  tracker.start()
61
  tracker.start_task("inference")
62
 
 
63
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
64
 
65
+ test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True, batch_size=32)
66
 
67
  gc.collect()
68
 
69
+ model_name = "CindyDelage/Challenge_HuggingFace_DFG_FrugalAI"
70
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
 
 
 
 
 
71
 
72
+ # Appliquer la quantification dynamique et le pruning
73
+ model.eval()
74
+ model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)
75
+ model = apply_pruning(model, amount=0.3) # Prune 30% des poids linéaires
76
+
77
+ classifier = pipeline("audio-classification", model=model, feature_extractor=feature_extractor, device=device)
78
  predictions = []
79
  logging.info("Début des prédictions par batch")
80
  for data in iter(test_dataset):
81
  with torch.no_grad():
82
  result = classifier(np.asarray(data["input_values"]), batch_size=1)
83
+
84
  predicted_label = result[0]['label']
85
  label = 1 if predicted_label == 'environment' else 0
86
  predictions.append(label)
 
92
  gc.collect()
93
 
94
  logging.info("Fin des prédictions")
 
 
95
  del classifier
96
  del feature_extractor
97
 
 
99
  # Stop tracking emissions
100
  emissions_data = tracker.stop_task()
101
 
102
+ true_labels = []
 
 
103
  for example in test_dataset:
104
  true_labels.append(example["label"])
105
 
 
123
  }
124
 
125
  logging.info("Returning results")
126
+ return results