Jfink09 commited on
Commit
9cf86b7
·
verified ·
1 Parent(s): b197d08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -41
app.py CHANGED
@@ -1,9 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
  from sklearn.metrics import r2_score
 
7
 
8
  class RegressionModel2(nn.Module):
9
  def __init__(self, input_dim2, hidden_dim2, output_dim2):
@@ -25,27 +143,52 @@ model = RegressionModel2(3, 32, 1)
25
  model.load_state_dict(torch.load('model.pt'))
26
  model.eval() # Set the model to evaluation mode
27
 
28
- # Define a function to make predictions
29
  def predict_astigmatism(age, axis, aca):
30
- """
31
- This function takes three arguments (age, axis, aca) as input,
32
- converts them to a tensor, makes a prediction using the loaded model,
33
- and returns the predicted value.
34
- """
35
- # Prepare the input data
36
- data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
-
38
- # Make prediction
39
- with torch.no_grad():
40
- prediction = model(data)
 
 
 
41
 
42
- # Return the predicted value
43
- return prediction.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def main():
46
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
47
  st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
48
  st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
 
49
  st.markdown(
50
  """
51
  <style>
@@ -80,35 +223,21 @@ def main():
80
  unsafe_allow_html=True
81
  )
82
 
83
- # st.markdown(
84
- # """
85
- # <body>
86
- # <header>
87
- # <nav class="navbar">
88
- # <div class="logo"><img src="iol.png" alt="Image description"></div>
89
- # <ul class="menu">
90
- # <li><a href="#">Home</a></li>
91
- # <li><a href="#">About</a></li>
92
- # <li><a href="#">Contact</a></li>
93
- # </ul>
94
- # </nav>
95
- # <div class="text-content">
96
- # <h2>Enter Variables</h2>
97
- # <br>
98
- # </div>
99
- # </header>
100
- # </body>
101
- # """,
102
- # unsafe_allow_html=True
103
- # )
104
-
105
- age = st.number_input('Enter Patient Age:', step=0.1)
106
- aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1)
107
- aca_axis = st.number_input('Enter ACA Axis:', step=0.1)
108
 
109
  if st.button('Predict!'):
110
- astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
111
- st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
 
 
 
 
 
 
112
 
113
  if __name__ == '__main__':
114
  main()
 
1
+ # import streamlit as st
2
+ # import pandas as pd
3
+ # import torch
4
+ # import torch.nn as nn
5
+ # import torch.optim as optim
6
+ # from sklearn.metrics import r2_score
7
+
8
+ # class RegressionModel2(nn.Module):
9
+ # def __init__(self, input_dim2, hidden_dim2, output_dim2):
10
+ # super(RegressionModel2, self).__init__()
11
+ # self.fc1 = nn.Linear(input_dim2, hidden_dim2)
12
+ # self.relu1 = nn.ReLU()
13
+ # self.fc2 = nn.Linear(hidden_dim2, output_dim2)
14
+ # self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
15
+
16
+ # def forward(self, x2):
17
+ # out = self.fc1(x2)
18
+ # out = self.relu1(out)
19
+ # out = self.batch_norm1(out)
20
+ # out = self.fc2(out)
21
+ # return out
22
+
23
+ # # Load the saved model state dictionary
24
+ # model = RegressionModel2(3, 32, 1)
25
+ # model.load_state_dict(torch.load('model.pt'))
26
+ # model.eval() # Set the model to evaluation mode
27
+
28
+ # # Define a function to make predictions
29
+ # def predict_astigmatism(age, axis, aca):
30
+ # """
31
+ # This function takes three arguments (age, axis, aca) as input,
32
+ # converts them to a tensor, makes a prediction using the loaded model,
33
+ # and returns the predicted value.
34
+ # """
35
+ # # Prepare the input data
36
+ # data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
+
38
+ # # Make prediction
39
+ # with torch.no_grad():
40
+ # prediction = model(data)
41
+
42
+ # # Return the predicted value
43
+ # return prediction.item()
44
+
45
+ # def main():
46
+ # st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
47
+ # st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
48
+ # st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
49
+ # st.markdown(
50
+ # """
51
+ # <style>
52
+ # .navbar {
53
+ # display: flex;
54
+ # justify-content: space-between;
55
+ # align-items: center;
56
+ # background-color: #f2f2f2;
57
+ # padding: 10px;
58
+ # }
59
+ # .logo img {
60
+ # height: 50px;
61
+ # }
62
+ # .menu {
63
+ # list-style-type: none;
64
+ # display: flex;
65
+ # }
66
+ # .menu li {
67
+ # margin-left: 20px;
68
+ # }
69
+ # .text-content {
70
+ # margin-top: 50px;
71
+ # text-align: center;
72
+ # }
73
+ # .button {
74
+ # margin-top: 20px;
75
+ # padding: 10px 20px;
76
+ # font-size: 16px;
77
+ # }
78
+ # </style>
79
+ # """,
80
+ # unsafe_allow_html=True
81
+ # )
82
+
83
+ # # st.markdown(
84
+ # # """
85
+ # # <body>
86
+ # # <header>
87
+ # # <nav class="navbar">
88
+ # # <div class="logo"><img src="iol.png" alt="Image description"></div>
89
+ # # <ul class="menu">
90
+ # # <li><a href="#">Home</a></li>
91
+ # # <li><a href="#">About</a></li>
92
+ # # <li><a href="#">Contact</a></li>
93
+ # # </ul>
94
+ # # </nav>
95
+ # # <div class="text-content">
96
+ # # <h2>Enter Variables</h2>
97
+ # # <br>
98
+ # # </div>
99
+ # # </header>
100
+ # # </body>
101
+ # # """,
102
+ # # unsafe_allow_html=True
103
+ # # )
104
+
105
+ # age = st.number_input('Enter Patient Age:', step=0.1)
106
+ # aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1)
107
+ # aca_axis = st.number_input('Enter ACA Axis:', step=0.1)
108
+
109
+ # if st.button('Predict!'):
110
+ # astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
111
+ # st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
112
+
113
+ # if __name__ == '__main__':
114
+ # main()
115
+
116
+
117
+
118
  import streamlit as st
119
  import pandas as pd
120
  import torch
121
  import torch.nn as nn
122
  import torch.optim as optim
123
  from sklearn.metrics import r2_score
124
+ import math
125
 
126
  class RegressionModel2(nn.Module):
127
  def __init__(self, input_dim2, hidden_dim2, output_dim2):
 
143
  model.load_state_dict(torch.load('model.pt'))
144
  model.eval() # Set the model to evaluation mode
145
 
 
146
  def predict_astigmatism(age, axis, aca):
147
+ """
148
+ This function takes three arguments (age, axis, aca) as input,
149
+ converts them to a tensor, makes a prediction using the loaded model,
150
+ and returns the predicted value.
151
+ """
152
+ # Prepare the input data
153
+ data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
154
+
155
+ # Make prediction
156
+ with torch.no_grad():
157
+ prediction = model(data)
158
+
159
+ # Return the predicted value
160
+ return prediction.item()
161
 
162
+ def predict_axis(aca_magnitude, aca_axis):
163
+ # Convert axis to radians
164
+ aca_axis_rad = math.radians(aca_axis)
165
+
166
+ # Calculate X and Y components
167
+ X = aca_magnitude * math.cos(2 * aca_axis_rad)
168
+ Y = aca_magnitude * math.sin(2 * aca_axis_rad)
169
+
170
+ # Calculate intermediate axis prediction
171
+ Z = math.degrees(0.5 * math.atan2(Y, X))
172
+
173
+ # Determine final predicted axis
174
+ if X > 0:
175
+ if Y > 0:
176
+ predicted_axis = Z
177
+ else:
178
+ predicted_axis = Z + 180
179
+ else:
180
+ predicted_axis = Z + 90
181
+
182
+ # Ensure the axis is between 0 and 180 degrees
183
+ predicted_axis = predicted_axis % 180
184
+
185
+ return predicted_axis
186
 
187
  def main():
188
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
189
  st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
190
  st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
191
+
192
  st.markdown(
193
  """
194
  <style>
 
223
  unsafe_allow_html=True
224
  )
225
 
226
+ st.title('Total Corneal Astigmatism Prediction')
227
+
228
+ age = st.number_input('Enter Patient Age:', min_value=0.0, step=0.1)
229
+ aca_magnitude = st.number_input('Enter ACA Magnitude:', min_value=0.0, step=0.1)
230
+ aca_axis = st.number_input('Enter ACA Axis:', min_value=0.0, max_value=180.0, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if st.button('Predict!'):
233
+ # Predict magnitude
234
+ tca_magnitude = predict_astigmatism(age, aca_axis, aca_magnitude)
235
+
236
+ # Predict axis
237
+ tca_axis = predict_axis(aca_magnitude, aca_axis)
238
+
239
+ st.success(f'Predicted Total Corneal Astigmatism Magnitude: {tca_magnitude:.4f} D')
240
+ st.success(f'Predicted Total Corneal Astigmatism Axis: {tca_axis:.2f}°')
241
 
242
  if __name__ == '__main__':
243
  main()