Robertomarting commited on
Commit
642e7e6
verified
1 Parent(s): fa869ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -36
app.py CHANGED
@@ -26,7 +26,7 @@ def is_white_noise(audio, threshold=0.75):
26
  kurt = kurtosis(audio)
27
  return np.abs(kurt) < 0.1 and np.mean(np.abs(audio)) < threshold
28
 
29
- def create_audio_dataframe(audio_tuple, target_sr=16000, target_duration=1.0):
30
  data = []
31
  target_length = int(target_sr * target_duration)
32
 
@@ -49,60 +49,41 @@ def create_audio_dataframe(audio_tuple, target_sr=16000, target_duration=1.0):
49
  if len(audio_data) > target_length:
50
  for i in range(0, len(audio_data), target_length):
51
  segment = audio_data[i:i + target_length]
52
- if len(segment) == target_length:
53
- if not is_white_noise(segment):
54
- data.append({"audio": segment})
55
  else:
56
  if not is_white_noise(audio_data):
57
- data.append({"audio": audio_data})
58
 
59
- df = pd.DataFrame(data)
60
- return df
61
 
62
- def convert_bytes_to_float64(byte_list):
63
- return [float(i) for i in byte_list]
64
-
65
- def preprocess_function(examples):
66
- audio_arrays = examples["audio"]
67
  inputs = processor(
68
- audio_arrays,
69
  padding=True,
70
  sampling_rate=processor.sampling_rate,
71
  max_length=int(processor.sampling_rate * 1),
72
  truncation=True,
 
73
  )
74
  return inputs
75
 
76
  def predict_audio(audio):
77
- df = create_audio_dataframe(audio)
78
- df['audio'] = df['audio'].apply(convert_bytes_to_float64)
79
 
80
- # Convertir el dataframe a Dataset
81
- predict_dataset = Dataset.from_pandas(df)
82
- dataset = DatasetDict({
83
- 'train': predict_dataset
84
- })
85
-
86
- if '__index_level_0__' in dataset['train'].column_names:
87
- dataset['train'] = dataset['train'].remove_columns(['__index_level_0__'])
88
-
89
- encoded_dataset = dataset.map(preprocess_function, remove_columns=["audio"], batched=True)
90
-
91
- # Crear el Trainer para la predicci贸n
92
- trainer = Trainer(
93
- model=model,
94
- eval_dataset=encoded_dataset["train"]
95
- )
96
 
97
  # Realizar las predicciones
98
- predictions_output = trainer.predict(encoded_dataset["train"].with_format("torch"))
 
99
 
100
- # Obtener las predicciones y etiquetas verdaderas
101
- predictions = predictions_output.predictions
102
- labels = predictions_output.label_ids
103
 
104
  # Convertir logits a probabilidades
105
- probabilities = F.softmax(torch.tensor(predictions), dim=-1).numpy()
106
  predicted_classes = probabilities.argmax(axis=1)
107
 
108
  # Obtener la etiqueta m谩s com煤n
 
26
  kurt = kurtosis(audio)
27
  return np.abs(kurt) < 0.1 and np.mean(np.abs(audio)) < threshold
28
 
29
+ def process_audio(audio_tuple, target_sr=16000, target_duration=1.0):
30
  data = []
31
  target_length = int(target_sr * target_duration)
32
 
 
49
  if len(audio_data) > target_length:
50
  for i in range(0, len(audio_data), target_length):
51
  segment = audio_data[i:i + target_length]
52
+ if len(segment) == target_length and not is_white_noise(segment):
53
+ data.append(segment)
 
54
  else:
55
  if not is_white_noise(audio_data):
56
+ data.append(audio_data)
57
 
58
+ return data
 
59
 
60
+ def preprocess_audio(audio_segments):
 
 
 
 
61
  inputs = processor(
62
+ audio_segments,
63
  padding=True,
64
  sampling_rate=processor.sampling_rate,
65
  max_length=int(processor.sampling_rate * 1),
66
  truncation=True,
67
+ return_tensors="pt" # Directamente retorna tensores de PyTorch
68
  )
69
  return inputs
70
 
71
  def predict_audio(audio):
72
+ # Procesar el audio y obtener las listas de numpy
73
+ audio_segments = process_audio(audio)
74
 
75
+ # Preprocesar el audio (aplica directamente al array numpy)
76
+ inputs = preprocess_audio(audio_segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Realizar las predicciones
79
+ with torch.no_grad():
80
+ outputs = model(**inputs)
81
 
82
+ # Obtener los logits de las predicciones
83
+ logits = outputs.logits
 
84
 
85
  # Convertir logits a probabilidades
86
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).numpy()
87
  predicted_classes = probabilities.argmax(axis=1)
88
 
89
  # Obtener la etiqueta m谩s com煤n