Spaces:
Runtime error
Runtime error
import gradio as gr | |
import whisper | |
import os | |
from transformers import MarianMTModel, MarianTokenizer | |
# Load the Whisper model | |
model = whisper.load_model("base") # Choose 'tiny', 'base', 'small', 'medium', or 'large' | |
# Load MarianMT translation model for different languages | |
def load_translation_model(target_language): | |
# Map of language codes to MarianMT model names | |
lang_models = { | |
"fa": "Helsinki-NLP/opus-mt-en-fa", # English to Persian (Farsi) | |
"es": "Helsinki-NLP/opus-mt-en-es", # English to Spanish | |
"fr": "Helsinki-NLP/opus-mt-en-fr", # English to French | |
# Add more models for other languages as needed | |
} | |
model_name = lang_models.get(target_language) | |
if not model_name: | |
raise ValueError(f"Translation model for {target_language} not found") | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
translation_model = MarianMTModel.from_pretrained(model_name) | |
return tokenizer, translation_model | |
def translate_text(text, tokenizer, model): | |
# Tokenize the input text and translate | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
translated = model.generate(**inputs) | |
return tokenizer.decode(translated[0], skip_special_tokens=True) | |
def write_srt(transcription, output_file, tokenizer=None, translation_model=None): | |
with open(output_file, "w") as f: | |
for i, segment in enumerate(transcription['segments']): | |
start = segment['start'] | |
end = segment['end'] | |
text = segment['text'] | |
# Translate text if translation model is provided | |
if translation_model: | |
text = translate_text(text, tokenizer, translation_model) | |
# Format timestamps for SRT | |
start_time = whisper.utils.format_timestamp(start) | |
end_time = whisper.utils.format_timestamp(end) | |
print(f"Writing subtitle {i + 1}: {text.strip()} ({start_time} --> {end_time})") # Debug print | |
f.write(f"{i + 1}\n") | |
f.write(f"{start_time} --> {end_time}\n") | |
f.write(f"{text.strip()}\n\n") | |
def transcribe_video(video_file, language, target_language): | |
# Transcribe the video to generate subtitles | |
result = model.transcribe(video_file.name, language=language) | |
# Get the video file name without extension and use the same name for the SRT file | |
video_name = os.path.splitext(video_file.name)[0] | |
srt_file = f"{video_name}.srt" | |
# Load the translation model for the selected subtitle language | |
if target_language != "en": # No translation needed if target is English | |
tokenizer, translation_model = load_translation_model(target_language) | |
else: | |
tokenizer, translation_model = None, None | |
# Write the transcription as subtitles (with optional translation) | |
write_srt(result, srt_file, tokenizer, translation_model) | |
return srt_file | |
# Gradio interface | |
iface = gr.Interface( | |
fn=transcribe_video, | |
inputs=[ | |
gr.File(label="Upload Video"), | |
gr.Dropdown(label="Select Video Language", choices=["en", "es", "fr", "de", "it", "pt"], value="en"), | |
gr.Dropdown(label="Select Subtitle Language", choices=["en", "fa", "es", "fr"], value="fa") # Added Persian (fa) as an option | |
], | |
outputs=gr.File(label="Download Subtitles"), | |
title="Video Subtitle Generator with Translation", | |
description="Upload a video file to generate subtitles using Whisper. Select the language of the video and the target subtitle language." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |