HypermindLabs commited on
Commit
8f954b0
·
1 Parent(s): 1fd25f4
Files changed (1) hide show
  1. app.py +34 -8
app.py CHANGED
@@ -5,7 +5,8 @@ import sounddevice as sd
5
  import numpy as np
6
  import pandas as pd
7
  import torch
8
- # import torchaudio
 
9
  import wave
10
  import io
11
  from scipy.io import wavfile
@@ -14,14 +15,38 @@ import time
14
  import os
15
  import atexit
16
  import librosa
 
 
17
 
18
  # MODEL LOADING and INITIALISATION
19
-
20
- def load_model():
21
- model = torch.jit.load("snorenetv1_small.ptl")
22
- model.eval()
23
- return model
24
- model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  # Audio parameters
@@ -31,7 +56,8 @@ def process_data(waveform_chunks):
31
  for chunk in waveform_chunks:
32
  input_tensor = torch.tensor(chunk).unsqueeze(0).to(torch.float32)
33
  # st.write(input_tensor[0][98])
34
- result = model(input_tensor)
 
35
  # st.write(result)
36
  if np.abs(result[0][0]) > np.abs(result[0][1]):
37
  other += 1
 
5
  import numpy as np
6
  import pandas as pd
7
  import torch
8
+ import torch.nn as nn
9
+ import torchaudio
10
  import wave
11
  import io
12
  from scipy.io import wavfile
 
15
  import os
16
  import atexit
17
  import librosa
18
+ import torchaudio.functional as F
19
+ import torchaudio.transforms as T
20
 
21
  # MODEL LOADING and INITIALISATION
22
+ n_fft = 1024
23
+ win_length = None
24
+ hop_length = 32
25
+
26
+ # Input tensor shape was ([1,16000])
27
+ class SnoreNet(nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.transform = torchaudio.transforms.Spectrogram(n_fft = n_fft,win_length = win_length,hop_length = hop_length,center = True ,pad_mode = "reflect",power = 2.0,)
31
+ self.fc1 = nn.Linear(257013, 512)
32
+ self.act1 = nn.Tanh()
33
+ self.fc2 = nn.Linear(512, 2)
34
+ self.logs1 = nn.LogSoftmax(dim=1)
35
+
36
+ def forward(self, raw_audio_tensor):
37
+ # print(raw_audio_tensor.shape)
38
+ spectrogram = self.transform(raw_audio_tensor)
39
+ # print(spectrogram.shape)
40
+ spectrogram = spectrogram.reshape(spectrogram.size(0), -1)
41
+ # print(spectrogram.shape)
42
+ output = self.act1(self.fc1(spectrogram))
43
+ output = self.fc2(output)
44
+ output = torch.abs(self.logs1(output))
45
+ return output
46
+
47
+ model = SnoreNet()
48
+ model.load_state_dict(torch.load('snoreNetv1.pt'))
49
+ model.eval()
50
 
51
 
52
  # Audio parameters
 
56
  for chunk in waveform_chunks:
57
  input_tensor = torch.tensor(chunk).unsqueeze(0).to(torch.float32)
58
  # st.write(input_tensor[0][98])
59
+ with torch.no_grad():
60
+ result = model(input_tensor)
61
  # st.write(result)
62
  if np.abs(result[0][0]) > np.abs(result[0][1]):
63
  other += 1