ChiBenevisamPas commited on
Commit
442428c
·
verified ·
1 Parent(s): 20bdd0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -51
app.py CHANGED
@@ -1,46 +1,16 @@
1
  import gradio as gr
2
  import whisper
3
  import os
4
- from transformers import MarianMTModel, MarianTokenizer
5
 
6
  # Load the Whisper model
7
  model = whisper.load_model("base") # Choose 'tiny', 'base', 'small', 'medium', or 'large'
8
 
9
- # Load MarianMT translation model for Persian
10
- def load_translation_model(target_language):
11
- # Map of language codes to MarianMT model names
12
- lang_models = {
13
- "fa": "Helsinki-NLP/opus-mt-en-fa", # English to Persian (Farsi)
14
- "es": "Helsinki-NLP/opus-mt-en-es", # English to Spanish
15
- "fr": "Helsinki-NLP/opus-mt-en-fr", # English to French
16
- # Add more models for other languages as needed
17
- }
18
-
19
- model_name = lang_models.get(target_language)
20
- if not model_name:
21
- raise ValueError(f"Translation model for {target_language} not found")
22
-
23
- tokenizer = MarianTokenizer.from_pretrained(model_name)
24
- translation_model = MarianMTModel.from_pretrained(model_name)
25
- return tokenizer, translation_model
26
-
27
- def translate_text(text, tokenizer, model):
28
- # Tokenize the input text and translate
29
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
30
- translated = model.generate(**inputs)
31
- return tokenizer.decode(translated[0], skip_special_tokens=True)
32
-
33
- def write_srt(transcription, output_file, tokenizer=None, translation_model=None):
34
  with open(output_file, "w") as f:
35
  for i, segment in enumerate(transcription['segments']):
36
  start = segment['start']
37
  end = segment['end']
38
  text = segment['text']
39
-
40
- # Translate text if translation model is provided
41
- if translation_model:
42
- text = translate_text(text, tokenizer, translation_model)
43
-
44
  # Format timestamps for SRT
45
  start_time = whisper.utils.format_timestamp(start)
46
  end_time = whisper.utils.format_timestamp(end)
@@ -49,36 +19,25 @@ def write_srt(transcription, output_file, tokenizer=None, translation_model=None
49
  f.write(f"{start_time} --> {end_time}\n")
50
  f.write(f"{text.strip()}\n\n")
51
 
52
- def transcribe_video(video_file, language, target_language):
53
  # Transcribe the video to generate subtitles
54
- result = model.transcribe(video_file.name, language=language)
55
-
56
- # Get the video file name without extension and create the SRT file name
57
- video_name = os.path.splitext(os.path.basename(video_file.name))[0]
58
- srt_file = f"{video_name}_{target_language}.srt"
59
 
60
- # Load the translation model for the selected language
61
- if target_language != "en": # No translation needed if target is English
62
- tokenizer, translation_model = load_translation_model(target_language)
63
- else:
64
- tokenizer, translation_model = None, None
65
 
66
- # Write the transcription as subtitles (with optional translation)
67
- write_srt(result, srt_file, tokenizer, translation_model)
68
 
69
  return srt_file
70
 
71
  # Gradio interface
72
  iface = gr.Interface(
73
  fn=transcribe_video,
74
- inputs=[
75
- gr.File(label="Upload Video"),
76
- gr.Dropdown(label="Select Video Language", choices=["en", "es", "fr", "de", "it", "pt"], value="en"),
77
- gr.Dropdown(label="Select Subtitle Language", choices=["en", "fa", "es", "fr"], value="fa") # Added Persian (fa) as an option
78
- ],
79
  outputs=gr.File(label="Download Subtitles"),
80
- title="Video Subtitle Generator with Translation",
81
- description="Upload a video file to generate subtitles using Whisper. Select the language of the video and the target subtitle language."
82
  )
83
 
84
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import whisper
3
  import os
 
4
 
5
  # Load the Whisper model
6
  model = whisper.load_model("base") # Choose 'tiny', 'base', 'small', 'medium', or 'large'
7
 
8
+ def write_srt(transcription, output_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  with open(output_file, "w") as f:
10
  for i, segment in enumerate(transcription['segments']):
11
  start = segment['start']
12
  end = segment['end']
13
  text = segment['text']
 
 
 
 
 
14
  # Format timestamps for SRT
15
  start_time = whisper.utils.format_timestamp(start)
16
  end_time = whisper.utils.format_timestamp(end)
 
19
  f.write(f"{start_time} --> {end_time}\n")
20
  f.write(f"{text.strip()}\n\n")
21
 
22
+ def transcribe_video(video_file):
23
  # Transcribe the video to generate subtitles
24
+ result = model.transcribe(video_file)
 
 
 
 
25
 
26
+ # Save the transcription to an .srt file
27
+ srt_file = "generated_subtitles.srt"
 
 
 
28
 
29
+ # Write the transcription as subtitles
30
+ write_srt(result, srt_file)
31
 
32
  return srt_file
33
 
34
  # Gradio interface
35
  iface = gr.Interface(
36
  fn=transcribe_video,
37
+ inputs=gr.File(label="Upload Video"),
 
 
 
 
38
  outputs=gr.File(label="Download Subtitles"),
39
+ title="Video Subtitle Generator",
40
+ description="Upload a video file to generate subtitles using Whisper."
41
  )
42
 
43
  if __name__ == "__main__":