Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
from sklearn.metrics import r2_score
|
7 |
+
import math
|
8 |
+
|
9 |
+
class RegressionModel2(nn.Module):
|
10 |
+
def __init__(self, input_dim2, hidden_dim2, output_dim2):
|
11 |
+
super(RegressionModel2, self).__init__()
|
12 |
+
self.fc1 = nn.Linear(input_dim2, hidden_dim2)
|
13 |
+
self.relu1 = nn.ReLU()
|
14 |
+
self.fc2 = nn.Linear(hidden_dim2, output_dim2)
|
15 |
+
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
|
16 |
+
|
17 |
+
def forward(self, x2):
|
18 |
+
out = self.fc1(x2)
|
19 |
+
out = self.relu1(out)
|
20 |
+
out = self.batch_norm1(out)
|
21 |
+
out = self.fc2(out)
|
22 |
+
return out
|
23 |
+
|
24 |
+
# Load the saved model state dictionaries
|
25 |
+
model_x = RegressionModel2(3, 32, 1)
|
26 |
+
model_x.load_state_dict(torch.load('j0_model.pt'))
|
27 |
+
model_x.eval()
|
28 |
+
|
29 |
+
model_y = RegressionModel2(3, 32, 1)
|
30 |
+
model_y.load_state_dict(torch.load('j45_model.pt'))
|
31 |
+
model_y.eval()
|
32 |
+
|
33 |
+
def predict_components(age, axis, aca):
|
34 |
+
"""
|
35 |
+
This function predicts both x and y components using the loaded models.
|
36 |
+
"""
|
37 |
+
data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
x_pred = model_x(data).item()
|
41 |
+
y_pred = model_y(data).item()
|
42 |
+
|
43 |
+
return x_pred, y_pred
|
44 |
+
|
45 |
+
def calculate_magnitude_and_axis(x, y):
|
46 |
+
"""
|
47 |
+
Calculate magnitude and axis from x and y components.
|
48 |
+
"""
|
49 |
+
magnitude = math.sqrt(x**2 + y**2)
|
50 |
+
|
51 |
+
# Calculate intermediate axis
|
52 |
+
intermediate_axis = math.degrees(0.5 * math.atan2(y, x))
|
53 |
+
|
54 |
+
# Calculate final axis
|
55 |
+
if x > 0:
|
56 |
+
if y > 0:
|
57 |
+
final_axis = intermediate_axis
|
58 |
+
else:
|
59 |
+
final_axis = intermediate_axis + 180
|
60 |
+
else:
|
61 |
+
final_axis = intermediate_axis + 90
|
62 |
+
|
63 |
+
# Ensure axis is between 0 and 180
|
64 |
+
final_axis = final_axis % 180
|
65 |
+
|
66 |
+
return magnitude, final_axis
|
67 |
+
|
68 |
+
def main():
|
69 |
+
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
|
70 |
+
st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True)
|
71 |
+
st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True)
|
72 |
+
st.markdown(
|
73 |
+
"""
|
74 |
+
<style>
|
75 |
+
.navbar {
|
76 |
+
display: flex;
|
77 |
+
justify-content: space-between;
|
78 |
+
align-items: center;
|
79 |
+
background-color: #f2f2f2;
|
80 |
+
padding: 10px;
|
81 |
+
}
|
82 |
+
.logo img {
|
83 |
+
height: 50px;
|
84 |
+
}
|
85 |
+
.menu {
|
86 |
+
list-style-type: none;
|
87 |
+
display: flex;
|
88 |
+
}
|
89 |
+
.menu li {
|
90 |
+
margin-left: 20px;
|
91 |
+
}
|
92 |
+
.text-content {
|
93 |
+
margin-top: 50px;
|
94 |
+
text-align: center;
|
95 |
+
}
|
96 |
+
.button {
|
97 |
+
margin-top: 20px;
|
98 |
+
padding: 10px 20px;
|
99 |
+
font-size: 16px;
|
100 |
+
}
|
101 |
+
.error {
|
102 |
+
color: red;
|
103 |
+
font-weight: bold;
|
104 |
+
}
|
105 |
+
</style>
|
106 |
+
""",
|
107 |
+
unsafe_allow_html=True
|
108 |
+
)
|
109 |
+
|
110 |
+
# Use session state to store input values
|
111 |
+
if 'age' not in st.session_state:
|
112 |
+
st.session_state.age = None
|
113 |
+
if 'aca_magnitude' not in st.session_state:
|
114 |
+
st.session_state.aca_magnitude = None
|
115 |
+
if 'aca_axis' not in st.session_state:
|
116 |
+
st.session_state.aca_axis = None
|
117 |
+
|
118 |
+
# Age input
|
119 |
+
age = st.number_input('Enter Patient Age (15-90 Years):', min_value=18.0, max_value=90.0, step=0.1, value=st.session_state.age)
|
120 |
+
if age != st.session_state.age:
|
121 |
+
st.session_state.age = age
|
122 |
+
if age is not None and (age < 18 or age > 90):
|
123 |
+
st.markdown('<p class="error">Error: Age must be between 18 and 90.</p>', unsafe_allow_html=True)
|
124 |
+
|
125 |
+
# ACA Magnitude input
|
126 |
+
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.1, value=st.session_state.aca_magnitude)
|
127 |
+
if aca_magnitude != st.session_state.aca_magnitude:
|
128 |
+
st.session_state.aca_magnitude = aca_magnitude
|
129 |
+
if aca_magnitude is not None and (aca_magnitude < 0 or aca_magnitude > 10):
|
130 |
+
st.markdown('<p class="error">Error: ACA Magnitude must be between 0 and 10.</p>', unsafe_allow_html=True)
|
131 |
+
|
132 |
+
# ACA Axis input
|
133 |
+
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1, value=st.session_state.aca_axis)
|
134 |
+
if aca_axis != st.session_state.aca_axis:
|
135 |
+
st.session_state.aca_axis = aca_axis
|
136 |
+
if aca_axis is not None and (aca_axis < 0 or aca_axis > 180):
|
137 |
+
st.markdown('<p class="error">Error: ACA Axis must be between 0 and 180.</p>', unsafe_allow_html=True)
|
138 |
+
|
139 |
+
if st.button('Predict!'):
|
140 |
+
if age is not None and aca_magnitude is not None and aca_axis is not None:
|
141 |
+
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
142 |
+
x_pred, y_pred = predict_components(age, aca_axis, aca_magnitude)
|
143 |
+
predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(x_pred, y_pred)
|
144 |
+
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
|
145 |
+
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
|
146 |
+
else:
|
147 |
+
st.error('Please correct the input errors before predicting.')
|
148 |
+
else:
|
149 |
+
st.error('Please fill in all fields before predicting.')
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
main()
|