Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -18,15 +18,14 @@ class RegressionModel2(nn.Module):
|
|
18 |
out = self.fc2(out)
|
19 |
return out
|
20 |
|
21 |
-
# Load the saved model state dictionaries
|
22 |
@st.cache_resource
|
23 |
def load_models():
|
24 |
model_j0 = RegressionModel2(3, 32, 1)
|
25 |
-
model_j0.load_state_dict(torch.load('j0_model
|
26 |
model_j0.eval()
|
27 |
|
28 |
model_j45 = RegressionModel2(3, 32, 1)
|
29 |
-
model_j45.load_state_dict(torch.load('j45_model
|
30 |
model_j45.eval()
|
31 |
|
32 |
return model_j0, model_j45
|
@@ -34,14 +33,12 @@ def load_models():
|
|
34 |
model_j0, model_j45 = load_models()
|
35 |
|
36 |
def calculate_initial_j0_j45(magnitude, axis_deg):
|
37 |
-
"""Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
|
38 |
axis_rad = math.radians(axis_deg)
|
39 |
j0 = magnitude * math.cos(2 * axis_rad)
|
40 |
j45 = magnitude * math.sin(2 * axis_rad)
|
41 |
return j0, j45
|
42 |
|
43 |
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
44 |
-
"""Predict new J0 and J45 using the loaded models."""
|
45 |
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
|
46 |
|
47 |
input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
|
@@ -54,45 +51,51 @@ def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
|
54 |
return new_j0, new_j45
|
55 |
|
56 |
def main():
|
57 |
-
st.
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
|
63 |
-
|
64 |
-
if st.button('Predict'):
|
65 |
-
# Calculate initial J0 and J45
|
66 |
-
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
|
67 |
-
|
68 |
-
# Make prediction
|
69 |
-
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
|
70 |
-
|
71 |
-
# Calculate TCA magnitude and axis
|
72 |
-
tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
|
73 |
-
tca_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
|
74 |
-
if tca_axis < 0:
|
75 |
-
tca_axis += 180
|
76 |
-
|
77 |
-
# Display results
|
78 |
-
st.subheader('Prediction Results')
|
79 |
-
col1, col2 = st.columns(2)
|
80 |
-
with col1:
|
81 |
-
st.write(f"Initial J0: {initial_j0:.2f}")
|
82 |
-
st.write(f"Initial J45: {initial_j45:.2f}")
|
83 |
-
st.write(f"Predicted J0: {new_j0:.2f}")
|
84 |
-
st.write(f"Predicted J45: {new_j45:.2f}")
|
85 |
-
with col2:
|
86 |
-
st.write(f"Predicted TCA Magnitude: {tca_magnitude:.2f} D")
|
87 |
-
st.write(f"Predicted TCA Axis: {tca_axis:.1f}°")
|
88 |
|
89 |
-
|
90 |
-
if
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if __name__ == '__main__':
|
98 |
main()
|
|
|
18 |
out = self.fc2(out)
|
19 |
return out
|
20 |
|
|
|
21 |
@st.cache_resource
|
22 |
def load_models():
|
23 |
model_j0 = RegressionModel2(3, 32, 1)
|
24 |
+
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
25 |
model_j0.eval()
|
26 |
|
27 |
model_j45 = RegressionModel2(3, 32, 1)
|
28 |
+
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
29 |
model_j45.eval()
|
30 |
|
31 |
return model_j0, model_j45
|
|
|
33 |
model_j0, model_j45 = load_models()
|
34 |
|
35 |
def calculate_initial_j0_j45(magnitude, axis_deg):
|
|
|
36 |
axis_rad = math.radians(axis_deg)
|
37 |
j0 = magnitude * math.cos(2 * axis_rad)
|
38 |
j45 = magnitude * math.sin(2 * axis_rad)
|
39 |
return j0, j45
|
40 |
|
41 |
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
|
|
42 |
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
|
43 |
|
44 |
input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
|
|
|
51 |
return new_j0, new_j45
|
52 |
|
53 |
def main():
|
54 |
+
st.set_page_config(page_title='Total Corneal Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
|
55 |
|
56 |
+
st.title('Total Corneal Astigmatism Prediction')
|
57 |
+
|
58 |
+
# Input fields
|
59 |
+
age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=0.1)
|
60 |
+
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.01)
|
61 |
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
if st.button('Predict!'):
|
64 |
+
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
65 |
+
# Calculate initial J0 and J45
|
66 |
+
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
|
67 |
+
|
68 |
+
# Predict new J0 and J45 using the models
|
69 |
+
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
|
70 |
+
|
71 |
+
# Calculate predicted magnitude and axis
|
72 |
+
predicted_magnitude = math.sqrt(new_j0**2 + new_j45**2)
|
73 |
+
predicted_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
|
74 |
+
if predicted_axis < 0:
|
75 |
+
predicted_axis += 180
|
76 |
+
|
77 |
+
# Display results in a green success box
|
78 |
+
st.success(f'''
|
79 |
+
Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D
|
80 |
+
Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°
|
81 |
+
''')
|
82 |
+
|
83 |
+
# Display intermediate values for verification
|
84 |
+
st.info(f'''
|
85 |
+
Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°
|
86 |
+
Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}
|
87 |
+
Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}
|
88 |
+
''')
|
89 |
+
|
90 |
+
# Additional debugging information (optional)
|
91 |
+
if st.checkbox('Show detailed model inputs and outputs'):
|
92 |
+
st.subheader('Debugging Information:')
|
93 |
+
st.write(f"Input tensor for J0: {torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)}")
|
94 |
+
st.write(f"Input tensor for J45: {torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)}")
|
95 |
+
st.write(f"Raw J0 output: {new_j0}")
|
96 |
+
st.write(f"Raw J45 output: {new_j45}")
|
97 |
+
else:
|
98 |
+
st.error('Please ensure all inputs are within the specified ranges.')
|
99 |
|
100 |
if __name__ == '__main__':
|
101 |
main()
|