File size: 4,373 Bytes
66d9db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from transformers import WhisperProcessor, WhisperForConditionalGeneration, RagTokenizer, RagRetriever, RagSequenceForGeneration
import torch
import soundfile as sf
import librosa
from moviepy.editor import VideoFileClip
import os
import tempfile
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load Whisper base model and processor
whisper_model_name = "openai/whisper-base"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)

# Load RAG sequence model and tokenizer
rag_model_name = "facebook/rag-sequence-nq"
rag_tokenizer = RagTokenizer.from_pretrained(rag_model_name)

# Try to load RagRetriever with trust_remote_code=True
try:
    rag_retriever = RagRetriever.from_pretrained(
        rag_model_name,
        index_name="exact",
        use_dummy_dataset=True,
        trust_remote_code=True
    )
    logger.info("Successfully loaded RagRetriever with trust_remote_code=True")
except ValueError as e:
    logger.error(f"Error loading RagRetriever: {e}")
    st.error(f"Error loading RagRetriever: {e}")

rag_model = RagSequenceForGeneration.from_pretrained(rag_model_name, retriever=rag_retriever)

def transcribe_audio(audio_path, language="ru"):
    speech, rate = librosa.load(audio_path, sr=16000)
    inputs = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
    input_features = whisper_processor.feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_features
    predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=whisper_processor.get_decoder_prompt_ids(language=language, task="translate"))
    transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription

def translate_and_summarize(text):
    inputs = rag_tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    outputs = rag_model.generate(input_ids=input_ids, attention_mask=attention_mask)
    return rag_tokenizer.batch_decode(outputs, skip_special_tokens=True)

def extract_audio_from_video(video_path, output_audio_path):
    video_clip = VideoFileClip(video_path)
    audio_clip = video_clip.audio
    if audio_clip is not None:
        audio_clip.write_audiofile(output_audio_path)
        return output_audio_path
    else:
        return None

st.title("Audio and Video Transcription & Summarization")

# Audio Upload Section
st.header("Upload an Audio File")
audio_file = st.file_uploader("Choose an audio file...", type=["wav", "mp3", "m4a"])

if audio_file is not None:
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
        tmp_file.write(audio_file.getbuffer())
        audio_path = tmp_file.name
    
    st.audio(audio_file)
    st.write("Transcribing audio...")
    try:
        transcription = transcribe_audio(audio_path)
        st.write("Transcription:", transcription)
        
        st.write("Translating and summarizing...")
        summary = translate_and_summarize(transcription)
        st.write("Translated Summary:", summary)
    except Exception as e:
        st.error(f"An error occurred: {e}")

# Video Upload Section
st.header("Upload a Video File")
video_file = st.file_uploader("Choose a video file...", type=["mp4", "mkv", "avi", "mov"])

if video_file is not None:
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file:
        tmp_file.write(video_file.getbuffer())
        video_path = tmp_file.name
    
    st.video(video_file)
    st.write("Extracting audio from video...")
    audio_path = extract_audio_from_video(video_path, tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name)
    
    if audio_path is not None:
        st.write("Transcribing audio...")
        try:
            transcription = transcribe_audio(audio_path)
            st.write("Transcription:", transcription)
            
            st.write("Translating and summarizing...")
            summary = translate_and_summarize(transcription)
            st.write("Translated Summary:", summary)
        except Exception as e:
            st.error(f"An error occurred: {e}")
    else:
        st.write("No audio track found in the video file.")