Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
-
model_j0 = RegressionModel2(
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
-
model_j45 = RegressionModel2(
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
@@ -34,7 +34,64 @@ def calculate_initial_j0_j45(magnitude, axis):
|
|
34 |
j45 = magnitude * math.sin(2 * axis_rad)
|
35 |
return j0, j45
|
36 |
|
37 |
-
def predict_new_j0_j45(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
39 |
with torch.no_grad():
|
40 |
new_j0 = model_j0(torch.tensor([[j0]], dtype=torch.float32)).item()
|
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
+
model_j0 = RegressionModel2(3, 32, 1)
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
+
model_j45 = RegressionModel2(3, 32, 1)
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
|
|
34 |
j45 = magnitude * math.sin(2 * axis_rad)
|
35 |
return j0, j45
|
36 |
|
37 |
+
def predict_new_j0_j45(age, magnitude, axis):
|
38 |
+
"""Predict new J0 and J45 using the loaded models."""
|
39 |
+
input_data = torch.tensor([[age, magnitude, axis]], dtype=torch.float32)
|
40 |
+
with torch.no_grad():
|
41 |
+
new_j0 = model_j0(input_data).item()
|
42 |
+
new_j45 = model_j45(input_data).item()
|
43 |
+
return new_j0, new_j45
|
44 |
+
|
45 |
+
def calculate_magnitude_and_axis(j0, j45):
|
46 |
+
"""Calculate magnitude and axis from J0 and J45."""
|
47 |
+
magnitude = math.sqrt(j0**2 + j45**2)
|
48 |
+
z = 0.5 * math.atan2(j45, j0)
|
49 |
+
z_deg = math.degrees(z)
|
50 |
+
|
51 |
+
if j0 > 0:
|
52 |
+
if j45 > 0:
|
53 |
+
final_axis = z_deg
|
54 |
+
else:
|
55 |
+
final_axis = z_deg + 180
|
56 |
+
else:
|
57 |
+
final_axis = z_deg + 90
|
58 |
+
|
59 |
+
final_axis = final_axis % 180
|
60 |
+
|
61 |
+
return magnitude, final_axis
|
62 |
+
|
63 |
+
def main():
|
64 |
+
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
|
65 |
+
|
66 |
+
st.title('Total Corneal Astigmatism Prediction')
|
67 |
+
|
68 |
+
# Input fields
|
69 |
+
age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, step=0.1)
|
70 |
+
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.01)
|
71 |
+
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1)
|
72 |
+
|
73 |
+
if st.button('Predict!'):
|
74 |
+
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
75 |
+
# Step 2: Calculate initial J0 and J45 (for display purposes only)
|
76 |
+
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
|
77 |
+
|
78 |
+
# Step 3: Predict new J0 and J45 using the models
|
79 |
+
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
|
80 |
+
|
81 |
+
# Steps 4-6: Calculate predicted magnitude and axis
|
82 |
+
predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)
|
83 |
+
|
84 |
+
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
|
85 |
+
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
|
86 |
+
|
87 |
+
# Display intermediate values for verification
|
88 |
+
st.info(f'Initial J0: {initial_j0:.4f}, Initial J45: {initial_j45:.4f}')
|
89 |
+
st.info(f'Predicted J0: {new_j0:.4f}, Predicted J45: {new_j45:.4f}')
|
90 |
+
else:
|
91 |
+
st.error('Please ensure all inputs are within the specified ranges.')
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
main()def predict_new_j0_j45(j0, j45):
|
95 |
"""Predict new J0 and J45 using the loaded models."""
|
96 |
with torch.no_grad():
|
97 |
new_j0 = model_j0(torch.tensor([[j0]], dtype=torch.float32)).item()
|