File size: 3,994 Bytes
64a29d6
f6beab7
 
d20cd0c
f6beab7
 
 
 
 
 
 
 
 
 
 
 
 
d20cd0c
f6beab7
 
8f5fb37
 
f6beab7
8f5fb37
 
 
 
 
f6beab7
8f5fb37
 
 
 
 
f6beab7
8f5fb37
 
f6beab7
8f5fb37
 
f6beab7
8f5fb37
 
 
f6beab7
8f5fb37
f6beab7
8f5fb37
 
d20cd0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6beab7
8f5fb37
d20cd0c
f6beab7
8f5fb37
 
 
f6beab7
8f5fb37
 
 
 
 
 
 
 
d20cd0c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import srt
from datetime import timedelta

# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
@st.cache_resource
def load_model():
    model = WhisperForConditionalGeneration.from_pretrained("lcjln/AIME_Project_The_Final")
    processor = WhisperProcessor.from_pretrained("lcjln/AIME_The_Final")
    return model, processor

model, processor = load_model()

# Streamlit ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ธํ„ฐํŽ˜์ด์Šค
st.title("Whisper ์ž๋ง‰ ์ƒ์„ฑ๊ธฐ")

# ์—ฌ๋Ÿฌ WAV ํŒŒ์ผ ์—…๋กœ๋“œ
uploaded_files = st.file_uploader("์—ฌ๊ธฐ์— WAV ํŒŒ์ผ๋“ค์„ ๋“œ๋ž˜๊ทธ ์•ค ๋“œ๋กญ ํ•˜์„ธ์š”", type=["wav"], accept_multiple_files=True)

# ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด์—ฌ์คŒ
if uploaded_files:
    st.write("์—…๋กœ๋“œ๋œ ํŒŒ์ผ ๋ชฉ๋ก:")
    for uploaded_file in uploaded_files:
        st.write(uploaded_file.name)

    # ์‹คํ–‰ ๋ฒ„ํŠผ
    if st.button("์‹คํ–‰"):
        combined_subs = []
        last_end_time = timedelta(0)
        subtitle_index = 1

        for uploaded_file in uploaded_files:
            st.write(f"์ฒ˜๋ฆฌ ์ค‘: {uploaded_file.name}")

            # ์ง„ํ–‰๋ฐ” ์ดˆ๊ธฐํ™”
            progress_bar = st.progress(0)

            # WAV ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ์ฒ˜๋ฆฌ
            st.write("์˜ค๋””์˜ค ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
            audio, sr = librosa.load(uploaded_file, sr=16000)

            progress_bar.progress(50)

            # Whisper ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
            st.write("๋ชจ๋ธ์„ ํ†ตํ•ด ์ž๋ง‰์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
            segments = split_audio(audio, sr, segment_duration=5)

            for i, segment in enumerate(segments):
                inputs = processor(segment, return_tensors="pt", sampling_rate=16000)
                with torch.no_grad():
                    outputs = model.generate(inputs["input_features"], max_length=2048, return_dict_in_generate=True, output_scores=True)

                # ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
                transcription = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()

                # ์‹ ๋ขฐ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ (์ถ”๊ฐ€์ ์ธ ์‹ ๋ขฐ๋„ ํ•„ํ„ฐ๋ง ์ ์šฉ)
                avg_logit_score = torch.mean(outputs.scores[-1]).item()

                # ์‹ ๋ขฐ๋„ ์ ์ˆ˜๊ฐ€ ๋‚ฎ๊ฑฐ๋‚˜ ํ…์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ ๋ฌด์‹œ
                if transcription and avg_logit_score > -5.0:
                    segment_duration = librosa.get_duration(y=segment, sr=sr)
                    end_time = last_end_time + timedelta(seconds=segment_duration)

                    combined_subs.append(
                        srt.Subtitle(
                            index=subtitle_index,
                            start=last_end_time,
                            end=end_time,
                            content=transcription
                        )
                    )
                    last_end_time = end_time
                    subtitle_index += 1

            progress_bar.progress(100)
            st.success(f"{uploaded_file.name}์˜ ์ž๋ง‰์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")

        # ๋ชจ๋“  ์ž๋ง‰์„ ํ•˜๋‚˜์˜ SRT ํŒŒ์ผ๋กœ ์ €์žฅ
        st.write("์ตœ์ข… SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
        srt_content = srt.compose(combined_subs)

        final_srt_file_path = "combined_output.srt"
        with open(final_srt_file_path, "w", encoding="utf-8") as f:
            f.write(srt_content)

        st.success("์ตœ์ข… SRT ํŒŒ์ผ์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")

        # ์ตœ์ข… SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ
        with open(final_srt_file_path, "rb") as srt_file:
            st.download_button(label="SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ", data=srt_file, file_name=final_srt_file_path, mime="text/srt")

def split_audio(audio, sr, segment_duration=5):
    segments = []
    for i in range(0, len(audio), int(segment_duration * sr)):
        segment = audio[i:i + int(segment_duration * sr)]
        segments.append(segment)
    return segments