Jfink09 commited on
Commit
52192c2
·
verified ·
1 Parent(s): bc3ac55

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from sklearn.metrics import r2_score
7
+ import math
8
+
9
+ class RegressionModel2(nn.Module):
10
+ def __init__(self, input_dim2, hidden_dim2, output_dim2):
11
+ super(RegressionModel2, self).__init__()
12
+ self.fc1 = nn.Linear(input_dim2, hidden_dim2)
13
+ self.relu1 = nn.ReLU()
14
+ self.fc2 = nn.Linear(hidden_dim2, output_dim2)
15
+ self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
16
+
17
+ def forward(self, x2):
18
+ out = self.fc1(x2)
19
+ out = self.relu1(out)
20
+ out = self.batch_norm1(out)
21
+ out = self.fc2(out)
22
+ return out
23
+
24
+ # Load the saved model state dictionaries
25
+ model_x = RegressionModel2(3, 32, 1)
26
+ model_x.load_state_dict(torch.load('j0_model.pt'))
27
+ model_x.eval()
28
+
29
+ model_y = RegressionModel2(3, 32, 1)
30
+ model_y.load_state_dict(torch.load('j45_model.pt'))
31
+ model_y.eval()
32
+
33
+ def predict_components(age, axis, aca):
34
+ """
35
+ This function predicts both x and y components using the loaded models.
36
+ """
37
+ data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
38
+
39
+ with torch.no_grad():
40
+ x_pred = model_x(data).item()
41
+ y_pred = model_y(data).item()
42
+
43
+ return x_pred, y_pred
44
+
45
+ def calculate_magnitude_and_axis(x, y):
46
+ """
47
+ Calculate magnitude and axis from x and y components.
48
+ """
49
+ magnitude = math.sqrt(x**2 + y**2)
50
+
51
+ # Calculate intermediate axis
52
+ intermediate_axis = math.degrees(0.5 * math.atan2(y, x))
53
+
54
+ # Calculate final axis
55
+ if x > 0:
56
+ if y > 0:
57
+ final_axis = intermediate_axis
58
+ else:
59
+ final_axis = intermediate_axis + 180
60
+ else:
61
+ final_axis = intermediate_axis + 90
62
+
63
+ # Ensure axis is between 0 and 180
64
+ final_axis = final_axis % 180
65
+
66
+ return magnitude, final_axis
67
+
68
+ def main():
69
+ st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
70
+ st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
71
+ st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
72
+ st.markdown(
73
+ """
74
+ <style>
75
+ .navbar {
76
+ display: flex;
77
+ justify-content: space-between;
78
+ align-items: center;
79
+ background-color: #f2f2f2;
80
+ padding: 10px;
81
+ }
82
+ .logo img {
83
+ height: 50px;
84
+ }
85
+ .menu {
86
+ list-style-type: none;
87
+ display: flex;
88
+ }
89
+ .menu li {
90
+ margin-left: 20px;
91
+ }
92
+ .text-content {
93
+ margin-top: 50px;
94
+ text-align: center;
95
+ }
96
+ .button {
97
+ margin-top: 20px;
98
+ padding: 10px 20px;
99
+ font-size: 16px;
100
+ }
101
+ .error {
102
+ color: red;
103
+ font-weight: bold;
104
+ }
105
+ </style>
106
+ """,
107
+ unsafe_allow_html=True
108
+ )
109
+
110
+ # Use session state to store input values
111
+ if 'age' not in st.session_state:
112
+ st.session_state.age = None
113
+ if 'aca_magnitude' not in st.session_state:
114
+ st.session_state.aca_magnitude = None
115
+ if 'aca_axis' not in st.session_state:
116
+ st.session_state.aca_axis = None
117
+
118
+ # Age input
119
+ age = st.number_input('Enter Patient Age (15-90 Years):', min_value=18.0, max_value=90.0, step=0.1, value=st.session_state.age)
120
+ if age != st.session_state.age:
121
+ st.session_state.age = age
122
+ if age is not None and (age < 18 or age > 90):
123
+ st.markdown('<p class="error">Error: Age must be between 18 and 90.</p>', unsafe_allow_html=True)
124
+
125
+ # ACA Magnitude input
126
+ aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.1, value=st.session_state.aca_magnitude)
127
+ if aca_magnitude != st.session_state.aca_magnitude:
128
+ st.session_state.aca_magnitude = aca_magnitude
129
+ if aca_magnitude is not None and (aca_magnitude < 0 or aca_magnitude > 10):
130
+ st.markdown('<p class="error">Error: ACA Magnitude must be between 0 and 10.</p>', unsafe_allow_html=True)
131
+
132
+ # ACA Axis input
133
+ aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1, value=st.session_state.aca_axis)
134
+ if aca_axis != st.session_state.aca_axis:
135
+ st.session_state.aca_axis = aca_axis
136
+ if aca_axis is not None and (aca_axis < 0 or aca_axis > 180):
137
+ st.markdown('<p class="error">Error: ACA Axis must be between 0 and 180.</p>', unsafe_allow_html=True)
138
+
139
+ if st.button('Predict!'):
140
+ if age is not None and aca_magnitude is not None and aca_axis is not None:
141
+ if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
142
+ x_pred, y_pred = predict_components(age, aca_axis, aca_magnitude)
143
+ predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(x_pred, y_pred)
144
+ st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
145
+ st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
146
+ else:
147
+ st.error('Please correct the input errors before predicting.')
148
+ else:
149
+ st.error('Please fill in all fields before predicting.')
150
+
151
+ if __name__ == '__main__':
152
+ main()