Sourudra commited on
Commit
383490d
·
verified ·
1 Parent(s): e15433e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +271 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,274 @@
1
- import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
+ import joblib
5
+
6
+ # Set page config
7
+ st.set_page_config(page_title="Stress Detection using One-Class SVM", layout="centered")
8
+
9
+ # Custom CSS for background and styles
10
+ st.markdown(
11
+ """
12
+ <style>
13
+ .stApp {
14
+ background-image: url("https://i.postimg.cc/vZb3ymYT/360-F-1375669005-ebg3mldxps5-ZYr-QFl-Y6-EX3e-CINw-VDeo-F.jpg");
15
+ background-size: cover;
16
+ background-repeat: no-repeat;
17
+ background-attachment: fixed;
18
+ }
19
+
20
+ .block-container {
21
+ background-color: rgba(0, 0, 0, 0.6);
22
+ color: white;
23
+ padding: 2rem;
24
+ border-radius: 15px;
25
+ max-width: 800px;
26
+ margin: 2rem auto;
27
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5);
28
+ backdrop-filter: blur(8px);
29
+ -webkit-backdrop-filter: blur(8px);
30
+ border: 1px solid rgba(255, 255, 255, 0.1);
31
+ }
32
+
33
+ .sensor-row {
34
+ display: flex;
35
+ justify-content: space-around;
36
+ font-size: 1.1rem;
37
+ margin-top: 1.2rem;
38
+ margin-bottom: 1rem;
39
+ }
40
+
41
+ .sensor-row > div {
42
+ padding: 0.5rem 1rem;
43
+ background-color: rgba(255, 255, 255, 0.1);
44
+ border-radius: 8px;
45
+ }
46
+
47
+ .scroll-box {
48
+ max-height: 400px;
49
+ overflow-y: auto;
50
+ border: 1px solid #ccc;
51
+ padding: 1rem;
52
+ background-color: rgba(255,255,255,0.05);
53
+ border-radius: 10px;
54
+ }
55
+
56
+ h1, h2, h3, p, div {
57
+ color: white !important;
58
+ }
59
+
60
+ section.main > div:first-child {
61
+ padding-top: 0rem;
62
+ }
63
+
64
+ header[data-testid="stHeader"] {
65
+ height: 0rem;
66
+ visibility: hidden;
67
+ }
68
+ </style>
69
+ """,
70
+ unsafe_allow_html=True
71
+ )
72
+
73
+ st.title("Stress Detection")
74
+ st.markdown("Select a mode to detect stress from sensor readings:")
75
+
76
+ # Load model and scaler
77
+ try:
78
+ model = joblib.load("one_class_svm_stress_model.pkl")
79
+ scaler = joblib.load("scaler.pkl")
80
+ except Exception as e:
81
+ st.error(f"Error loading model or scaler: {e}")
82
+ st.stop()
83
+
84
+ # Load or create hardcoded dataset
85
+ try:
86
+ df = pd.read_csv("simulated_stress_data.csv")
87
+ df = df[['HR', 'HRV', 'EDA']].head(100)
88
+ except:
89
+ df = pd.DataFrame({
90
+ "HR": np.random.randint(60, 120, 100),
91
+ "HRV": np.random.uniform(20, 80, 100),
92
+ "EDA": np.random.uniform(0.1, 5.0, 100)
93
+ })
94
+
95
+ # Radio button selection
96
+ mode = st.radio("Choose input mode:", ["Manual Readings", "Generate Readings", "Test Dataset"], horizontal=True)
97
+
98
+ # Manual Input
99
+ if mode == "Manual Readings":
100
+ hr = st.number_input("Heart Rate (HR)", min_value=60, max_value=120, value=80)
101
+ hrv = st.number_input("Heart Rate Variability (HRV)", min_value=20.0, max_value=80.0, value=50.0)
102
+ eda = st.number_input("Electrodermal Activity (EDA)", min_value=0.1, max_value=5.0, value=2.0)
103
+
104
+ if st.button("Predict"):
105
+ sample = np.array([[hr, hrv, eda]])
106
+ scaled = scaler.transform(sample)
107
+ pred = model.predict(scaled)
108
+ label = "Stress" if pred[0] == -1 else "No Stress"
109
+
110
+ st.markdown(
111
+ f"""
112
+ <div class="sensor-row">
113
+ <div><strong>HR:</strong> {hr} bpm</div>
114
+ <div><strong>HRV:</strong> {hrv:.2f} ms</div>
115
+ <div><strong>EDA:</strong> {eda:.2f} µS</div>
116
+ </div>
117
+ """, unsafe_allow_html=True
118
+ )
119
+
120
+ st.subheader("Prediction")
121
+ if label == "No Stress":
122
+ st.success(label)
123
+ else:
124
+ st.error(label)
125
+
126
+ # Generate Random Input
127
+ elif mode == "Generate Readings":
128
+ if st.button("Generate and Predict"):
129
+ hr = np.random.randint(60, 120)
130
+ hrv = np.random.uniform(20, 80)
131
+ eda = np.random.uniform(0.1, 5.0)
132
+
133
+ sample = np.array([[hr, hrv, eda]])
134
+ scaled = scaler.transform(sample)
135
+ pred = model.predict(scaled)
136
+ label = "Stress" if pred[0] == -1 else "No Stress"
137
+
138
+ st.markdown(
139
+ f"""
140
+ <div class="sensor-row">
141
+ <div><strong>HR:</strong> {hr} bpm</div>
142
+ <div><strong>HRV:</strong> {hrv:.2f} ms</div>
143
+ <div><strong>EDA:</strong> {eda:.2f} µS</div>
144
+ </div>
145
+ """, unsafe_allow_html=True
146
+ )
147
+
148
+ st.subheader("Prediction")
149
+ if label == "No Stress":
150
+ st.success(label)
151
+ else:
152
+ st.error(label)
153
+
154
+
155
+ # Test Dataset (Scrollable)
156
+
157
+ elif mode == "Test Dataset":
158
+ st.markdown("### Select a row from test dataset for prediction:")
159
+
160
+ # Session state for pagination and prediction result
161
+ if "page" not in st.session_state:
162
+ st.session_state.page = 0
163
+ if "last_prediction" not in st.session_state:
164
+ st.session_state.last_prediction = None
165
+ st.session_state.last_row = None
166
+
167
+ rows_per_page = 5
168
+ df_filtered = df
169
+ total_pages = max(1, (len(df_filtered) - 1) // rows_per_page + 1)
170
+
171
+ # CSS for table rows
172
+ st.markdown("""
173
+ <style>
174
+ .scrollable-table {
175
+ max-height: 350px;
176
+ overflow-y: auto;
177
+ padding: 10px;
178
+ background-color: rgba(255,255,255,0.05);
179
+ border-radius: 10px;
180
+ border: 1px solid #ccc;
181
+ }
182
+ .row-box {
183
+ border-radius: 8px;
184
+ padding: 6px;
185
+ margin-bottom: 4px;
186
+ background-color: rgba(0, 128, 255, 0.15);
187
+ height: 40px;
188
+ display: flex;
189
+ align-items: center;
190
+ transition: background-color 0.3s ease;
191
+ }
192
+ .row-box:hover {
193
+ background-color: rgba(0, 128, 255, 0.3);
194
+ }
195
+ </style>
196
+ """, unsafe_allow_html=True)
197
+
198
+ # Display scrollable table
199
+ with st.container():
200
+ #st.markdown('<div class="scrollable-table">', unsafe_allow_html=True)
201
+
202
+ col1, col2, col3, col4 = st.columns([3, 3, 3, 2])
203
+ col1.markdown("**HR**")
204
+ col2.markdown("**HRV**")
205
+ col3.markdown("**EDA**")
206
+ col4.markdown("**Predict**")
207
+
208
+ page_data = df_filtered.iloc[
209
+ st.session_state.page * rows_per_page : (st.session_state.page + 1) * rows_per_page
210
+ ]
211
+
212
+ for idx, row in page_data.iterrows():
213
+ row_style = "row-box"
214
+
215
+ col1, col2, col3, col4 = st.columns([3, 3, 3, 2])
216
+ with col1:
217
+ st.markdown(f'<div class="{row_style}">{row["HR"]}</div>', unsafe_allow_html=True)
218
+ with col2:
219
+ st.markdown(f'<div class="{row_style}">{row["HRV"]:.2f}</div>', unsafe_allow_html=True)
220
+ with col3:
221
+ st.markdown(f'<div class="{row_style}">{row["EDA"]:.2f}</div>', unsafe_allow_html=True)
222
+ with col4:
223
+ if st.button("Select", key=f"select_{idx}"):
224
+ sample = np.array([[row['HR'], row['HRV'], row['EDA']]])
225
+ sample_scaled = scaler.transform(sample)
226
+ pred = model.predict(sample_scaled)
227
+ label = "Stress" if pred[0] == -1 else "No Stress"
228
+
229
+ st.session_state.last_prediction = label
230
+ st.session_state.last_row = row
231
+
232
+ st.markdown('</div>', unsafe_allow_html=True)
233
+
234
+
235
+ # Tighter Pagination Controls
236
+ col1, col2, col3 = st.columns([1.2, 1.2, 3])
237
+ with col1:
238
+ if st.button("⬅️ Previous"):
239
+ if st.session_state.page > 0:
240
+ st.session_state.page -= 1
241
+
242
+ with col2:
243
+ if st.button("Next ➡️"):
244
+ if st.session_state.page < total_pages - 1:
245
+ st.session_state.page += 1
246
+
247
+ with col3:
248
+ st.markdown(
249
+ f"<div style='text-align:left; margin-top: 0.5rem;'>Page {st.session_state.page + 1} of {total_pages}</div>",
250
+ unsafe_allow_html=True
251
+ )
252
+
253
+
254
+ # Final prediction display at the end
255
+ if st.session_state.last_prediction and st.session_state.last_row is not None:
256
+ row = st.session_state.last_row
257
+ label = st.session_state.last_prediction
258
+
259
+ st.markdown("---")
260
+ st.markdown("### Prediction")
261
+ st.markdown(
262
+ f"""
263
+ <div class="sensor-row">
264
+ <div><strong>HR:</strong> {row['HR']} bpm</div>
265
+ <div><strong>HRV:</strong> {row['HRV']:.2f} ms</div>
266
+ <div><strong>EDA:</strong> {row['EDA']:.2f} µS</div>
267
+ </div>
268
+ """, unsafe_allow_html=True
269
+ )
270
 
271
+ if label == "No Stress":
272
+ st.success(label)
273
+ else:
274
+ st.error(label)