ggirishg commited on
Commit
8f1708c
·
verified ·
1 Parent(s): 0e88a7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -71
app.py CHANGED
@@ -10,12 +10,19 @@ import time
10
  import tempfile
11
  import streamlit.components.v1 as components
12
 
13
- # Ensure setup.sh is executable and then run it using bash
14
- subprocess.run(['chmod', '+x', 'setup.sh'])
15
- subprocess.run(['bash', 'setup.sh'], check=True)
 
16
 
17
- # Load the model from TensorFlow Hub
18
- m = hub.KerasLayer('https://tfhub.dev/google/nonsemantic-speech-benchmark/trillsson4/1')
 
 
 
 
 
 
19
 
20
  class TransformerEncoder(tf.keras.layers.Layer):
21
  def __init__(self, embed_dim, num_heads, ff_dim, rate=0.01, **kwargs):
@@ -49,7 +56,14 @@ class TransformerEncoder(tf.keras.layers.Layer):
49
  })
50
  return config
51
 
52
- model = load_model('autism_detection_model3.h5', custom_objects={'TransformerEncoder': TransformerEncoder})
 
 
 
 
 
 
 
53
 
54
  def extract_features(path):
55
  sample_rate = 16000
@@ -59,6 +73,9 @@ def extract_features(path):
59
  if array.shape[0] > 1:
60
  array = np.mean(array, axis=0, keepdims=True)
61
 
 
 
 
62
  embeddings = m(array)['embedding']
63
  embeddings.shape.assert_is_compatible_with([None, 1024])
64
  embeddings = np.squeeze(np.array(embeddings), axis=0)
@@ -69,20 +86,44 @@ st.markdown('<span style="color:black; font-size: 48px; font-weight: bold;">Neu<
69
 
70
  option = st.radio("**Choose an option:**", ["Upload an audio file", "Record audio"])
71
 
72
- if option == "Upload an audio file":
73
- uploaded_file = st.file_uploader("Upload an audio file (.wav)", type=["wav"])
74
- if uploaded_file is not None:
75
- start_time = time.time() # Record start time
76
- with st.spinner('Extracting features...'):
77
- # Save the uploaded file temporarily
78
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
79
- temp_file.write(uploaded_file.getbuffer())
80
- temp_file_path = temp_file.name
81
-
82
- features = extract_features(temp_file_path)
83
- os.remove(temp_file_path)
84
-
85
- # Display prediction probabilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  prediction = model.predict(np.expand_dims(features, axis=0))
87
  autism_probability = prediction[0][1]
88
  normal_probability = prediction[0][0]
@@ -116,67 +157,181 @@ if option == "Upload an audio file":
116
  unsafe_allow_html=True
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
119
  elapsed_time = round(time.time() - start_time, 2)
120
  st.write(f"Elapsed Time: {elapsed_time} seconds")
121
 
122
  else: # Option is "Record audio"
123
- # Load and display the local index.html file
124
- with open("index.html", 'r', encoding='utf-8') as f:
125
- html_content = f.read()
126
- components.html(html_content, height=600)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- if st.button("Click to Predict"):
129
- # Save the recorded audio file temporarily
130
- recorded_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.wav').name
131
- converted_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.wav').name
 
132
 
133
- # Run the ffmpeg command to convert the recorded audio
134
- os.system(f'ffmpeg -i {recorded_audio_path} -acodec pcm_s16le -ar 16000 -ac 1 {converted_audio_path}')
135
-
136
- # Process the converted audio file
137
- features = extract_features(converted_audio_path)
138
 
139
- # Display prediction probabilities
140
- prediction = model.predict(np.expand_dims(features, axis=0))
141
- autism_probability = prediction[0][1]
142
- normal_probability = prediction[0][0]
 
 
 
 
 
 
 
 
 
143
 
144
- st.subheader("Prediction Probabilities:")
 
 
145
 
146
- if autism_probability > normal_probability:
147
- st.markdown(
148
- f'<div style="background-color:#658EA9;padding:20px;border-radius:10px;margin-bottom:40px;">'
149
- f'<h3 style="color:black;">Autism: {autism_probability}</h3>'
150
- '</div>',
151
- unsafe_allow_html=True
152
- )
153
- st.markdown(
154
- f'<div style="background-color:#ADD8E6;padding:20px;border-radius:10px;margin-bottom:40px;">'
155
- f'<h3 style="color:black;">Normal: {normal_probability}</h3>'
156
- '</div>',
157
- unsafe_allow_html=True
158
- )
159
- else:
160
- st.markdown(
161
- f'<div style="background-color:#658EA9;padding:20px;border-radius:10px;margin-bottom:40px;">'
162
- f'<h3 style="color:black;">Normal: {normal_probability}</h3>'
163
- '</div>',
164
- unsafe_allow_html=True
165
- )
166
- st.markdown(
167
- f'<div style="background-color:#ADD8E6;padding:20px;border-radius:10px;margin-bottom:40px;">'
168
- f'<h3 style="color:black;">Autism: {autism_probability}</h3>'
169
- '</div>',
170
- unsafe_allow_html=True
171
- )
172
 
173
- # Remove temporary audio files
174
- try:
175
- os.remove(recorded_audio_path)
176
- except Exception as e:
177
- print(f"Error deleting '{recorded_audio_path}': {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  try:
180
- os.remove(converted_audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  except Exception as e:
182
- print(f"Error deleting '{converted_audio_path}': {e}")
 
10
  import tempfile
11
  import streamlit.components.v1 as components
12
 
13
+ # Attempt to set GPU memory growth
14
+ try:
15
+ from tensorflow.compat.v1 import ConfigProto
16
+ from tensorflow.compat.v1 import InteractiveSession
17
 
18
+ config = ConfigProto()
19
+ config.gpu_options.allow_growth = True
20
+ session = InteractiveSession(config=config)
21
+ except Exception as e:
22
+ st.warning(f"Could not set GPU memory growth: {e}")
23
+
24
+ model_path = 'TrillsonFeature_model'
25
+ m = hub.load(model_path)
26
 
27
  class TransformerEncoder(tf.keras.layers.Layer):
28
  def __init__(self, embed_dim, num_heads, ff_dim, rate=0.01, **kwargs):
 
56
  })
57
  return config
58
 
59
+ def load_autism_model():
60
+ try:
61
+ return load_model('autism_detection_model3.h5', custom_objects={'TransformerEncoder': TransformerEncoder})
62
+ except Exception as e:
63
+ st.error(f"Error loading model: {e}")
64
+ return None
65
+
66
+ model = load_autism_model()
67
 
68
  def extract_features(path):
69
  sample_rate = 16000
 
73
  if array.shape[0] > 1:
74
  array = np.mean(array, axis=0, keepdims=True)
75
 
76
+ # Truncate the audio to 10 seconds for reducing memory usage
77
+ array = array[:, :sample_rate * 10]
78
+
79
  embeddings = m(array)['embedding']
80
  embeddings.shape.assert_is_compatible_with([None, 1024])
81
  embeddings = np.squeeze(np.array(embeddings), axis=0)
 
86
 
87
  option = st.radio("**Choose an option:**", ["Upload an audio file", "Record audio"])
88
 
89
+ def run_prediction(features):
90
+ try:
91
+ prediction = model.predict(np.expand_dims(features, axis=0))
92
+ autism_probability = prediction[0][1]
93
+ normal_probability = prediction[0][0]
94
+
95
+ st.subheader("Prediction Probabilities:")
96
+
97
+ if autism_probability > normal_probability:
98
+ st.markdown(
99
+ f'<div style="background-color:#658EA9;padding:20px;border-radius:10px;margin-bottom:40px;">'
100
+ f'<h3 style="color:black;">Autism: {autism_probability}</h3>'
101
+ '</div>',
102
+ unsafe_allow_html=True
103
+ )
104
+ st.markdown(
105
+ f'<div style="background-color:#ADD8E6;padding:20px;border-radius:10px;margin-bottom:40px;">'
106
+ f'<h3 style="color:black;">Normal: {normal_probability}</h3>'
107
+ '</div>',
108
+ unsafe_allow_html=True
109
+ )
110
+ else:
111
+ st.markdown(
112
+ f'<div style="background-color:#658EA9;padding:20px;border-radius:10px;margin-bottom:40px;">'
113
+ f'<h3 style="color:black;">Normal: {normal_probability}</h3>'
114
+ '</div>',
115
+ unsafe_allow_html=True
116
+ )
117
+ st.markdown(
118
+ f'<div style="background-color:#ADD8E6;padding:20px;border-radius:10px;margin-bottom:40px;">'
119
+ f'<h3 style="color:black;">Autism: {autism_probability}</h3>'
120
+ '</div>',
121
+ unsafe_allow_html=True
122
+ )
123
+
124
+ except tf.errors.ResourceExhaustedError as e:
125
+ st.error("Resource exhausted error: switching to CPU.")
126
+ with tf.device('/cpu:0'):
127
  prediction = model.predict(np.expand_dims(features, axis=0))
128
  autism_probability = prediction[0][1]
129
  normal_probability = prediction[0][0]
 
157
  unsafe_allow_html=True
158
  )
159
 
160
+ if option == "Upload an audio file":
161
+ uploaded_file = st.file_uploader("Upload an audio file (.wav)", type=["wav"])
162
+ if uploaded_file is not None:
163
+ start_time = time.time() # Record start time
164
+ with st.spinner('Extracting features...'):
165
+ # Process the uploaded file
166
+ with open("temp_audio.wav", "wb") as f:
167
+ f.write(uploaded_file.getbuffer())
168
+ features = extract_features("temp_audio.wav")
169
+ os.remove("temp_audio.wav")
170
+ run_prediction(features)
171
  elapsed_time = round(time.time() - start_time, 2)
172
  st.write(f"Elapsed Time: {elapsed_time} seconds")
173
 
174
  else: # Option is "Record audio"
175
+ audio_recorder_html = '''
176
+ <!DOCTYPE html>
177
+ <html lang="en">
178
+ <head>
179
+ <meta charset="UTF-8">
180
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
181
+ <title>Audio Recorder</title>
182
+ <style>
183
+ body {
184
+ font-family: Arial, sans-serif;
185
+ background-color: #ffffff;
186
+ margin: 0;
187
+ padding: 0;
188
+ display: flex;
189
+ justify-content: center;
190
+ align-items: center;
191
+ height: 100vh;
192
+ }
193
 
194
+ .container {
195
+ text-align: center;
196
+ background-color: #ffffff;
197
+ border-radius: 0%;
198
+ }
199
 
200
+ h1 {
201
+ color: #000000;
202
+ }
 
 
203
 
204
+ button {
205
+ background-color: #40826D;
206
+ color: rgb(0, 0, 0);
207
+ border: none;
208
+ padding: 10px 20px;
209
+ text-align: center;
210
+ text-decoration: none;
211
+ display: inline-block;
212
+ font-size: 16px;
213
+ margin: 10px;
214
+ cursor: pointer;
215
+ border-radius: 5px;
216
+ }
217
 
218
+ button:hover {
219
+ background-color: #40826D;
220
+ }
221
 
222
+ button:disabled {
223
+ background-color: #df5e5e;
224
+ cursor: not-allowed;
225
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ #timer {
228
+ font-size: 20px;
229
+ margin-top: 20px;
230
+ color: #000000;
231
+ }
232
+ </style>
233
+ </head>
234
+ <body>
235
+ <div class="container">
236
+ <h1>Audio Recorder</h1>
237
+ <button id="startRecording">Start Recording</button>
238
+ <button id="stopRecording" disabled>Stop Recording</button>
239
+ <div id="timer">00:00</div>
240
+ </div>
241
+
242
+ <script>
243
+ let recorder;
244
+ let audioChunks = [];
245
+ let startTime;
246
+ let timerInterval;
247
+
248
+ function updateTime() {
249
+ const elapsedTime = Math.floor((Date.now() - startTime) / 1000);
250
+ const minutes = Math.floor(elapsedTime / 60);
251
+ const seconds = elapsedTime % 60;
252
+ const formattedTime = `${minutes.toString().padStart(2, '0')}:${seconds.toString().padStart(2, '0')}`;
253
+ document.getElementById('timer').textContent = formattedTime;
254
+ }
255
 
256
+ navigator.mediaDevices.getUserMedia({ audio: true })
257
+ .then(stream => {
258
+ recorder = new MediaRecorder(stream);
259
+
260
+ recorder.ondataavailable = e => {
261
+ audioChunks.push(e.data);
262
+ };
263
+
264
+ recorder.onstart = () => {
265
+ startTime = Date.now();
266
+ timerInterval = setInterval(updateTime, 1000);
267
+ };
268
+
269
+ recorder.onstop = () => {
270
+ const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
271
+ const audioUrl = URL.createObjectURL(audioBlob);
272
+ const a = document.createElement('a');
273
+ a.href = audioUrl;
274
+ a.download = 'recorded_audio.wav';
275
+ document.body.appendChild(a);
276
+ a.click();
277
+
278
+ // Reset
279
+ audioChunks = [];
280
+ clearInterval(timerInterval);
281
+ };
282
+ })
283
+ .catch(err => {
284
+ console.error('Permission to access microphone denied:', err);
285
+ });
286
+
287
+ document.getElementById('startRecording').addEventListener('click', () => {
288
+ recorder.start();
289
+ document.getElementById('startRecording').disabled = true;
290
+ document.getElementById('stopRecording').disabled = false;
291
+ setTimeout(() => {
292
+ recorder.stop();
293
+ document.getElementById('startRecording').disabled = false;
294
+ document.getElementById('stopRecording').disabled = true;
295
+ }, 15000); // 15 seconds
296
+ });
297
+
298
+ document.getElementById('stopRecording').addEventListener('click', () => {
299
+ recorder.stop();
300
+ document.getElementById('startRecording').disabled = false;
301
+ document.getElementById('stopRecording').disabled = true;
302
+ });
303
+ </script>
304
+ </body>
305
+ </html>
306
+ '''
307
+ st.components.v1.html(audio_recorder_html, height=600)
308
+
309
+ if st.button("Click to Predict"):
310
  try:
311
+ # Run the ffmpeg command to convert the recorded audio
312
+ command = 'ffmpeg -i C:/Users/giris/Downloads/recorded_audio.wav -acodec pcm_s16le -ar 16000 -ac 1 C:/Users/giris/Downloads/recorded_audio2.wav'
313
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
314
+ if result.returncode != 0:
315
+ st.error(f"Error running ffmpeg: {result.stderr}")
316
+ else:
317
+ # Check if the file exists
318
+ if not os.path.exists("C:/Users/giris/Downloads/recorded_audio2.wav"):
319
+ st.error("The converted audio file was not created.")
320
+ else:
321
+ # Process the converted audio file
322
+ features = extract_features("C:/Users/giris/Downloads/recorded_audio2.wav")
323
+ run_prediction(features)
324
+
325
+ # Try to delete the first audio file
326
+ try:
327
+ os.remove("recorded_audio.wav")
328
+ except Exception as e:
329
+ print(f"Error deleting 'recorded_audio.wav': {e}")
330
+
331
+ # Try to delete the second audio file
332
+ try:
333
+ os.remove("recorded_audio2.wav")
334
+ except Exception as e:
335
+ print(f"Error deleting 'recorded_audio2.wav': {e}")
336
  except Exception as e:
337
+ st.error(f"An error occurred: {e}")