File size: 3,883 Bytes
1ab2ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import glob
import torchaudio
import torchaudio.transforms as T
import numpy as np
from matplotlib import pyplot as plt
import librosa
import librosa.display
from df import enhance, init_df
import streamlit as st
from streamlit.components.v1 import html

app_title = "μ†ŒμŒ μ–΅μ œ 도ꡬ"
model, df_state, _ = init_df()  # Load default model
df_sr = 48000


def display_audio_info(audio, title):
    # 두 개의 컬럼 생성
    col1, col2 = st.columns(2)

    audio = np.clip(audio, -1.0, 1.0)
    if len(np.shape(audio)) == 2:
        audio = audio[0]

    # μ™Όμͺ½ μ»¬λŸΌμ— μŠ€νŽ™νŠΈλ‘œκ·Έλž¨ ν‘œμ‹œ
    with col1:
        st.markdown(f"### {title} - Spectrogram")
        D = librosa.stft(audio)  # STFT of y
        S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
        fig, ax = plt.subplots()
        img = librosa.display.specshow(
            S_db, x_axis='time', y_axis='linear', ax=ax)
        fig.colorbar(img, ax=ax, format="%+2.f dB")
        st.pyplot(fig)

    # 였λ₯Έμͺ½ μ»¬λŸΌμ— νŒŒν˜• ν‘œμ‹œ
    with col2:
        st.markdown(f"### {title} - Waveform")
        fig, ax = plt.subplots()
        plt.plot(audio)
        ax.set_xticks([])
        ax.set_ylim(-1, 1)
        st.pyplot(fig)


def main():
    st.set_page_config(page_title=app_title, page_icon="favicon.ico",
                       layout="centered", initial_sidebar_state="auto", menu_items=None)

    button = """<script type="text/javascript" src="https://cdnjs.buymeacoffee.com/1.0.0/button.prod.min.js" data-name="bmc-button" data-slug="woojae" data-color="#FFDD00" data-emoji="β˜•"  data-font="Cookie" data-text="Buy me a coffee" data-outline-color="#000000" data-font-color="#000000" data-coffee-color="#ffffff" ></script>"""

    st.title(app_title)
    st.divider()
    st.header('μ†μ‰½κ²Œ λΆˆν•„μš”ν•œ μ†ŒμŒμ„ μ œκ±°ν•˜μ„Έμš”!')

    uploaded_file = st.file_uploader(
        "λ³€ν™˜ν•  νŒŒμΌμ„ μ—…λ‘œλ“œ ν•΄μ£Όμ„Έμš”. (지원 ν˜•μ‹: .wav, .mp3, .opus)")

    if uploaded_file:
        # 이전에 λ‹€μš΄λ‘œλ“œ ν•œ νŒŒμΌμ„ μ‚­μ œ
        files_to_remove = glob.glob('enhanced_*')
        for file in files_to_remove:
            os.remove(file)

        uploaded_file_type = uploaded_file.type.split('/')[-1]
        print(uploaded_file_type)
        if uploaded_file_type not in ['wav', 'mpeg', 'ogg']:
            st.text('μ§€μ›ν•˜μ§€ μ•ŠλŠ” 파일 ν˜•μ‹μž…λ‹ˆλ‹€.')
        else:
            with st.spinner('μ†ŒμŒ μ œκ±°ν•˜λŠ” 쀑'):
                noisy_audio, sr = torchaudio.load(uploaded_file)
                print("np.shape(noisy_audio)", np.shape(noisy_audio))
                st.audio(noisy_audio.numpy(), sample_rate=sr)

                # μƒ˜ν”Œλ§ λ ˆμ΄νŠΈκ°€ 48000Hzκ°€ 아닐 경우 λ¦¬μƒ˜ν”Œλ§
                if sr != df_sr:
                    resampler = T.Resample(orig_freq=sr, new_freq=df_sr)
                    noisy_audio = resampler(noisy_audio)
                display_audio_info(noisy_audio.numpy(), "μž…λ ₯")

            with st.spinner('μ†ŒμŒ μ œκ±°ν•˜λŠ” 쀑'):
                output_audio = enhance(model, df_state, noisy_audio)
                enhanced_audio = output_audio
                st.divider()
                # μƒ˜ν”Œλ§ λ ˆμ΄νŠΈκ°€ 48000Hzκ°€ 아닐 경우 λ¦¬μƒ˜ν”Œλ§
                if sr != df_sr:
                    resampler = T.Resample(orig_freq=df_sr, new_freq=sr)
                    enhanced_audio = resampler(enhanced_audio)
                st.audio(enhanced_audio.numpy(), sample_rate=sr)
                display_audio_info(output_audio.numpy(), "좜λ ₯")

    html(button, height=70, width=240)

    st.markdown(
        """
        <style>
            iframe[width="240"] {
                position: fixed;
                bottom: 30px;
                right: 10px;
            }
        </style>
        """,
        unsafe_allow_html=True,
    )


if __name__ == '__main__':
    main()