CindyDelage commited on
Commit
81ffeb5
·
verified ·
1 Parent(s): dc319ca

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +13 -10
tasks/audio.py CHANGED
@@ -25,12 +25,12 @@ ROUTE = "/audio"
25
 
26
  device = 0 if torch.cuda.is_available() else -1 # Choix du périphérique GPU si dispo
27
 
28
- # Feature extraction
29
- feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
30
-
31
- def preprocess_function(examples):
32
- audio_arrays = [x["array"] for x in examples["audio"]]
33
- return feature_extractor(audio_arrays, sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt")
34
 
35
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
36
  async def evaluate_audio(request: AudioEvaluationRequest):
@@ -56,7 +56,11 @@ async def evaluate_audio(request: AudioEvaluationRequest):
56
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
57
 
58
  # Prétraitement en streaming avec fonction explicite
59
- encoded_data_test = test_dataset.map(preprocess_function, remove_columns="audio", batched=True)#, keep_in_memory=False)
 
 
 
 
60
 
61
  del feature_extractor # Libération explicite
62
  gc.collect()
@@ -67,10 +71,9 @@ async def evaluate_audio(request: AudioEvaluationRequest):
67
  predictions = []
68
  logging.info("Début des prédictions par batch")
69
 
70
- for data in encoded_data_test:
71
- logging.info(data)
72
  with torch.no_grad():
73
- result = classifier(np.asarray(data["input_values"]))
74
 
75
  predicted_label = result[0]['label']
76
  predictions.append(1 if predicted_label == 'environment' else 0)
 
25
 
26
  device = 0 if torch.cuda.is_available() else -1 # Choix du périphérique GPU si dispo
27
 
28
+ def preprocess_function(example, feature_extractor):
29
+ return feature_extractor(
30
+ example["audio"]["array"],
31
+ sampling_rate=feature_extractor.sampling_rate,
32
+ return_tensors="pt"
33
+ )
34
 
35
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
36
  async def evaluate_audio(request: AudioEvaluationRequest):
 
56
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
57
 
58
  # Prétraitement en streaming avec fonction explicite
59
+ test_dataset = test_dataset.map(
60
+ preprocess_function,
61
+ fn_kwargs={"feature_extractor": feature_extractor},
62
+ remove_columns=["audio"]
63
+ )
64
 
65
  del feature_extractor # Libération explicite
66
  gc.collect()
 
71
  predictions = []
72
  logging.info("Début des prédictions par batch")
73
 
74
+ for data in test_dataset:
 
75
  with torch.no_grad():
76
+ result = classifier(np.asarray(data["array"]))
77
 
78
  predicted_label = result[0]['label']
79
  predictions.append(1 if predicted_label == 'environment' else 0)