vivekvar commited on
Commit
66d9db5
·
verified ·
1 Parent(s): b46b3df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -85
app.py CHANGED
@@ -1,85 +1,111 @@
1
- import streamlit as st
2
- from transformers import WhisperProcessor, WhisperForConditionalGeneration, RagTokenizer, RagRetriever, RagSequenceForGeneration
3
- import torch
4
- import soundfile as sf
5
- import librosa
6
- from moviepy.editor import VideoFileClip
7
- import os
8
-
9
- # Load Whisper base model and processor
10
- whisper_model_name = "openai/whisper-base"
11
- whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
12
- whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
13
-
14
- # Load RAG sequence model and tokenizer
15
- rag_model_name = "facebook/rag-sequence-nq"
16
- rag_tokenizer = RagTokenizer.from_pretrained(rag_model_name)
17
- rag_retriever = RagRetriever.from_pretrained(rag_model_name, index_name="exact", use_dummy_dataset=True, trust_remote_code=True)
18
- rag_model = RagSequenceForGeneration.from_pretrained(rag_model_name, retriever=rag_retriever)
19
-
20
- def transcribe_audio(audio_path, language="ru"):
21
- speech, rate = librosa.load(audio_path, sr=16000)
22
- inputs = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
23
- input_features = whisper_processor.feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_features
24
- predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=whisper_processor.get_decoder_prompt_ids(language=language, task="translate"))
25
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
26
- return transcription
27
-
28
- def translate_and_summarize(text):
29
- inputs = rag_tokenizer(text, return_tensors="pt")
30
- input_ids = inputs["input_ids"]
31
- attention_mask = inputs["attention_mask"]
32
- outputs = rag_model.generate(input_ids=input_ids, attention_mask=attention_mask)
33
- return rag_tokenizer.batch_decode(outputs, skip_special_tokens=True)
34
-
35
- def extract_audio_from_video(video_path, output_audio_path):
36
- video_clip = VideoFileClip(video_path)
37
- audio_clip = video_clip.audio
38
- if audio_clip is not None:
39
- audio_clip.write_audiofile(output_audio_path)
40
- return output_audio_path
41
- else:
42
- return None
43
-
44
- st.title("Audio and Video Transcription & Summarization")
45
-
46
- # Audio Upload Section
47
- st.header("Upload an Audio File")
48
- audio_file = st.file_uploader("Choose an audio file...", type=["wav", "mp3", "m4a"])
49
-
50
- if audio_file is not None:
51
- audio_path = os.path.join("/tmp", audio_file.name)
52
- with open(audio_path, "wb") as f:
53
- f.write(audio_file.getbuffer())
54
-
55
- st.audio(audio_file)
56
- st.write("Transcribing audio...")
57
- transcription = transcribe_audio(audio_path)
58
- st.write("Transcription:", transcription)
59
-
60
- st.write("Translating and summarizing...")
61
- summary = translate_and_summarize(transcription)
62
- st.write("Translated Summary:", summary)
63
- # Video Upload Section
64
- st.header("Upload a Video File")
65
- video_file = st.file_uploader("Choose a video file...", type=["mp4", "mkv", "avi", "mov"])
66
-
67
- if video_file is not None:
68
- video_path = os.path.join("/tmp", video_file.name)
69
- with open(video_path, "wb") as f:
70
- f.write(video_file.getbuffer())
71
-
72
- st.video(video_file)
73
- st.write("Extracting audio from video...")
74
- audio_path = extract_audio_from_video(video_path, "/tmp/extracted_audio.wav")
75
-
76
- if audio_path is not None:
77
- st.write("Transcribing audio...")
78
- transcription = transcribe_audio(audio_path)
79
- st.write("Transcription:", transcription)
80
-
81
- st.write("Translating and summarizing...")
82
- summary = translate_and_summarize(transcription)
83
- st.write("Translated Summary:", summary)
84
- else:
85
- st.write("No audio track found in the video file.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ import torch
4
+ import soundfile as sf
5
+ import librosa
6
+ from moviepy.editor import VideoFileClip
7
+ import os
8
+ import tempfile
9
+ import logging
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Load Whisper base model and processor
16
+ whisper_model_name = "openai/whisper-base"
17
+ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
18
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
19
+
20
+ # Load RAG sequence model and tokenizer
21
+ rag_model_name = "facebook/rag-sequence-nq"
22
+ rag_tokenizer = RagTokenizer.from_pretrained(rag_model_name)
23
+
24
+ # Try to load RagRetriever with trust_remote_code=True
25
+ try:
26
+ rag_retriever = RagRetriever.from_pretrained(
27
+ rag_model_name,
28
+ index_name="exact",
29
+ use_dummy_dataset=True,
30
+ trust_remote_code=True
31
+ )
32
+ logger.info("Successfully loaded RagRetriever with trust_remote_code=True")
33
+ except ValueError as e:
34
+ logger.error(f"Error loading RagRetriever: {e}")
35
+ st.error(f"Error loading RagRetriever: {e}")
36
+
37
+ rag_model = RagSequenceForGeneration.from_pretrained(rag_model_name, retriever=rag_retriever)
38
+
39
+ def transcribe_audio(audio_path, language="ru"):
40
+ speech, rate = librosa.load(audio_path, sr=16000)
41
+ inputs = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
42
+ input_features = whisper_processor.feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_features
43
+ predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=whisper_processor.get_decoder_prompt_ids(language=language, task="translate"))
44
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
45
+ return transcription
46
+
47
+ def translate_and_summarize(text):
48
+ inputs = rag_tokenizer(text, return_tensors="pt")
49
+ input_ids = inputs["input_ids"]
50
+ attention_mask = inputs["attention_mask"]
51
+ outputs = rag_model.generate(input_ids=input_ids, attention_mask=attention_mask)
52
+ return rag_tokenizer.batch_decode(outputs, skip_special_tokens=True)
53
+
54
+ def extract_audio_from_video(video_path, output_audio_path):
55
+ video_clip = VideoFileClip(video_path)
56
+ audio_clip = video_clip.audio
57
+ if audio_clip is not None:
58
+ audio_clip.write_audiofile(output_audio_path)
59
+ return output_audio_path
60
+ else:
61
+ return None
62
+
63
+ st.title("Audio and Video Transcription & Summarization")
64
+
65
+ # Audio Upload Section
66
+ st.header("Upload an Audio File")
67
+ audio_file = st.file_uploader("Choose an audio file...", type=["wav", "mp3", "m4a"])
68
+
69
+ if audio_file is not None:
70
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
71
+ tmp_file.write(audio_file.getbuffer())
72
+ audio_path = tmp_file.name
73
+
74
+ st.audio(audio_file)
75
+ st.write("Transcribing audio...")
76
+ try:
77
+ transcription = transcribe_audio(audio_path)
78
+ st.write("Transcription:", transcription)
79
+
80
+ st.write("Translating and summarizing...")
81
+ summary = translate_and_summarize(transcription)
82
+ st.write("Translated Summary:", summary)
83
+ except Exception as e:
84
+ st.error(f"An error occurred: {e}")
85
+
86
+ # Video Upload Section
87
+ st.header("Upload a Video File")
88
+ video_file = st.file_uploader("Choose a video file...", type=["mp4", "mkv", "avi", "mov"])
89
+
90
+ if video_file is not None:
91
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file:
92
+ tmp_file.write(video_file.getbuffer())
93
+ video_path = tmp_file.name
94
+
95
+ st.video(video_file)
96
+ st.write("Extracting audio from video...")
97
+ audio_path = extract_audio_from_video(video_path, tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name)
98
+
99
+ if audio_path is not None:
100
+ st.write("Transcribing audio...")
101
+ try:
102
+ transcription = transcribe_audio(audio_path)
103
+ st.write("Transcription:", transcription)
104
+
105
+ st.write("Translating and summarizing...")
106
+ summary = translate_and_summarize(transcription)
107
+ st.write("Translated Summary:", summary)
108
+ except Exception as e:
109
+ st.error(f"An error occurred: {e}")
110
+ else:
111
+ st.write("No audio track found in the video file.")