CindyDelage commited on
Commit
29697c3
·
verified ·
1 Parent(s): d894f7c

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +40 -83
tasks/audio.py CHANGED
@@ -5,18 +5,18 @@ from sklearn.metrics import accuracy_score
5
  import numpy as np
6
  import random
7
  import os
8
- from transformers import AutoFeatureExtractor
9
-
 
 
 
10
  from .utils.evaluation import AudioEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
-
13
  from dotenv import load_dotenv
14
  import logging
15
 
16
  # Configurer le logging
17
  logging.basicConfig(level=logging.INFO)
18
-
19
- # Utiliser le logging au lieu de print
20
  logging.info("Début du fichier python")
21
  load_dotenv()
22
 
@@ -25,112 +25,69 @@ router = APIRouter()
25
  DESCRIPTION = "Random Baseline"
26
  ROUTE = "/audio"
27
 
28
-
29
-
30
- @router.post(ROUTE, tags=["Audio Task"],
31
- description=DESCRIPTION)
32
-
33
  async def evaluate_audio(request: AudioEvaluationRequest):
34
  """
35
  Evaluate audio classification for rainforest sound detection.
36
-
37
- Current Model: Random Baseline
38
- - Makes random predictions from the label space (0-1)
39
- - Used as a baseline for comparison
40
  """
41
  # Get space info
42
  username, space_url = get_space_info()
43
-
44
- # Define the label mapping
45
- LABEL_MAPPING = {
46
- "chainsaw": 0,
47
- "environment": 1
48
- }
49
- # Load and prepare the dataset
50
- # Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
51
  logging.info("Chargement des données")
52
- dataset = load_dataset(request.dataset_name,token=os.getenv("HF_TOKEN"))
53
  logging.info("Données chargées")
54
- # Split dataset
55
- train_test = dataset["train"]
56
  test_dataset = dataset["test"]
57
 
58
  # Start tracking emissions
59
  tracker.start()
60
  tracker.start_task("inference")
61
 
62
- #--------------------------------------------------------------------------------------------
63
- # YOUR MODEL INFERENCE CODE HERE
64
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
65
- #--------------------------------------------------------------------------------------------
66
-
67
- # Make random predictions (placeholder for actual model inference)
68
- true_labels = test_dataset["label"]
69
- import torch
70
- from transformers import pipeline
71
- from sklearn import preprocessing
72
- from transformers import AutoFeatureExtractor
73
-
74
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
75
-
76
  def preprocess_function(examples):
77
  audio_arrays = [x["array"] for x in examples["audio"]]
78
- inputs = feature_extractor(audio_arrays, sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True,return_tensors="pt")
79
- return inputs
80
-
81
  encoded_data_test = test_dataset.map(preprocess_function, remove_columns="audio", batched=True)
82
 
83
- from datasets import Dataset
84
- from transformers import AutoFeatureExtractor
85
-
86
- # Utilisation du pipeline directement sur le dataset
87
- classifier = pipeline("audio-classification",
88
- model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI",
89
- device=-1)
90
- # Correctly access the audio data
91
- # audio_data = [example["array"] for example in dataset["test"]["audio"]]
92
  predictions = []
93
-
94
- logging.info("Début des prédictions")
95
 
96
- for example in encoded_data_test:
97
- logging.info("Nombre de prédictions faites :", len(predictions))
98
- input_values = np.array(example["input_values"])
99
- result = classifier(input_values) # Utilisation des données pré-traitées
100
- #result = classifier(example["input_values"]) # Utilisation des données pré-traitées
101
- predicted_label = result[0]['label']
102
- predictions.append(1 if predicted_label == 'environment' else 0)
103
-
104
-
105
- logging.info("Fin des prédictions")
106
- #predictions = []
107
- # for result in results:
108
- # Check if result is a dictionary
109
- #if isinstance(result, dict):
110
- # # Get the label with the highest score
111
- # predicted_label = result['label']
112
- # else:
113
- # If result is not a dictionary, access it as a list
114
- # predicted_label = result[0]['label'] # Assuming the dictionary is the first element
115
 
116
- # Assign 1 for "environment", 0 for "chainsaw"
117
- #if predicted_label == 'environment':
118
- # predictions.append(1)
119
- #else:
120
- # predictions.append(0)
121
- #print(len(predictions))
122
-
123
- #--------------------------------------------------------------------------------------------
124
- # YOUR MODEL INFERENCE STOPS HERE
125
- #--------------------------------------------------------------------------------------------
126
 
127
  # Stop tracking emissions
128
  emissions_data = tracker.stop_task()
129
 
130
  # Calculate accuracy
 
131
  accuracy = accuracy_score(true_labels, predictions)
132
 
133
- # Prepare results dictionary
134
  results = {
135
  "username": username,
136
  "space_url": space_url,
@@ -148,4 +105,4 @@ async def evaluate_audio(request: AudioEvaluationRequest):
148
  }
149
  }
150
  logging.info("Returning results")
151
- return results
 
5
  import numpy as np
6
  import random
7
  import os
8
+ import torch
9
+ import gc
10
+ import psutil
11
+ from torch.utils.data import DataLoader
12
+ from transformers import AutoFeatureExtractor, pipeline
13
  from .utils.evaluation import AudioEvaluationRequest
14
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
15
  from dotenv import load_dotenv
16
  import logging
17
 
18
  # Configurer le logging
19
  logging.basicConfig(level=logging.INFO)
 
 
20
  logging.info("Début du fichier python")
21
  load_dotenv()
22
 
 
25
  DESCRIPTION = "Random Baseline"
26
  ROUTE = "/audio"
27
 
28
+ @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
 
 
 
 
29
  async def evaluate_audio(request: AudioEvaluationRequest):
30
  """
31
  Evaluate audio classification for rainforest sound detection.
 
 
 
 
32
  """
33
  # Get space info
34
  username, space_url = get_space_info()
35
+
36
+ # Load dataset
 
 
 
 
 
 
37
  logging.info("Chargement des données")
38
+ dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
39
  logging.info("Données chargées")
40
+
 
41
  test_dataset = dataset["test"]
42
 
43
  # Start tracking emissions
44
  tracker.start()
45
  tracker.start_task("inference")
46
 
47
+ # Feature extraction
 
 
 
 
 
 
 
 
 
 
 
48
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
49
+
50
  def preprocess_function(examples):
51
  audio_arrays = [x["array"] for x in examples["audio"]]
52
+ return feature_extractor(audio_arrays, sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt")
53
+
 
54
  encoded_data_test = test_dataset.map(preprocess_function, remove_columns="audio", batched=True)
55
 
56
+ # Pipeline de classification
57
+ classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI", device=-1)
58
+
59
+ # DataLoader pour batch processing
60
+ BATCH_SIZE = 8
61
+ dataset_for_loader = [{"input_values": torch.tensor(example["input_values"])} for example in encoded_data_test]
62
+ dataloader = DataLoader(dataset_for_loader, batch_size=BATCH_SIZE)
63
+
 
64
  predictions = []
65
+ logging.info("Début des prédictions par batch")
 
66
 
67
+ for batch in dataloader:
68
+ input_values = batch["input_values"]
69
+ results = classifier(input_values) # Pipeline en batch
70
+
71
+ for result in results:
72
+ predicted_label = result[0]['label']
73
+ predictions.append(1 if predicted_label == 'environment' else 0)
74
+
75
+ # Nettoyage mémoire toutes les 500 prédictions
76
+ if len(predictions) % 500 == 0:
77
+ torch.cuda.empty_cache()
78
+ gc.collect()
79
+ logging.info(f"Nettoyage de la mémoire après {len(predictions)} prédictions")
80
+ logging.info(f"Utilisation mémoire : {psutil.virtual_memory().percent}%")
 
 
 
 
 
81
 
82
+ logging.info("Fin des prédictions")
 
 
 
 
 
 
 
 
 
83
 
84
  # Stop tracking emissions
85
  emissions_data = tracker.stop_task()
86
 
87
  # Calculate accuracy
88
+ true_labels = test_dataset["label"]
89
  accuracy = accuracy_score(true_labels, predictions)
90
 
 
91
  results = {
92
  "username": username,
93
  "space_url": space_url,
 
105
  }
106
  }
107
  logging.info("Returning results")
108
+ return results