ChiBenevisamPas's picture
Update app.py
87b68ab verified
raw
history blame
3.64 kB
import gradio as gr
import whisper
import os
from transformers import MarianMTModel, MarianTokenizer
# Load the Whisper model
model = whisper.load_model("base") # Choose 'tiny', 'base', 'small', 'medium', or 'large'
# Load MarianMT translation model for different languages
def load_translation_model(target_language):
# Map of language codes to MarianMT model names
lang_models = {
"fa": "Helsinki-NLP/opus-mt-en-fa", # English to Persian (Farsi)
"es": "Helsinki-NLP/opus-mt-en-es", # English to Spanish
"fr": "Helsinki-NLP/opus-mt-en-fr", # English to French
# Add more models for other languages as needed
}
model_name = lang_models.get(target_language)
if not model_name:
raise ValueError(f"Translation model for {target_language} not found")
tokenizer = MarianTokenizer.from_pretrained(model_name)
translation_model = MarianMTModel.from_pretrained(model_name)
return tokenizer, translation_model
def translate_text(text, tokenizer, model):
# Tokenize the input text and translate
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
translated = model.generate(**inputs)
return tokenizer.decode(translated[0], skip_special_tokens=True)
def write_srt(transcription, output_file, tokenizer=None, translation_model=None):
with open(output_file, "w") as f:
for i, segment in enumerate(transcription['segments']):
start = segment['start']
end = segment['end']
text = segment['text']
# Translate text if translation model is provided
if translation_model:
text = translate_text(text, tokenizer, translation_model)
# Format timestamps for SRT
start_time = whisper.utils.format_timestamp(start)
end_time = whisper.utils.format_timestamp(end)
print(f"Writing subtitle {i + 1}: {text.strip()} ({start_time} --> {end_time})") # Debug print
f.write(f"{i + 1}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{text.strip()}\n\n")
def transcribe_video(video_file, language, target_language):
# Transcribe the video to generate subtitles
result = model.transcribe(video_file.name, language=language)
# Get the video file name without extension and use the same name for the SRT file
video_name = os.path.splitext(video_file.name)[0]
srt_file = f"{video_name}.srt"
# Load the translation model for the selected subtitle language
if target_language != "en": # No translation needed if target is English
tokenizer, translation_model = load_translation_model(target_language)
else:
tokenizer, translation_model = None, None
# Write the transcription as subtitles (with optional translation)
write_srt(result, srt_file, tokenizer, translation_model)
return srt_file
# Gradio interface
iface = gr.Interface(
fn=transcribe_video,
inputs=[
gr.File(label="Upload Video"),
gr.Dropdown(label="Select Video Language", choices=["en", "es", "fr", "de", "it", "pt"], value="en"),
gr.Dropdown(label="Select Subtitle Language", choices=["en", "fa", "es", "fr"], value="fa") # Added Persian (fa) as an option
],
outputs=gr.File(label="Download Subtitles"),
title="Video Subtitle Generator with Translation",
description="Upload a video file to generate subtitles using Whisper. Select the language of the video and the target subtitle language."
)
if __name__ == "__main__":
iface.launch()