datasetTTS / app.py
Woziii's picture
Update app.py
d0b0179 verified
import os
import re
import spaces
import shutil
import zipfile
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import gradio as gr
from pydub import AudioSegment
from transformers import pipeline
# -------------------------------------------------
# 1. Configuration et Initialisation
# -------------------------------------------------
MODEL_NAME = "openai/whisper-large-v3"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialisation du modèle Whisper
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
device=device,
model_kwargs={"low_cpu_mem_usage": True},
)
# Création du répertoire temporaire pour stocker les extraits audio
TEMP_DIR = "./temp_audio"
os.makedirs(TEMP_DIR, exist_ok=True)
def init_metadata_state():
return []
# -------------------------------------------------
# 2. Transcription de l'audio avec Whisper (Timestamps de fin + Marge de Sécurité)
# -------------------------------------------------
def correct_typography(text):
text = re.sub(r"\b([lLdDmMcCjJnNsStT]) ['’] (\w)", r"\1'\2", text) # Corrige les espaces autour des apostrophes
return text
@spaces.GPU(duration=120)
def transcribe_audio(audio_path):
if not audio_path:
print("[LOG] Aucun fichier audio fourni.")
return "Aucun fichier audio fourni", None, [], ""
print(f"[LOG] Début de la transcription de {audio_path}...")
result = pipe(audio_path, return_timestamps="word")
words = result.get("chunks", [])
if not words:
print("[LOG ERROR] Erreur : Aucun timestamp détecté.")
return "Erreur : Aucun timestamp détecté.", None, [], ""
raw_transcription = " ".join([w["text"] for w in words])
# 🔄 Correction typographique AVANT affichage
raw_transcription = correct_typography(raw_transcription)
# 🔄 Ajout des timestamps de fin avec marge de sécurité
MARGIN = 0.06 # 60ms
word_timestamps = []
for i, w in enumerate(words):
start_time = w["timestamp"][0]
end_time = w["timestamp"][1] if w["timestamp"][1] is not None else start_time + 0.5
# Vérifier qu'on ne dépasse pas le début du mot suivant
if i < len(words) - 1:
next_start_time = words[i + 1]["timestamp"][0]
end_time = min(end_time + MARGIN, next_start_time - 0.01) # On laisse 10ms de sécurité
word_timestamps.append((w["text"], start_time, end_time))
transcription_with_timestamps = " ".join([f"{w[0]}[{w[1]:.2f}-{w[2]:.2f}]" for w in word_timestamps])
print(f"[LOG] Transcription brute corrigée : {raw_transcription}")
return raw_transcription, word_timestamps, transcription_with_timestamps, audio_path
# -------------------------------------------------
# 3. Enregistrement des segments définis par l'utilisateur (Affichage sur Interface)
# -------------------------------------------------
def save_segments(table_data):
print("[LOG] Enregistrement des segments définis par l'utilisateur...")
formatted_data = []
confirmation_message = "**📌 Segments enregistrés :**\n"
for i, row in table_data.iterrows():
text, start_time, end_time = row["Texte"], row["Début (s)"], row["Fin (s)"]
segment_id = f"seg_{i+1:02d}"
try:
start_time = str(start_time).replace(",", ".")
end_time = str(end_time).replace(",", ".")
if not start_time.replace(".", "").isdigit() or not end_time.replace(".", "").isdigit():
raise ValueError("Valeurs de timestamps invalides")
start_time = float(start_time)
end_time = float(end_time)
if start_time < 0 or end_time <= start_time:
raise ValueError("Valeurs incohérentes")
formatted_data.append([text, start_time, end_time, segment_id])
log_message = f"- `{segment_id}` | **Texte** : {text} | ⏱ **{start_time:.2f}s - {end_time:.2f}s**"
confirmation_message += log_message + "\n"
print(f"[LOG] {log_message}")
except ValueError as e:
print(f"[LOG ERROR] Erreur de conversion des timestamps : {e}")
return pd.DataFrame(), "❌ **Erreur** : Vérifiez que les valeurs sont bien des nombres valides."
return pd.DataFrame(formatted_data, columns=["Texte", "Début (s)", "Fin (s)", "ID"]), confirmation_message
# -------------------------------------------------
# 4. Génération du fichier ZIP
# -------------------------------------------------
def generate_zip(metadata_state, audio_path, zip_name):
if isinstance(metadata_state, tuple):
metadata_state = metadata_state[0] # Extraire le DataFrame si c'est un tuple
if metadata_state is None or metadata_state.empty:
print("[LOG ERROR] Aucun segment valide trouvé pour la génération du ZIP.")
return None
zip_folder_name = f"{zip_name}_dataset"
zip_path = os.path.join(TEMP_DIR, f"{zip_folder_name}.zip")
if os.path.exists(zip_path):
os.remove(zip_path)
metadata_csv_path = os.path.join(TEMP_DIR, f"{zip_name}_metadata.csv")
# Assurer que les ID et fichiers audio correspondent
metadata_state["ID"] = [f"{zip_name}_seg_{i+1:02d}" for i in range(len(metadata_state))]
# Ajouter une colonne "Commentaires" vide
metadata_state["Commentaires"] = ""
# Réorganiser l’ordre des colonnes
metadata_state = metadata_state[["ID", "Texte", "Début (s)", "Fin (s)", "Commentaires"]]
# Sauvegarde du fichier CSV
metadata_state.to_csv(metadata_csv_path, sep="|", index=False)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.write(metadata_csv_path, "metadata.csv")
original_audio = AudioSegment.from_file(audio_path)
for _, row in metadata_state.iterrows():
start_ms, end_ms = int(row["Début (s)"] * 1000), int(row["Fin (s)"] * 1000)
segment_audio = original_audio[start_ms:end_ms]
segment_filename = f"{row['ID']}.wav"
segment_path = os.path.join(TEMP_DIR, segment_filename)
segment_audio.export(segment_path, format="wav")
zf.write(segment_path, segment_filename)
print("[LOG] Fichier ZIP généré avec succès.")
return zip_path
# -------------------------------------------------
# 5. Interface utilisateur Gradio
# -------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Application de Découpe Audio")
metadata_state = gr.State(init_metadata_state())
audio_input = gr.Audio(type="filepath", label="Fichier audio")
zip_name = gr.Textbox(label="Nom du fichier ZIP", interactive=True)
raw_transcription = gr.Textbox(label="Transcription", interactive=True)
transcription_timestamps = gr.Textbox(label="Transcription avec Timestamps", interactive=True)
table = gr.Dataframe(headers=["Texte", "Début (s)", "Fin (s)"], datatype=["str", "str", "str"], row_count=(1, "dynamic"))
save_button = gr.Button("Enregistrer les segments")
save_message = gr.Markdown(label="📢 **Message de confirmation**")
generate_button = gr.Button("Générer ZIP")
zip_file = gr.File(label="Télécharger le ZIP")
word_timestamps = gr.State()
audio_input.change(transcribe_audio, inputs=audio_input, outputs=[raw_transcription, word_timestamps, transcription_timestamps, audio_input])
save_button.click(save_segments, inputs=table, outputs=[metadata_state, save_message])
generate_button.click(generate_zip, inputs=[metadata_state, audio_input, zip_name], outputs=zip_file)
demo.queue().launch()