Spaces:
Runtime error
Runtime error
import streamlit as st | |
from st_audiorec import st_audiorec | |
import matplotlib.pyplot as plt | |
import sounddevice as sd | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchaudio | |
import wave | |
import io | |
from scipy.io import wavfile | |
# MODEL LOADING and INITIALISATION | |
model = torch.jit.load("snorenetv1_small.ptl") | |
model.eval() | |
# Session state | |
if 'text' not in st.session_state: | |
st.session_state['text'] = 'Listening...' | |
st.session_state['run'] = False | |
# Audio parameters | |
st.sidebar.header('Audio Parameters') | |
FRAMES_PER_BUFFER = int(st.sidebar.text_input('Frames per buffer', 3200)) | |
FORMAT = 'audio/wav' | |
CHANNELS = 1 | |
RATE = int(st.sidebar.text_input('Rate', 16000)) | |
# Open an audio stream | |
monitoring = False | |
audio_data = [] | |
def start_monitoring(): | |
global monitoring | |
st.session_state['run'] = True | |
monitoring = True | |
def stop_monitoring(): | |
global monitoring | |
st.session_state['run'] = False | |
monitoring = False | |
st.title('🎙️ Real-Time Snore Detection App') | |
with st.expander('About this App'): | |
st.markdown(''' | |
This streamlit app from Hypermind Labs Helps users detect | |
how much they are snoring during their sleep. | |
''') | |
wav_audio_data = st_audiorec() | |
if wav_audio_data is not None: | |
data = np.frombuffer(wav_audio_data, dtype=np.int16) | |
st.write(len(data)) | |
duration = len(data)//110000 | |
num_of_samples = len(data) | |
sample_rate = num_of_samples // duration | |
# data = np.array(wav_audio_data, dtype=float) | |
max_abs_value = np.max(np.abs(data)) | |
np_array = (data/max_abs_value) * 32767 | |
scaled_data = np_array.astype(np.int16).tobytes() | |
with io.BytesIO() as fp, wave.open(fp, mode="wb") as waveobj: | |
waveobj.setnchannels(1) | |
waveobj.setframerate(96000) | |
waveobj.setsampwidth(2) | |
waveobj.setcomptype("NONE", "NONE") | |
waveobj.writeframes(scaled_data) | |
wav_make = fp.getvalue() | |
with open("output.wav", 'wb') as wav_file: | |
wav_file.write(wav_make) | |
sr, waveform = wavfile.read('output.wav') | |
input_tensor = torch.tensor(waveform[:16000]).unsqueeze(0).to(torch.float32) | |
st.write(input_tensor.shape) | |
result = model(input_tensor) | |
if np.abs(result[0][0]) > np.abs(result[0][1]): | |
st.write("NON_SNORING") | |
else: | |
st.write("SNORING") | |
# PERCENTAGE OF SNORING PLOT | |
# waveform, sample_rate = torchaudio.load('test/0_10.wav') | |
# resampler = T.Resample(sample_rate, RESAMPLE_RATE, dtype=waveform.dtype) | |
# signal = resampler(waveform) | |
# signal = torch.mean(signal, dim=0, keepdim=True) | |
# ptl_model(signal) | |
# snore = 0 | |
# other = 0 | |
# for row in model: | |
# for element in row: | |
# if element > 0.5: | |
# snore += 1 | |
# else: | |
# other += 1 | |
# total = snore + other | |
# snore_percentage = (snore / total) * 100 | |
# other_percentage = (other / total) * 100 | |
# categories = ["Snore", "Other"] | |
# percentages = [snore_percentage, other_percentage] | |
# plt.figure(figsize=(8, 4)) | |
# plt.barh(categories, percentages, color=['#ff0033', '#00ffee']) | |
# plt.xlabel('Percentage') | |
# plt.title('Percentage of "Snore" and "Other"') | |
# plt.xlim(0, 100) | |
# for i, percentage in enumerate(percentages): | |
# plt.text(percentage, i, f' {percentage:.2f}%', va='center') | |
# st.pyplot(plt) | |