CindyDelage commited on
Commit
e9b633a
·
verified ·
1 Parent(s): abbb8ef

Upload audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +118 -0
tasks/audio.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from datetime import datetime
3
+ from datasets import load_dataset
4
+ from sklearn.metrics import accuracy_score
5
+ import random
6
+ import os
7
+
8
+ from .utils.evaluation import AudioEvaluationRequest
9
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
10
+
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+
14
+ router = APIRouter()
15
+
16
+ DESCRIPTION = "Random Baseline"
17
+ ROUTE = "/audio"
18
+
19
+
20
+
21
+ @router.post(ROUTE, tags=["Audio Task"],
22
+ description=DESCRIPTION)
23
+ async def evaluate_audio(request: AudioEvaluationRequest):
24
+ """
25
+ Evaluate audio classification for rainforest sound detection.
26
+
27
+ Current Model: Random Baseline
28
+ - Makes random predictions from the label space (0-1)
29
+ - Used as a baseline for comparison
30
+ """
31
+ # Get space info
32
+ username, space_url = get_space_info()
33
+
34
+ # Define the label mapping
35
+ LABEL_MAPPING = {
36
+ "chainsaw": 0,
37
+ "environment": 1
38
+ }
39
+ # Load and prepare the dataset
40
+ # Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
41
+ dataset = load_dataset(request.dataset_name,token=dataset_name,token=os.getenv("HF_TOKEN"))
42
+
43
+ # Split dataset
44
+ train_test = dataset["train"]
45
+ test_dataset = dataset["test"]
46
+
47
+ # Start tracking emissions
48
+ tracker.start()
49
+ tracker.start_task("inference")
50
+
51
+ #--------------------------------------------------------------------------------------------
52
+ # YOUR MODEL INFERENCE CODE HERE
53
+ # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
54
+ #--------------------------------------------------------------------------------------------
55
+
56
+ # Make random predictions (placeholder for actual model inference)
57
+ true_labels = test_dataset["label"]
58
+ import torch
59
+ from transformers import pipeline
60
+ from sklearn import preprocessing
61
+ #encoded_data_fine_tuned_model = train_test["train"].map(preprocess_function, remove_columns="audio", batched=True)
62
+
63
+ from datasets import Dataset
64
+
65
+ # Utilisation du pipeline directement sur le dataset
66
+ classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI", feature_extractor=feature_extractor)
67
+
68
+ # Correctly access the audio data
69
+ audio_data = [example["array"] for example in dataset["test"]["audio"]]
70
+
71
+ # Prédiction sur tout le dataset
72
+ results = classifier(audio_data, batch_size=8)
73
+
74
+ predictions = []
75
+ for result in results:
76
+ # Check if result is a dictionary
77
+ if isinstance(result, dict):
78
+ # Get the label with the highest score
79
+ predicted_label = result['label']
80
+ else:
81
+ # If result is not a dictionary, access it as a list
82
+ predicted_label = result[0]['label'] # Assuming the dictionary is the first element
83
+
84
+ # Assign 1 for "environment", 0 for "chainsaw"
85
+ if predicted_label == 'environment':
86
+ predictions.append(1)
87
+ else:
88
+ predictions.append(0)
89
+
90
+ #--------------------------------------------------------------------------------------------
91
+ # YOUR MODEL INFERENCE STOPS HERE
92
+ #--------------------------------------------------------------------------------------------
93
+
94
+ # Stop tracking emissions
95
+ emissions_data = tracker.stop_task()
96
+
97
+ # Calculate accuracy
98
+ accuracy = accuracy_score(true_labels, predictions)
99
+
100
+ # Prepare results dictionary
101
+ results = {
102
+ "username": username,
103
+ "space_url": space_url,
104
+ "submission_timestamp": datetime.now().isoformat(),
105
+ "model_description": DESCRIPTION,
106
+ "accuracy": float(accuracy),
107
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
108
+ "emissions_gco2eq": emissions_data.emissions * 1000,
109
+ "emissions_data": clean_emissions_data(emissions_data),
110
+ "api_route": ROUTE,
111
+ "dataset_config": {
112
+ "dataset_name": request.dataset_name,
113
+ "test_size": request.test_size,
114
+ "test_seed": request.test_seed
115
+ }
116
+ }
117
+
118
+ return results