Spaces:
Sleeping
Sleeping
import os | |
import re | |
import spaces | |
import shutil | |
import zipfile | |
import torch | |
import numpy as np | |
import pandas as pd | |
from pathlib import Path | |
import gradio as gr | |
from pydub import AudioSegment | |
from transformers import pipeline | |
# ------------------------------------------------- | |
# 1. Configuration et Initialisation | |
# ------------------------------------------------- | |
MODEL_NAME = "openai/whisper-large-v3" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialisation du modèle Whisper | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME, | |
device=device, | |
model_kwargs={"low_cpu_mem_usage": True}, | |
) | |
# Création du répertoire temporaire pour stocker les extraits audio | |
TEMP_DIR = "./temp_audio" | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
def init_metadata_state(): | |
return [] | |
# ------------------------------------------------- | |
# 2. Transcription de l'audio avec Whisper (Timestamps de fin + Marge de Sécurité) | |
# ------------------------------------------------- | |
def correct_typography(text): | |
text = re.sub(r"\b([lLdDmMcCjJnNsStT]) ['’] (\w)", r"\1'\2", text) # Corrige les espaces autour des apostrophes | |
return text | |
def transcribe_audio(audio_path): | |
if not audio_path: | |
print("[LOG] Aucun fichier audio fourni.") | |
return "Aucun fichier audio fourni", None, [], "" | |
print(f"[LOG] Début de la transcription de {audio_path}...") | |
result = pipe(audio_path, return_timestamps="word") | |
words = result.get("chunks", []) | |
if not words: | |
print("[LOG ERROR] Erreur : Aucun timestamp détecté.") | |
return "Erreur : Aucun timestamp détecté.", None, [], "" | |
raw_transcription = " ".join([w["text"] for w in words]) | |
# 🔄 Correction typographique AVANT affichage | |
raw_transcription = correct_typography(raw_transcription) | |
# 🔄 Ajout des timestamps de fin avec marge de sécurité | |
MARGIN = 0.06 # 60ms | |
word_timestamps = [] | |
for i, w in enumerate(words): | |
start_time = w["timestamp"][0] | |
end_time = w["timestamp"][1] if w["timestamp"][1] is not None else start_time + 0.5 | |
# Vérifier qu'on ne dépasse pas le début du mot suivant | |
if i < len(words) - 1: | |
next_start_time = words[i + 1]["timestamp"][0] | |
end_time = min(end_time + MARGIN, next_start_time - 0.01) # On laisse 10ms de sécurité | |
word_timestamps.append((w["text"], start_time, end_time)) | |
transcription_with_timestamps = " ".join([f"{w[0]}[{w[1]:.2f}-{w[2]:.2f}]" for w in word_timestamps]) | |
print(f"[LOG] Transcription brute corrigée : {raw_transcription}") | |
return raw_transcription, word_timestamps, transcription_with_timestamps, audio_path | |
# ------------------------------------------------- | |
# 3. Enregistrement des segments définis par l'utilisateur (Affichage sur Interface) | |
# ------------------------------------------------- | |
def save_segments(table_data): | |
print("[LOG] Enregistrement des segments définis par l'utilisateur...") | |
formatted_data = [] | |
confirmation_message = "**📌 Segments enregistrés :**\n" | |
for i, row in table_data.iterrows(): | |
text, start_time, end_time = row["Texte"], row["Début (s)"], row["Fin (s)"] | |
segment_id = f"seg_{i+1:02d}" | |
try: | |
start_time = str(start_time).replace(",", ".") | |
end_time = str(end_time).replace(",", ".") | |
if not start_time.replace(".", "").isdigit() or not end_time.replace(".", "").isdigit(): | |
raise ValueError("Valeurs de timestamps invalides") | |
start_time = float(start_time) | |
end_time = float(end_time) | |
if start_time < 0 or end_time <= start_time: | |
raise ValueError("Valeurs incohérentes") | |
formatted_data.append([text, start_time, end_time, segment_id]) | |
log_message = f"- `{segment_id}` | **Texte** : {text} | ⏱ **{start_time:.2f}s - {end_time:.2f}s**" | |
confirmation_message += log_message + "\n" | |
print(f"[LOG] {log_message}") | |
except ValueError as e: | |
print(f"[LOG ERROR] Erreur de conversion des timestamps : {e}") | |
return pd.DataFrame(), "❌ **Erreur** : Vérifiez que les valeurs sont bien des nombres valides." | |
return pd.DataFrame(formatted_data, columns=["Texte", "Début (s)", "Fin (s)", "ID"]), confirmation_message | |
# ------------------------------------------------- | |
# 4. Génération du fichier ZIP | |
# ------------------------------------------------- | |
def generate_zip(metadata_state, audio_path, zip_name): | |
if isinstance(metadata_state, tuple): | |
metadata_state = metadata_state[0] # Extraire le DataFrame si c'est un tuple | |
if metadata_state is None or metadata_state.empty: | |
print("[LOG ERROR] Aucun segment valide trouvé pour la génération du ZIP.") | |
return None | |
zip_folder_name = f"{zip_name}_dataset" | |
zip_path = os.path.join(TEMP_DIR, f"{zip_folder_name}.zip") | |
if os.path.exists(zip_path): | |
os.remove(zip_path) | |
metadata_csv_path = os.path.join(TEMP_DIR, f"{zip_name}_metadata.csv") | |
# Assurer que les ID et fichiers audio correspondent | |
metadata_state["ID"] = [f"{zip_name}_seg_{i+1:02d}" for i in range(len(metadata_state))] | |
# Ajouter une colonne "Commentaires" vide | |
metadata_state["Commentaires"] = "" | |
# Réorganiser l’ordre des colonnes | |
metadata_state = metadata_state[["ID", "Texte", "Début (s)", "Fin (s)", "Commentaires"]] | |
# Sauvegarde du fichier CSV | |
metadata_state.to_csv(metadata_csv_path, sep="|", index=False) | |
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: | |
zf.write(metadata_csv_path, "metadata.csv") | |
original_audio = AudioSegment.from_file(audio_path) | |
for _, row in metadata_state.iterrows(): | |
start_ms, end_ms = int(row["Début (s)"] * 1000), int(row["Fin (s)"] * 1000) | |
segment_audio = original_audio[start_ms:end_ms] | |
segment_filename = f"{row['ID']}.wav" | |
segment_path = os.path.join(TEMP_DIR, segment_filename) | |
segment_audio.export(segment_path, format="wav") | |
zf.write(segment_path, segment_filename) | |
print("[LOG] Fichier ZIP généré avec succès.") | |
return zip_path | |
# ------------------------------------------------- | |
# 5. Interface utilisateur Gradio | |
# ------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Application de Découpe Audio") | |
metadata_state = gr.State(init_metadata_state()) | |
audio_input = gr.Audio(type="filepath", label="Fichier audio") | |
zip_name = gr.Textbox(label="Nom du fichier ZIP", interactive=True) | |
raw_transcription = gr.Textbox(label="Transcription", interactive=True) | |
transcription_timestamps = gr.Textbox(label="Transcription avec Timestamps", interactive=True) | |
table = gr.Dataframe(headers=["Texte", "Début (s)", "Fin (s)"], datatype=["str", "str", "str"], row_count=(1, "dynamic")) | |
save_button = gr.Button("Enregistrer les segments") | |
save_message = gr.Markdown(label="📢 **Message de confirmation**") | |
generate_button = gr.Button("Générer ZIP") | |
zip_file = gr.File(label="Télécharger le ZIP") | |
word_timestamps = gr.State() | |
audio_input.change(transcribe_audio, inputs=audio_input, outputs=[raw_transcription, word_timestamps, transcription_timestamps, audio_input]) | |
save_button.click(save_segments, inputs=table, outputs=[metadata_state, save_message]) | |
generate_button.click(generate_zip, inputs=[metadata_state, audio_input, zip_name], outputs=zip_file) | |
demo.queue().launch() |