AIME / app.py
lcjln's picture
Update app.py
d20cd0c verified
raw
history blame
3.99 kB
import os
import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import srt
from datetime import timedelta
# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
@st.cache_resource
def load_model():
model = WhisperForConditionalGeneration.from_pretrained("lcjln/AIME_Project_The_Final")
processor = WhisperProcessor.from_pretrained("lcjln/AIME_The_Final")
return model, processor
model, processor = load_model()
# Streamlit ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ธํ„ฐํŽ˜์ด์Šค
st.title("Whisper ์ž๋ง‰ ์ƒ์„ฑ๊ธฐ")
# ์—ฌ๋Ÿฌ WAV ํŒŒ์ผ ์—…๋กœ๋“œ
uploaded_files = st.file_uploader("์—ฌ๊ธฐ์— WAV ํŒŒ์ผ๋“ค์„ ๋“œ๋ž˜๊ทธ ์•ค ๋“œ๋กญ ํ•˜์„ธ์š”", type=["wav"], accept_multiple_files=True)
# ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด์—ฌ์คŒ
if uploaded_files:
st.write("์—…๋กœ๋“œ๋œ ํŒŒ์ผ ๋ชฉ๋ก:")
for uploaded_file in uploaded_files:
st.write(uploaded_file.name)
# ์‹คํ–‰ ๋ฒ„ํŠผ
if st.button("์‹คํ–‰"):
combined_subs = []
last_end_time = timedelta(0)
subtitle_index = 1
for uploaded_file in uploaded_files:
st.write(f"์ฒ˜๋ฆฌ ์ค‘: {uploaded_file.name}")
# ์ง„ํ–‰๋ฐ” ์ดˆ๊ธฐํ™”
progress_bar = st.progress(0)
# WAV ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ์ฒ˜๋ฆฌ
st.write("์˜ค๋””์˜ค ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
audio, sr = librosa.load(uploaded_file, sr=16000)
progress_bar.progress(50)
# Whisper ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
st.write("๋ชจ๋ธ์„ ํ†ตํ•ด ์ž๋ง‰์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
segments = split_audio(audio, sr, segment_duration=5)
for i, segment in enumerate(segments):
inputs = processor(segment, return_tensors="pt", sampling_rate=16000)
with torch.no_grad():
outputs = model.generate(inputs["input_features"], max_length=2048, return_dict_in_generate=True, output_scores=True)
# ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
transcription = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
# ์‹ ๋ขฐ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ (์ถ”๊ฐ€์ ์ธ ์‹ ๋ขฐ๋„ ํ•„ํ„ฐ๋ง ์ ์šฉ)
avg_logit_score = torch.mean(outputs.scores[-1]).item()
# ์‹ ๋ขฐ๋„ ์ ์ˆ˜๊ฐ€ ๋‚ฎ๊ฑฐ๋‚˜ ํ…์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ ๋ฌด์‹œ
if transcription and avg_logit_score > -5.0:
segment_duration = librosa.get_duration(y=segment, sr=sr)
end_time = last_end_time + timedelta(seconds=segment_duration)
combined_subs.append(
srt.Subtitle(
index=subtitle_index,
start=last_end_time,
end=end_time,
content=transcription
)
)
last_end_time = end_time
subtitle_index += 1
progress_bar.progress(100)
st.success(f"{uploaded_file.name}์˜ ์ž๋ง‰์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")
# ๋ชจ๋“  ์ž๋ง‰์„ ํ•˜๋‚˜์˜ SRT ํŒŒ์ผ๋กœ ์ €์žฅ
st.write("์ตœ์ข… SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
srt_content = srt.compose(combined_subs)
final_srt_file_path = "combined_output.srt"
with open(final_srt_file_path, "w", encoding="utf-8") as f:
f.write(srt_content)
st.success("์ตœ์ข… SRT ํŒŒ์ผ์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")
# ์ตœ์ข… SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ
with open(final_srt_file_path, "rb") as srt_file:
st.download_button(label="SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ", data=srt_file, file_name=final_srt_file_path, mime="text/srt")
def split_audio(audio, sr, segment_duration=5):
segments = []
for i in range(0, len(audio), int(segment_duration * sr)):
segment = audio[i:i + int(segment_duration * sr)]
segments.append(segment)
return segments