CindyDelage commited on
Commit
f1fb962
·
verified ·
1 Parent(s): 7f9b2de

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +20 -20
tasks/audio.py CHANGED
@@ -55,36 +55,36 @@ async def evaluate_audio(request: AudioEvaluationRequest):
55
 
56
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
57
  # Prétraitement en streaming avec fonction explicite
58
- test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True)
59
 
60
  gc.collect()
61
 
62
  # Pipeline de classification optimisé
63
  classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI", device=device)
64
-
65
- logging.info("Début des prédictions par batch")
66
 
67
- with open('predictions.csv', mode='w', newline='') as file:
68
- writer = csv.writer(file)
69
- writer.writerow(['predicted_label']) # Écrire les en-têtes
70
- # Traiter les données et écrire les résultats dans le fichier
71
- for data in test_dataset:
72
- with torch.no_grad():
73
- result = classifier(np.asarray(data["input_values"]), batch_size=2)
74
 
75
- predicted_label = result[0]['label']
76
- label = 1 if predicted_label == 'environment' else 0
77
-
78
- # Écrire chaque prédiction directement dans le fichier
79
- writer.writerow([label])
80
-
81
- # Nettoyer la mémoire après chaque itération
82
- del result
83
- torch.cuda.empty_cache()
84
- gc.collect()
85
 
86
  logging.info("Fin des prédictions")
 
 
 
 
87
 
 
88
  # Stop tracking emissions
89
  emissions_data = tracker.stop_task()
90
 
 
55
 
56
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
57
  # Prétraitement en streaming avec fonction explicite
58
+ test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True, batch_size=32) # Choix de la taille du batch)
59
 
60
  gc.collect()
61
 
62
  # Pipeline de classification optimisé
63
  classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI", device=device)
64
+ predictions = [] # Liste pour stocker les prédictions
 
65
 
66
+ logging.info("Début des prédictions par batch")
67
+ for data in iter(test_dataset):
68
+ with torch.no_grad():
69
+ result = classifier(np.asarray(data["input_values"]), batch_size=2)
 
 
 
70
 
71
+ predicted_label = result[0]['label']
72
+ label = 1 if predicted_label == 'environment' else 0
73
+ predictions.append(label) # Ajouter la prédiction à la liste
74
+
75
+ # Nettoyer la mémoire après chaque itération
76
+ del result
77
+ del label
78
+ torch.cuda.empty_cache()
79
+ gc.collect()
 
80
 
81
  logging.info("Fin des prédictions")
82
+ del result
83
+ del label
84
+ del classifier
85
+ del feature_extractor
86
 
87
+ gc.collect()
88
  # Stop tracking emissions
89
  emissions_data = tracker.stop_task()
90