File size: 3,389 Bytes
d008cad
94ec7f8
063e974
 
 
99cbfbc
63beaa0
 
a406b0b
 
b2a4dc4
f758117
63beaa0
 
d008cad
cda85cb
 
 
 
d008cad
cda85cb
 
 
 
94ec7f8
cda85cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99cbfbc
 
301baa4
063e974
a406b0b
 
e2496fa
a406b0b
 
 
 
 
 
 
 
63beaa0
a406b0b
 
 
 
 
 
 
63beaa0
 
 
 
 
 
 
 
 
 
99cbfbc
 
 
f758117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
063e974
99cbfbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from st_audiorec import st_audiorec
import matplotlib.pyplot as plt
import sounddevice as sd
import numpy as np
import pandas as pd
import torch
import torchaudio
import wave
import io
from scipy.io import wavfile
# MODEL LOADING and INITIALISATION
model = torch.jit.load("snorenetv1_small.ptl")
model.eval()

# Session state
if 'text' not in st.session_state:
    st.session_state['text'] = 'Listening...'
    st.session_state['run'] = False

# Audio parameters
st.sidebar.header('Audio Parameters')

FRAMES_PER_BUFFER = int(st.sidebar.text_input('Frames per buffer', 3200))
FORMAT = 'audio/wav'
CHANNELS = 1
RATE = int(st.sidebar.text_input('Rate', 16000))

# Open an audio stream

monitoring = False
audio_data = []

def start_monitoring():
    global monitoring
    st.session_state['run'] = True
    monitoring = True

def stop_monitoring():
    global monitoring
    st.session_state['run'] = False
    monitoring = False

st.title('🎙️ Real-Time Snore Detection App')

with st.expander('About this App'):
    st.markdown('''
    This streamlit app from Hypermind Labs Helps users detect
    how much they are snoring during their sleep.
    ''')



wav_audio_data = st_audiorec()
if wav_audio_data is not None:
    data = np.frombuffer(wav_audio_data, dtype=np.int16)
    st.write(len(data))
    duration = len(data)//110000
    num_of_samples = len(data)
    sample_rate = num_of_samples // duration
    # data = np.array(wav_audio_data, dtype=float)
    max_abs_value = np.max(np.abs(data))
    np_array = (data/max_abs_value) * 32767
    scaled_data = np_array.astype(np.int16).tobytes()
    with io.BytesIO() as fp, wave.open(fp, mode="wb") as waveobj:
        waveobj.setnchannels(1)
        waveobj.setframerate(96000)
        waveobj.setsampwidth(2)
        waveobj.setcomptype("NONE", "NONE")
        waveobj.writeframes(scaled_data)
        wav_make = fp.getvalue()
    
    with open("output.wav", 'wb') as wav_file:
        wav_file.write(wav_make)
    sr, waveform = wavfile.read('output.wav')
    input_tensor = torch.tensor(waveform[:16000]).unsqueeze(0).to(torch.float32)
    st.write(input_tensor.shape)
    result = model(input_tensor)
    if np.abs(result[0][0]) > np.abs(result[0][1]):
        st.write("NON_SNORING")
    else:
        st.write("SNORING")

    
    # PERCENTAGE OF SNORING PLOT


    
    # waveform, sample_rate = torchaudio.load('test/0_10.wav')
    # resampler = T.Resample(sample_rate, RESAMPLE_RATE, dtype=waveform.dtype)
    # signal = resampler(waveform)
    # signal = torch.mean(signal, dim=0, keepdim=True)
    # ptl_model(signal)

    # snore = 0
    # other = 0

    # for row in model:
    #     for element in row:
    #         if element > 0.5:
    #             snore += 1
    #         else:
    #             other += 1

    # total = snore + other
    # snore_percentage = (snore / total) * 100
    # other_percentage = (other / total) * 100

    # categories = ["Snore", "Other"]
    # percentages = [snore_percentage, other_percentage]

    # plt.figure(figsize=(8, 4))
    # plt.barh(categories, percentages, color=['#ff0033', '#00ffee'])
    # plt.xlabel('Percentage')
    # plt.title('Percentage of "Snore" and "Other"')
    # plt.xlim(0, 100)

    # for i, percentage in enumerate(percentages):
    #     plt.text(percentage, i, f' {percentage:.2f}%', va='center')

    # st.pyplot(plt)