CindyDelage commited on
Commit
7342f5a
·
verified ·
1 Parent(s): 1552e0b

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +22 -26
tasks/audio.py CHANGED
@@ -3,12 +3,10 @@ from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import numpy as np
6
- import random
7
  import os
8
  import torch
9
  import gc
10
  import psutil
11
- from torch.utils.data import DataLoader
12
  from transformers import AutoFeatureExtractor, pipeline
13
  from .utils.evaluation import AudioEvaluationRequest
14
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -25,6 +23,8 @@ router = APIRouter()
25
  DESCRIPTION = "Random Baseline"
26
  ROUTE = "/audio"
27
 
 
 
28
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
29
  async def evaluate_audio(request: AudioEvaluationRequest):
30
  """
@@ -33,13 +33,14 @@ async def evaluate_audio(request: AudioEvaluationRequest):
33
  # Get space info
34
  username, space_url = get_space_info()
35
 
36
- # Load dataset
37
  logging.info("Chargement des données")
38
- dataset = load_dataset(request.dataset_name, streaming=True,token=os.getenv("HF_TOKEN"))
39
  logging.info("Données chargées")
40
 
41
  test_dataset = dataset["test"]
42
  del dataset
 
43
  # Start tracking emissions
44
  tracker.start()
45
  tracker.start_task("inference")
@@ -47,39 +48,33 @@ async def evaluate_audio(request: AudioEvaluationRequest):
47
  # Feature extraction
48
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
49
 
50
- def preprocess_function(examples):
51
- audio_arrays = [x["array"] for x in examples["audio"]]
52
- return feature_extractor(audio_arrays, sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
53
 
54
- encoded_data_test = test_dataset.map(preprocess_function, remove_columns="audio", batched=True)#, keep_in_memory=False)
55
- del feature_extractor
56
- del audio_arrays
57
- # Pipeline de classification
58
- classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI", device=-1)
59
-
60
  predictions = []
61
  logging.info("Début des prédictions par batch")
62
-
63
- for data in encoded_data_test:
64
- # Récupérer les données audio et le label
65
  with torch.no_grad():
66
  result = classifier(np.asarray(data["input_values"]))
67
 
68
  predicted_label = result[0]['label']
69
  predictions.append(1 if predicted_label == 'environment' else 0)
70
- del result
71
- del predicted_label
72
-
73
- # Nettoyage mémoire après chaque batch
74
- #del input_values
75
  torch.cuda.empty_cache()
76
  gc.collect()
77
 
78
- # Log mémoire toutes les 500 prédictions
79
- if len(predictions) % 500 == 0:
80
- logging.info(f"Nettoyage mémoire après {len(predictions)} prédictions")
81
- logging.info(f"Utilisation mémoire : {psutil.virtual_memory().percent}%")
82
-
83
  logging.info("Fin des prédictions")
84
 
85
  # Stop tracking emissions
@@ -105,5 +100,6 @@ async def evaluate_audio(request: AudioEvaluationRequest):
105
  "test_seed": request.test_seed
106
  }
107
  }
 
108
  logging.info("Returning results")
109
  return results
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import numpy as np
 
6
  import os
7
  import torch
8
  import gc
9
  import psutil
 
10
  from transformers import AutoFeatureExtractor, pipeline
11
  from .utils.evaluation import AudioEvaluationRequest
12
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
23
  DESCRIPTION = "Random Baseline"
24
  ROUTE = "/audio"
25
 
26
+ device = 0 if torch.cuda.is_available() else -1 # Choix du périphérique GPU si dispo
27
+
28
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
29
  async def evaluate_audio(request: AudioEvaluationRequest):
30
  """
 
33
  # Get space info
34
  username, space_url = get_space_info()
35
 
36
+ # Load dataset en streaming
37
  logging.info("Chargement des données")
38
+ dataset = load_dataset(request.dataset_name, streaming=True, token=os.getenv("HF_TOKEN"))
39
  logging.info("Données chargées")
40
 
41
  test_dataset = dataset["test"]
42
  del dataset
43
+
44
  # Start tracking emissions
45
  tracker.start()
46
  tracker.start_task("inference")
 
48
  # Feature extraction
49
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
50
 
51
+ def preprocess_function(example):
52
+ audio_array = example["audio"]["array"]
53
+ return feature_extractor(audio_array, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
54
+
55
+ # Prétraitement en streaming
56
+ test_dataset = test_dataset.map(preprocess_function, remove_columns=["audio"])
57
+
58
+ del feature_extractor # Libération explicite
59
+ gc.collect()
60
+
61
+ # Pipeline de classification optimisé
62
+ classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI", device=device)
63
 
 
 
 
 
 
 
64
  predictions = []
65
  logging.info("Début des prédictions par batch")
66
+
67
+ for data in test_dataset:
 
68
  with torch.no_grad():
69
  result = classifier(np.asarray(data["input_values"]))
70
 
71
  predicted_label = result[0]['label']
72
  predictions.append(1 if predicted_label == 'environment' else 0)
73
+
74
+ del result # Nettoyage mémoire
 
 
 
75
  torch.cuda.empty_cache()
76
  gc.collect()
77
 
 
 
 
 
 
78
  logging.info("Fin des prédictions")
79
 
80
  # Stop tracking emissions
 
100
  "test_seed": request.test_seed
101
  }
102
  }
103
+
104
  logging.info("Returning results")
105
  return results