CindyDelage commited on
Commit
7f14609
·
verified ·
1 Parent(s): 8572a30

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +8 -4
tasks/audio.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import torch
8
  import gc
9
  import psutil
10
- from transformers import AutoFeatureExtractor, pipeline
11
  from .utils.evaluation import AudioEvaluationRequest
12
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
  from dotenv import load_dotenv
@@ -58,9 +58,13 @@ async def evaluate_audio(request: AudioEvaluationRequest):
58
 
59
  gc.collect()
60
 
61
- model = torch.load("CindyDelage/Challenge_HuggingFace_DFG_FrugalAI")
62
- model.eval()
63
- model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)
 
 
 
 
64
 
65
  classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI",feature_extractor=feature_extractor, device=device)
66
  predictions = []
 
7
  import torch
8
  import gc
9
  import psutil
10
+ from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, pipeline
11
  from .utils.evaluation import AudioEvaluationRequest
12
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
  from dotenv import load_dotenv
 
58
 
59
  gc.collect()
60
 
61
+ # Charger le modèle depuis Hugging Face Hub (par exemple, à partir de l'ID du modèle)
62
+ model_name = "CindyDelage/Challenge_HuggingFace_DFG_FrugalAI" # Nom du modèle dans Hugging Face Hub
63
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) # Charger le modèle pour la classification audio
64
+
65
+ # Appliquer la quantification dynamique si nécessaire
66
+ model.eval() # Mettre le modèle en mode évaluation
67
+ model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8) # Appliquer la quantification dynamique
68
 
69
  classifier = pipeline("audio-classification", model="CindyDelage/Challenge_HuggingFace_DFG_FrugalAI",feature_extractor=feature_extractor, device=device)
70
  predictions = []