Jfink09 commited on
Commit
8a7bb30
·
verified ·
1 Parent(s): 9cf86b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -171
app.py CHANGED
@@ -1,127 +1,9 @@
1
- # import streamlit as st
2
- # import pandas as pd
3
- # import torch
4
- # import torch.nn as nn
5
- # import torch.optim as optim
6
- # from sklearn.metrics import r2_score
7
-
8
- # class RegressionModel2(nn.Module):
9
- # def __init__(self, input_dim2, hidden_dim2, output_dim2):
10
- # super(RegressionModel2, self).__init__()
11
- # self.fc1 = nn.Linear(input_dim2, hidden_dim2)
12
- # self.relu1 = nn.ReLU()
13
- # self.fc2 = nn.Linear(hidden_dim2, output_dim2)
14
- # self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
15
-
16
- # def forward(self, x2):
17
- # out = self.fc1(x2)
18
- # out = self.relu1(out)
19
- # out = self.batch_norm1(out)
20
- # out = self.fc2(out)
21
- # return out
22
-
23
- # # Load the saved model state dictionary
24
- # model = RegressionModel2(3, 32, 1)
25
- # model.load_state_dict(torch.load('model.pt'))
26
- # model.eval() # Set the model to evaluation mode
27
-
28
- # # Define a function to make predictions
29
- # def predict_astigmatism(age, axis, aca):
30
- # """
31
- # This function takes three arguments (age, axis, aca) as input,
32
- # converts them to a tensor, makes a prediction using the loaded model,
33
- # and returns the predicted value.
34
- # """
35
- # # Prepare the input data
36
- # data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
-
38
- # # Make prediction
39
- # with torch.no_grad():
40
- # prediction = model(data)
41
-
42
- # # Return the predicted value
43
- # return prediction.item()
44
-
45
- # def main():
46
- # st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
47
- # st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
48
- # st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
49
- # st.markdown(
50
- # """
51
- # <style>
52
- # .navbar {
53
- # display: flex;
54
- # justify-content: space-between;
55
- # align-items: center;
56
- # background-color: #f2f2f2;
57
- # padding: 10px;
58
- # }
59
- # .logo img {
60
- # height: 50px;
61
- # }
62
- # .menu {
63
- # list-style-type: none;
64
- # display: flex;
65
- # }
66
- # .menu li {
67
- # margin-left: 20px;
68
- # }
69
- # .text-content {
70
- # margin-top: 50px;
71
- # text-align: center;
72
- # }
73
- # .button {
74
- # margin-top: 20px;
75
- # padding: 10px 20px;
76
- # font-size: 16px;
77
- # }
78
- # </style>
79
- # """,
80
- # unsafe_allow_html=True
81
- # )
82
-
83
- # # st.markdown(
84
- # # """
85
- # # <body>
86
- # # <header>
87
- # # <nav class="navbar">
88
- # # <div class="logo"><img src="iol.png" alt="Image description"></div>
89
- # # <ul class="menu">
90
- # # <li><a href="#">Home</a></li>
91
- # # <li><a href="#">About</a></li>
92
- # # <li><a href="#">Contact</a></li>
93
- # # </ul>
94
- # # </nav>
95
- # # <div class="text-content">
96
- # # <h2>Enter Variables</h2>
97
- # # <br>
98
- # # </div>
99
- # # </header>
100
- # # </body>
101
- # # """,
102
- # # unsafe_allow_html=True
103
- # # )
104
-
105
- # age = st.number_input('Enter Patient Age:', step=0.1)
106
- # aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1)
107
- # aca_axis = st.number_input('Enter ACA Axis:', step=0.1)
108
-
109
- # if st.button('Predict!'):
110
- # astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
111
- # st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
112
-
113
- # if __name__ == '__main__':
114
- # main()
115
-
116
-
117
-
118
  import streamlit as st
119
  import pandas as pd
120
  import torch
121
  import torch.nn as nn
122
  import torch.optim as optim
123
  from sklearn.metrics import r2_score
124
- import math
125
 
126
  class RegressionModel2(nn.Module):
127
  def __init__(self, input_dim2, hidden_dim2, output_dim2):
@@ -143,52 +25,27 @@ model = RegressionModel2(3, 32, 1)
143
  model.load_state_dict(torch.load('model.pt'))
144
  model.eval() # Set the model to evaluation mode
145
 
 
146
  def predict_astigmatism(age, axis, aca):
147
- """
148
- This function takes three arguments (age, axis, aca) as input,
149
- converts them to a tensor, makes a prediction using the loaded model,
150
- and returns the predicted value.
151
- """
152
- # Prepare the input data
153
- data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
154
-
155
- # Make prediction
156
- with torch.no_grad():
157
- prediction = model(data)
158
-
159
- # Return the predicted value
160
- return prediction.item()
161
 
162
- def predict_axis(aca_magnitude, aca_axis):
163
- # Convert axis to radians
164
- aca_axis_rad = math.radians(aca_axis)
165
-
166
- # Calculate X and Y components
167
- X = aca_magnitude * math.cos(2 * aca_axis_rad)
168
- Y = aca_magnitude * math.sin(2 * aca_axis_rad)
169
-
170
- # Calculate intermediate axis prediction
171
- Z = math.degrees(0.5 * math.atan2(Y, X))
172
-
173
- # Determine final predicted axis
174
- if X > 0:
175
- if Y > 0:
176
- predicted_axis = Z
177
- else:
178
- predicted_axis = Z + 180
179
- else:
180
- predicted_axis = Z + 90
181
-
182
- # Ensure the axis is between 0 and 180 degrees
183
- predicted_axis = predicted_axis % 180
184
-
185
- return predicted_axis
186
 
187
  def main():
188
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
189
  st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
190
  st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
191
-
192
  st.markdown(
193
  """
194
  <style>
@@ -223,21 +80,35 @@ def main():
223
  unsafe_allow_html=True
224
  )
225
 
226
- st.title('Total Corneal Astigmatism Prediction')
227
-
228
- age = st.number_input('Enter Patient Age:', min_value=0.0, step=0.1)
229
- aca_magnitude = st.number_input('Enter ACA Magnitude:', min_value=0.0, step=0.1)
230
- aca_axis = st.number_input('Enter ACA Axis:', min_value=0.0, max_value=180.0, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if st.button('Predict!'):
233
- # Predict magnitude
234
- tca_magnitude = predict_astigmatism(age, aca_axis, aca_magnitude)
235
-
236
- # Predict axis
237
- tca_axis = predict_axis(aca_magnitude, aca_axis)
238
-
239
- st.success(f'Predicted Total Corneal Astigmatism Magnitude: {tca_magnitude:.4f} D')
240
- st.success(f'Predicted Total Corneal Astigmatism Axis: {tca_axis:.2f}°')
241
 
242
  if __name__ == '__main__':
243
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
  from sklearn.metrics import r2_score
 
7
 
8
  class RegressionModel2(nn.Module):
9
  def __init__(self, input_dim2, hidden_dim2, output_dim2):
 
25
  model.load_state_dict(torch.load('model.pt'))
26
  model.eval() # Set the model to evaluation mode
27
 
28
+ # Define a function to make predictions
29
  def predict_astigmatism(age, axis, aca):
30
+ """
31
+ This function takes three arguments (age, axis, aca) as input,
32
+ converts them to a tensor, makes a prediction using the loaded model,
33
+ and returns the predicted value.
34
+ """
35
+ # Prepare the input data
36
+ data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
+
38
+ # Make prediction
39
+ with torch.no_grad():
40
+ prediction = model(data)
 
 
 
41
 
42
+ # Return the predicted value
43
+ return prediction.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def main():
46
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
47
  st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
48
  st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
 
49
  st.markdown(
50
  """
51
  <style>
 
80
  unsafe_allow_html=True
81
  )
82
 
83
+ # st.markdown(
84
+ # """
85
+ # <body>
86
+ # <header>
87
+ # <nav class="navbar">
88
+ # <div class="logo"><img src="iol.png" alt="Image description"></div>
89
+ # <ul class="menu">
90
+ # <li><a href="#">Home</a></li>
91
+ # <li><a href="#">About</a></li>
92
+ # <li><a href="#">Contact</a></li>
93
+ # </ul>
94
+ # </nav>
95
+ # <div class="text-content">
96
+ # <h2>Enter Variables</h2>
97
+ # <br>
98
+ # </div>
99
+ # </header>
100
+ # </body>
101
+ # """,
102
+ # unsafe_allow_html=True
103
+ # )
104
+
105
+ age = st.number_input('Enter Patient Age:', step=0.1)
106
+ aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1)
107
+ aca_axis = st.number_input('Enter ACA Axis:', step=0.1)
108
 
109
  if st.button('Predict!'):
110
+ astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude)
111
+ st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}')
 
 
 
 
 
 
112
 
113
  if __name__ == '__main__':
114
+ main()