new2 / app.py
vivekvar's picture
Update app.py
7bb4a8e verified
raw
history blame
4.75 kB
import streamlit as st
from transformers import WhisperProcessor, WhisperForConditionalGeneration, RagTokenizer, RagRetriever, RagSequenceForGeneration
import torch
import soundfile as sf
import librosa
from moviepy.editor import VideoFileClip
import os
import tempfile
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load Whisper base model and processor
whisper_model_name = "openai/whisper-base"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
# Load RAG sequence model and tokenizer
rag_model_name = "facebook/rag-sequence-nq"
rag_tokenizer = RagTokenizer.from_pretrained(rag_model_name)
# Try to load RagRetriever with trust_remote_code=True
try:
rag_retriever = RagRetriever.from_pretrained(
rag_model_name,
index_name="exact",
use_dummy_dataset=True,
dataset_path=local_dataset_path,
trust_remote_code=True
)
logger.info("Successfully loaded RagRetriever with trust_remote_code=True")
except ValueError as e:
logger.error(f"Error loading RagRetriever: {e}")
rag_retriever = None
if rag_retriever is not None:
rag_model = RagSequenceForGeneration.from_pretrained(rag_model_name, retriever=rag_retriever)
else:
logger.error("RagRetriever is not available, unable to proceed with loading RAG model.")
st.error("RagRetriever is not available, unable to proceed with loading RAG model.")
def transcribe_audio(audio_path, language="ru"):
speech, rate = librosa.load(audio_path, sr=16000)
inputs = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
input_features = whisper_processor.feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_features
predicted_ids = whisper_model.generate(input_features, forced_decoder_ids=whisper_processor.get_decoder_prompt_ids(language=language, task="translate"))
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
def translate_and_summarize(text):
if rag_retriever is None:
return ["Translation and summarization feature is not available due to RAG retriever loading issue."]
inputs = rag_tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
outputs = rag_model.generate(input_ids=input_ids, attention_mask=attention_mask)
return rag_tokenizer.batch_decode(outputs, skip_special_tokens=True)
def extract_audio_from_video(video_path, output_audio_path):
video_clip = VideoFileClip(video_path)
audio_clip = video_clip.audio
if audio_clip is not None:
audio_clip.write_audiofile(output_audio_path)
return output_audio_path
else:
return None
st.title("Audio and Video Transcription & Summarization")
# Audio Upload Section
st.header("Upload an Audio File")
audio_file = st.file_uploader("Choose an audio file...", type=["wav", "mp3", "m4a"])
if audio_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(audio_file.getbuffer())
audio_path = tmp_file.name
st.audio(audio_file)
st.write("Transcribing audio...")
try:
transcription = transcribe_audio(audio_path)
st.write("Transcription:", transcription)
st.write("Translating and summarizing...")
summary = translate_and_summarize(transcription)
st.write("Translated Summary:", summary)
except Exception as e:
st.error(f"An error occurred: {e}")
# Video Upload Section
st.header("Upload a Video File")
video_file = st.file_uploader("Choose a video file...", type=["mp4", "mkv", "avi", "mov"])
if video_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file:
tmp_file.write(video_file.getbuffer())
video_path = tmp_file.name
st.video(video_file)
st.write("Extracting audio from video...")
audio_path = extract_audio_from_video(video_path, tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name)
if audio_path is not None:
st.write("Transcribing audio...")
try:
transcription = transcribe_audio(audio_path)
st.write("Transcription:", transcription)
st.write("Translating and summarizing...")
summary = translate_and_summarize(transcription)
st.write("Translated Summary:", summary)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.write("No audio track found in the video file.")