Jfink09 commited on
Commit
afe13a1
·
verified ·
1 Parent(s): 8a7bb30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -71
app.py CHANGED
@@ -1,46 +1,131 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- from sklearn.metrics import r2_score
7
-
8
- class RegressionModel2(nn.Module):
9
- def __init__(self, input_dim2, hidden_dim2, output_dim2):
10
- super(RegressionModel2, self).__init__()
11
- self.fc1 = nn.Linear(input_dim2, hidden_dim2)
12
- self.relu1 = nn.ReLU()
13
- self.fc2 = nn.Linear(hidden_dim2, output_dim2)
14
- self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
15
-
16
- def forward(self, x2):
17
- out = self.fc1(x2)
18
- out = self.relu1(out)
19
- out = self.batch_norm1(out)
20
- out = self.fc2(out)
21
- return out
22
-
23
- # Load the saved model state dictionary
24
- model = RegressionModel2(3, 32, 1)
25
- model.load_state_dict(torch.load('model.pt'))
26
- model.eval() # Set the model to evaluation mode
27
-
28
- # Define a function to make predictions
29
- def predict_astigmatism(age, axis, aca):
30
- """
31
- This function takes three arguments (age, axis, aca) as input,
32
- converts them to a tensor, makes a prediction using the loaded model,
33
- and returns the predicted value.
34
- """
35
- # Prepare the input data
36
- data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
-
38
- # Make prediction
39
- with torch.no_grad():
40
- prediction = model(data)
41
-
42
- # Return the predicted value
43
- return prediction.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def main():
46
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
@@ -75,40 +160,33 @@ def main():
75
  padding: 10px 20px;
76
  font-size: 16px;
77
  }
 
 
 
 
78
  </style>
79
  """,
80
  unsafe_allow_html=True
81
  )
82
 
83
- # st.markdown(
84
- # """
85
- # <body>
86
- # <header>
87
- # <nav class="navbar">
88
- # <div class="logo"><img src="iol.png" alt="Image description"></div>
89
- # <ul class="menu">
90
- # <li><a href="#">Home</a></li>
91
- # <li><a href="#">About</a></li>
92
- # <li><a href="#">Contact</a></li>
93
- # </ul>
94
- # </nav>
95
- # <div class="text-content">
96
- # <h2>Enter Variables</h2>
97
- # <br>
98
- # </div>
99
- # </header>
100
- # </body>
101
- # """,
102
- # unsafe_allow_html=True
103
- # )
104
-
105
- age = st.number_input('Enter Patient Age:', step=0.1)
106
- aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1)
107
- aca_axis = st.number_input('Enter ACA Axis:', step=0.1)
108
 
109
  if st.button('Predict!'):
110
- astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
111
- st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
 
 
 
112
 
113
  if __name__ == '__main__':
114
- main()
 
1
+ # import streamlit as st
2
+ # import pandas as pd
3
+ # import torch
4
+ # import torch.nn as nn
5
+ # import torch.optim as optim
6
+ # from sklearn.metrics import r2_score
7
+
8
+ # class RegressionModel2(nn.Module):
9
+ # def __init__(self, input_dim2, hidden_dim2, output_dim2):
10
+ # super(RegressionModel2, self).__init__()
11
+ # self.fc1 = nn.Linear(input_dim2, hidden_dim2)
12
+ # self.relu1 = nn.ReLU()
13
+ # self.fc2 = nn.Linear(hidden_dim2, output_dim2)
14
+ # self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
15
+
16
+ # def forward(self, x2):
17
+ # out = self.fc1(x2)
18
+ # out = self.relu1(out)
19
+ # out = self.batch_norm1(out)
20
+ # out = self.fc2(out)
21
+ # return out
22
+
23
+ # # Load the saved model state dictionary
24
+ # model = RegressionModel2(3, 32, 1)
25
+ # model.load_state_dict(torch.load('model.pt'))
26
+ # model.eval() # Set the model to evaluation mode
27
+
28
+ # # Define a function to make predictions
29
+ # def predict_astigmatism(age, axis, aca):
30
+ # """
31
+ # This function takes three arguments (age, axis, aca) as input,
32
+ # converts them to a tensor, makes a prediction using the loaded model,
33
+ # and returns the predicted value.
34
+ # """
35
+ # # Prepare the input data
36
+ # data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
+
38
+ # # Make prediction
39
+ # with torch.no_grad():
40
+ # prediction = model(data)
41
+
42
+ # # Return the predicted value
43
+ # return prediction.item()
44
+
45
+ # def main():
46
+ # st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
47
+ # st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
48
+ # st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
49
+ # st.markdown(
50
+ # """
51
+ # <style>
52
+ # .navbar {
53
+ # display: flex;
54
+ # justify-content: space-between;
55
+ # align-items: center;
56
+ # background-color: #f2f2f2;
57
+ # padding: 10px;
58
+ # }
59
+ # .logo img {
60
+ # height: 50px;
61
+ # }
62
+ # .menu {
63
+ # list-style-type: none;
64
+ # display: flex;
65
+ # }
66
+ # .menu li {
67
+ # margin-left: 20px;
68
+ # }
69
+ # .text-content {
70
+ # margin-top: 50px;
71
+ # text-align: center;
72
+ # }
73
+ # .button {
74
+ # margin-top: 20px;
75
+ # padding: 10px 20px;
76
+ # font-size: 16px;
77
+ # }
78
+ # </style>
79
+ # """,
80
+ # unsafe_allow_html=True
81
+ # )
82
+
83
+ # # st.markdown(
84
+ # # """
85
+ # # <body>
86
+ # # <header>
87
+ # # <nav class="navbar">
88
+ # # <div class="logo"><img src="iol.png" alt="Image description"></div>
89
+ # # <ul class="menu">
90
+ # # <li><a href="#">Home</a></li>
91
+ # # <li><a href="#">About</a></li>
92
+ # # <li><a href="#">Contact</a></li>
93
+ # # </ul>
94
+ # # </nav>
95
+ # # <div class="text-content">
96
+ # # <h2>Enter Variables</h2>
97
+ # # <br>
98
+ # # </div>
99
+ # # </header>
100
+ # # </body>
101
+ # # """,
102
+ # # unsafe_allow_html=True
103
+ # # )
104
+
105
+ # age = st.number_input('Enter Patient Age:', step=0.1)
106
+ # aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1)
107
+ # aca_axis = st.number_input('Enter ACA Axis:', step=0.1)
108
+
109
+ # if st.button('Predict!'):
110
+ # astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
111
+ # st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
112
+
113
+ # if __name__ == '__main__':
114
+ # main()
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
 
130
  def main():
131
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
 
160
  padding: 10px 20px;
161
  font-size: 16px;
162
  }
163
+ .error {
164
+ color: red;
165
+ font-weight: bold;
166
+ }
167
  </style>
168
  """,
169
  unsafe_allow_html=True
170
  )
171
 
172
+ age = st.number_input('Enter Patient Age:', min_value=0.0, max_value=120.0, step=0.1)
173
+ if age < 18 or age > 90:
174
+ st.markdown('<p class="error">Error: Age must be between 18 and 90.</p>', unsafe_allow_html=True)
175
+
176
+ aca_magnitude = st.number_input('Enter ACA Magnitude:', min_value=0.0, max_value=20.0, step=0.1)
177
+ if aca_magnitude < 0 or aca_magnitude > 10:
178
+ st.markdown('<p class="error">Error: ACA Magnitude must be between 0 and 10.</p>', unsafe_allow_html=True)
179
+
180
+ aca_axis = st.number_input('Enter ACA Axis:', min_value=0.0, max_value=180.0, step=0.1)
181
+ if aca_axis < 0 or aca_axis > 180:
182
+ st.markdown('<p class="error">Error: ACA Axis must be between 0 and 180.</p>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if st.button('Predict!'):
185
+ if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
186
+ astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
187
+ st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
188
+ else:
189
+ st.error('Please correct the input errors before predicting.')
190
 
191
  if __name__ == '__main__':
192
+ main()