File size: 4,301 Bytes
e9b633a
 
8572a30
 
a6ce206
e9b633a
29697c3
 
 
7f14609
e9b633a
 
 
2031133
445a3a0
cc249c3
2031133
 
 
 
e9b633a
 
 
 
 
 
 
8572a30
7342f5a
7f9b2de
81ffeb5
aace22c
13717ee
81ffeb5
9914858
cc249c3
 
 
 
 
 
 
 
29697c3
e9b633a
 
 
 
 
 
29697c3
2031133
7342f5a
2031133
29697c3
e9b633a
80a180c
7342f5a
e9b633a
 
 
7f9b2de
 
8572a30
cc249c3
bbe5903
7342f5a
8572a30
cc249c3
 
2b6d28e
cc249c3
 
 
 
 
 
8572a30
f1fb962
 
 
8e71413
cc249c3
f1fb962
 
8572a30
f1fb962
 
 
 
 
 
39c4047
29697c3
f1fb962
 
e9b633a
f1fb962
e9b633a
 
 
cc249c3
8572a30
 
f78df1f
e9b633a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7342f5a
2031133
cc249c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset  
from sklearn.metrics import accuracy_score  
import numpy as np
import os
import torch
import gc
import psutil
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, pipeline
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv
import logging
import csv
import torch.nn.utils.prune as prune

# Configurer le logging
logging.basicConfig(level=logging.INFO)
logging.info("Début du fichier python")
load_dotenv()

router = APIRouter()

DESCRIPTION = "Random Baseline"
ROUTE = "/audio"

device = 0 if torch.cuda.is_available() else -1  

def preprocess_function(example, feature_extractor):
    return feature_extractor(
        [x["array"] for x in example["audio"]], 
        sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt"
    )

def apply_pruning(model, amount=0.3):
    """Applique un pruning sur les poids du modèle."""
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
    """
    Evaluate audio classification for rainforest sound detection.
    """
    # Get space info
    username, space_url = get_space_info()
    
    logging.info("Chargement des données")
    dataset = load_dataset(request.dataset_name, streaming=True, token=os.getenv("HF_TOKEN"))
    logging.info("Données chargées")
    
    test_dataset = dataset["test"]
    del dataset
    
    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")

    feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")

    test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True, batch_size=32)  
    
    gc.collect()

    model_name = "CindyDelage/Challenge_HuggingFace_DFG_FrugalAI"
    model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
    
    # Appliquer la quantification dynamique et le pruning
    model.eval()
    model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)
    model = apply_pruning(model, amount=0.3)  # Prune 30% des poids linéaires
    
    classifier = pipeline("audio-classification", model=model, feature_extractor=feature_extractor, device=device)
    predictions = []  
    logging.info("Début des prédictions par batch")
    for data in iter(test_dataset):
        with torch.no_grad():
            result = classifier(np.asarray(data["input_values"]), batch_size=1)
        
        predicted_label = result[0]['label']
        label = 1 if predicted_label == 'environment' else 0
        predictions.append(label)  
        
        # Nettoyer la mémoire après chaque itération
        del result
        del label 
        torch.cuda.empty_cache()
        gc.collect()
        
    logging.info("Fin des prédictions")
    del classifier 
    del feature_extractor 
    
    gc.collect()
    # Stop tracking emissions
    emissions_data = tracker.stop_task()
    
    true_labels = []  
    for example in test_dataset:   
        true_labels.append(example["label"])   

    accuracy = accuracy_score(true_labels, predictions)
    
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    logging.info("Returning results")
    return results