File size: 3,994 Bytes
64a29d6 f6beab7 d20cd0c f6beab7 d20cd0c f6beab7 8f5fb37 f6beab7 8f5fb37 f6beab7 8f5fb37 f6beab7 8f5fb37 f6beab7 8f5fb37 f6beab7 8f5fb37 f6beab7 8f5fb37 f6beab7 8f5fb37 d20cd0c f6beab7 8f5fb37 d20cd0c f6beab7 8f5fb37 f6beab7 8f5fb37 d20cd0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import srt
from datetime import timedelta
# ๋ชจ๋ธ ๋ฐ ํ๋ก์ธ์ ๋ก๋
@st.cache_resource
def load_model():
model = WhisperForConditionalGeneration.from_pretrained("lcjln/AIME_Project_The_Final")
processor = WhisperProcessor.from_pretrained("lcjln/AIME_The_Final")
return model, processor
model, processor = load_model()
# Streamlit ์น ์ ํ๋ฆฌ์ผ์ด์
์ธํฐํ์ด์ค
st.title("Whisper ์๋ง ์์ฑ๊ธฐ")
# ์ฌ๋ฌ WAV ํ์ผ ์
๋ก๋
uploaded_files = st.file_uploader("์ฌ๊ธฐ์ WAV ํ์ผ๋ค์ ๋๋๊ทธ ์ค ๋๋กญ ํ์ธ์", type=["wav"], accept_multiple_files=True)
# ํ์ผ ๋ชฉ๋ก์ ๋ณด์ฌ์ค
if uploaded_files:
st.write("์
๋ก๋๋ ํ์ผ ๋ชฉ๋ก:")
for uploaded_file in uploaded_files:
st.write(uploaded_file.name)
# ์คํ ๋ฒํผ
if st.button("์คํ"):
combined_subs = []
last_end_time = timedelta(0)
subtitle_index = 1
for uploaded_file in uploaded_files:
st.write(f"์ฒ๋ฆฌ ์ค: {uploaded_file.name}")
# ์งํ๋ฐ ์ด๊ธฐํ
progress_bar = st.progress(0)
# WAV ํ์ผ ๋ก๋ ๋ฐ ์ฒ๋ฆฌ
st.write("์ค๋์ค ํ์ผ์ ์ฒ๋ฆฌํ๋ ์ค์
๋๋ค...")
audio, sr = librosa.load(uploaded_file, sr=16000)
progress_bar.progress(50)
# Whisper ๋ชจ๋ธ๋ก ๋ณํ
st.write("๋ชจ๋ธ์ ํตํด ์๋ง์ ์์ฑํ๋ ์ค์
๋๋ค...")
segments = split_audio(audio, sr, segment_duration=5)
for i, segment in enumerate(segments):
inputs = processor(segment, return_tensors="pt", sampling_rate=16000)
with torch.no_grad():
outputs = model.generate(inputs["input_features"], max_length=2048, return_dict_in_generate=True, output_scores=True)
# ํ
์คํธ ๋์ฝ๋ฉ
transcription = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
# ์ ๋ขฐ๋ ์ ์ ๊ณ์ฐ (์ถ๊ฐ์ ์ธ ์ ๋ขฐ๋ ํํฐ๋ง ์ ์ฉ)
avg_logit_score = torch.mean(outputs.scores[-1]).item()
# ์ ๋ขฐ๋ ์ ์๊ฐ ๋ฎ๊ฑฐ๋ ํ
์คํธ๊ฐ ๋น์ด์๋ ๊ฒฝ์ฐ ๋ฌด์
if transcription and avg_logit_score > -5.0:
segment_duration = librosa.get_duration(y=segment, sr=sr)
end_time = last_end_time + timedelta(seconds=segment_duration)
combined_subs.append(
srt.Subtitle(
index=subtitle_index,
start=last_end_time,
end=end_time,
content=transcription
)
)
last_end_time = end_time
subtitle_index += 1
progress_bar.progress(100)
st.success(f"{uploaded_file.name}์ ์๋ง์ด ์ฑ๊ณต์ ์ผ๋ก ์์ฑ๋์์ต๋๋ค!")
# ๋ชจ๋ ์๋ง์ ํ๋์ SRT ํ์ผ๋ก ์ ์ฅ
st.write("์ต์ข
SRT ํ์ผ์ ์์ฑํ๋ ์ค์
๋๋ค...")
srt_content = srt.compose(combined_subs)
final_srt_file_path = "combined_output.srt"
with open(final_srt_file_path, "w", encoding="utf-8") as f:
f.write(srt_content)
st.success("์ต์ข
SRT ํ์ผ์ด ์ฑ๊ณต์ ์ผ๋ก ์์ฑ๋์์ต๋๋ค!")
# ์ต์ข
SRT ํ์ผ ๋ค์ด๋ก๋ ๋ฒํผ
with open(final_srt_file_path, "rb") as srt_file:
st.download_button(label="SRT ํ์ผ ๋ค์ด๋ก๋", data=srt_file, file_name=final_srt_file_path, mime="text/srt")
def split_audio(audio, sr, segment_duration=5):
segments = []
for i in range(0, len(audio), int(segment_duration * sr)):
segment = audio[i:i + int(segment_duration * sr)]
segments.append(segment)
return segments |