Snore-Detector / app.py
AbdullaShafeeg
update
75b78d9
raw
history blame
3.07 kB
import streamlit as st
from st_audiorec import st_audiorec
import matplotlib.pyplot as plt
import sounddevice as sd
import numpy as np
import pandas as pd
import torch
import torchaudio
import wave
import io
from scipy.io import wavfile
# MODEL LOADING and INITIALISATION
model = torch.jit.load("snorenetv1_small.ptl")
model.eval()
# Session state
if 'text' not in st.session_state:
st.session_state['text'] = 'Listening...'
st.session_state['run'] = False
# Audio parameters
st.sidebar.header('Audio Parameters')
FRAMES_PER_BUFFER = int(st.sidebar.text_input('Frames per buffer', 3200))
FORMAT = 'audio/wav'
CHANNELS = 1
RATE = int(st.sidebar.text_input('Rate', 16000))
# Open an audio stream
monitoring = False
audio_data = []
def start_monitoring():
global monitoring
st.session_state['run'] = True
monitoring = True
def stop_monitoring():
global monitoring
st.session_state['run'] = False
monitoring = False
st.title('🎙️ Real-Time Snore Detection App')
with st.expander('About this App'):
st.markdown('''
This streamlit app from Hypermind Labs Helps users detect
how much they are snoring during their sleep.
''')
wav_audio_data = st_audiorec()
if wav_audio_data is not None:
data = np.frombuffer(wav_audio_data, dtype=np.int16)
st.write(len(data))
duration = len(data)//110000
num_of_samples = len(data)
sample_rate = num_of_samples // duration
# data = np.array(wav_audio_data, dtype=float)
max_abs_value = np.max(np.abs(data))
np_array = (data/max_abs_value) * 32767
scaled_data = np_array.astype(np.int16).tobytes()
with io.BytesIO() as fp, wave.open(fp, mode="wb") as waveobj:
waveobj.setnchannels(1)
waveobj.setframerate(96000)
waveobj.setsampwidth(2)
waveobj.setcomptype("NONE", "NONE")
waveobj.writeframes(scaled_data)
wav_make = fp.getvalue()
with open("output.wav", 'wb') as wav_file:
wav_file.write(wav_make)
sr, waveform = wavfile.read('output.wav')
snore = 0
other = 0
s=0
n=16000
endReached = False
while(endReached==False):
input_tensor = torch.tensor(waveform[s:n]).unsqueeze(0).to(torch.float32)
result = model(input_tensor)
if np.abs(result[0][0]) > np.abs(result[0][1]):
other += 1
else:
snore += 1
s += 16000
n += 16000
if(n >= len(waveform)):
endReached = True
# PERCENTAGE OF SNORING PLOT
total = snore + other
snore_percentage = (snore / total) * 100
other_percentage = (other / total) * 100
categories = ["Snore", "Other"]
percentages = [snore_percentage, other_percentage]
plt.figure(figsize=(8, 4))
plt.barh(categories, percentages, color=['#ff0033', '#00ffee'])
plt.xlabel('Percentage')
plt.title('Percentage of Snoring')
plt.xlim(0, 100)
for i, percentage in enumerate(percentages):
plt.text(percentage, i, f' {percentage:.2f}%', va='center')
st.pyplot(plt)