lcjln commited on
Commit
d20cd0c
ยท
verified ยท
1 Parent(s): 8f5fb37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -27
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import streamlit as st
3
- import torch
4
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
 
5
  import librosa
6
  import srt
7
  from datetime import timedelta
@@ -15,9 +15,8 @@ def load_model():
15
 
16
  model, processor = load_model()
17
 
18
- # ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ธํ„ฐํŽ˜์ด์Šค
19
  st.title("Whisper ์ž๋ง‰ ์ƒ์„ฑ๊ธฐ")
20
- st.write("WAV ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜์—ฌ ์ž๋ง‰์„ ์ƒ์„ฑํ•˜์„ธ์š”.")
21
 
22
  # ์—ฌ๋Ÿฌ WAV ํŒŒ์ผ ์—…๋กœ๋“œ
23
  uploaded_files = st.file_uploader("์—ฌ๊ธฐ์— WAV ํŒŒ์ผ๋“ค์„ ๋“œ๋ž˜๊ทธ ์•ค ๋“œ๋กญ ํ•˜์„ธ์š”", type=["wav"], accept_multiple_files=True)
@@ -48,31 +47,37 @@ if uploaded_files:
48
 
49
  # Whisper ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
50
  st.write("๋ชจ๋ธ์„ ํ†ตํ•ด ์ž๋ง‰์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
51
- inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
52
- with torch.no_grad():
53
- predicted_ids = model.generate(inputs["input_features"], max_length=2048)
54
-
55
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
56
-
57
- progress_bar.progress(80)
58
-
59
- # SRT ์ž๋ง‰ ์ƒ์„ฑ
60
- st.write("SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
61
- lines = transcription.split(". ")
62
- step = len(audio) / sr / len(lines)
63
- start_time = last_end_time
64
-
65
- for line in lines:
66
- end_time = start_time + timedelta(seconds=step)
67
- combined_subs.append(
68
- srt.Subtitle(index=subtitle_index, start=start_time, end=end_time, content=line)
69
- )
70
- start_time = end_time
71
- subtitle_index += 1
72
-
73
- last_end_time = start_time # ๋‹ค์Œ ํŒŒ์ผ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„์„ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ๋งˆ์ง€๋ง‰ ๋ ์‹œ๊ฐ„์„ ๊ธฐ๋ก
 
 
 
 
 
74
 
75
  progress_bar.progress(100)
 
76
 
77
  # ๋ชจ๋“  ์ž๋ง‰์„ ํ•˜๋‚˜์˜ SRT ํŒŒ์ผ๋กœ ์ €์žฅ
78
  st.write("์ตœ์ข… SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
@@ -86,4 +91,11 @@ if uploaded_files:
86
 
87
  # ์ตœ์ข… SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ
88
  with open(final_srt_file_path, "rb") as srt_file:
89
- st.download_button(label="SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ", data=srt_file, file_name=final_srt_file_path, mime="text/srt")
 
 
 
 
 
 
 
 
1
  import os
2
  import streamlit as st
 
3
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
4
+ import torch
5
  import librosa
6
  import srt
7
  from datetime import timedelta
 
15
 
16
  model, processor = load_model()
17
 
18
+ # Streamlit ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ธํ„ฐํŽ˜์ด์Šค
19
  st.title("Whisper ์ž๋ง‰ ์ƒ์„ฑ๊ธฐ")
 
20
 
21
  # ์—ฌ๋Ÿฌ WAV ํŒŒ์ผ ์—…๋กœ๋“œ
22
  uploaded_files = st.file_uploader("์—ฌ๊ธฐ์— WAV ํŒŒ์ผ๋“ค์„ ๋“œ๋ž˜๊ทธ ์•ค ๋“œ๋กญ ํ•˜์„ธ์š”", type=["wav"], accept_multiple_files=True)
 
47
 
48
  # Whisper ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
49
  st.write("๋ชจ๋ธ์„ ํ†ตํ•ด ์ž๋ง‰์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
50
+ segments = split_audio(audio, sr, segment_duration=5)
51
+
52
+ for i, segment in enumerate(segments):
53
+ inputs = processor(segment, return_tensors="pt", sampling_rate=16000)
54
+ with torch.no_grad():
55
+ outputs = model.generate(inputs["input_features"], max_length=2048, return_dict_in_generate=True, output_scores=True)
56
+
57
+ # ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
58
+ transcription = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
59
+
60
+ # ์‹ ๋ขฐ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ (์ถ”๊ฐ€์ ์ธ ์‹ ๋ขฐ๋„ ํ•„ํ„ฐ๋ง ์ ์šฉ)
61
+ avg_logit_score = torch.mean(outputs.scores[-1]).item()
62
+
63
+ # ์‹ ๋ขฐ๋„ ์ ์ˆ˜๊ฐ€ ๋‚ฎ๊ฑฐ๋‚˜ ํ…์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ ๋ฌด์‹œ
64
+ if transcription and avg_logit_score > -5.0:
65
+ segment_duration = librosa.get_duration(y=segment, sr=sr)
66
+ end_time = last_end_time + timedelta(seconds=segment_duration)
67
+
68
+ combined_subs.append(
69
+ srt.Subtitle(
70
+ index=subtitle_index,
71
+ start=last_end_time,
72
+ end=end_time,
73
+ content=transcription
74
+ )
75
+ )
76
+ last_end_time = end_time
77
+ subtitle_index += 1
78
 
79
  progress_bar.progress(100)
80
+ st.success(f"{uploaded_file.name}์˜ ์ž๋ง‰์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")
81
 
82
  # ๋ชจ๋“  ์ž๋ง‰์„ ํ•˜๋‚˜์˜ SRT ํŒŒ์ผ๋กœ ์ €์žฅ
83
  st.write("์ตœ์ข… SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
 
91
 
92
  # ์ตœ์ข… SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ
93
  with open(final_srt_file_path, "rb") as srt_file:
94
+ st.download_button(label="SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ", data=srt_file, file_name=final_srt_file_path, mime="text/srt")
95
+
96
+ def split_audio(audio, sr, segment_duration=5):
97
+ segments = []
98
+ for i in range(0, len(audio), int(segment_duration * sr)):
99
+ segment = audio[i:i + int(segment_duration * sr)]
100
+ segments.append(segment)
101
+ return segments