Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,40 +19,51 @@ class RegressionModel2(nn.Module):
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
def predict_components(age, axis, aca):
|
31 |
"""
|
32 |
-
This function predicts both J0 and J45 components using the loaded models.
|
33 |
"""
|
34 |
data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
|
35 |
|
36 |
with torch.no_grad():
|
37 |
-
|
38 |
-
|
39 |
|
40 |
-
return
|
41 |
|
42 |
-
def calculate_magnitude_and_axis(
|
43 |
"""
|
44 |
-
Calculate magnitude and axis from
|
45 |
"""
|
46 |
-
magnitude =
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Ensure axis is between 0 and 180
|
52 |
-
|
53 |
-
axis += 180
|
54 |
|
55 |
-
return magnitude,
|
56 |
|
57 |
def main():
|
58 |
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
|
@@ -66,10 +77,14 @@ def main():
|
|
66 |
|
67 |
if st.button('Predict!'):
|
68 |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
69 |
-
|
70 |
-
predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(
|
71 |
-
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.
|
72 |
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
|
|
|
|
|
|
|
|
|
73 |
else:
|
74 |
st.error('Please ensure all inputs are within the specified ranges.')
|
75 |
|
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
+
model_x = RegressionModel2(3, 32, 1)
|
23 |
+
model_x.load_state_dict(torch.load('j0_model.pt'))
|
24 |
+
model_x.eval()
|
25 |
|
26 |
+
model_y = RegressionModel2(3, 32, 1)
|
27 |
+
model_y.load_state_dict(torch.load('j45_model.pt'))
|
28 |
+
model_y.eval()
|
29 |
|
30 |
def predict_components(age, axis, aca):
|
31 |
"""
|
32 |
+
This function predicts both x (J0) and y (J45) components using the loaded models.
|
33 |
"""
|
34 |
data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
|
35 |
|
36 |
with torch.no_grad():
|
37 |
+
x_pred = model_x(data).item()
|
38 |
+
y_pred = model_y(data).item()
|
39 |
|
40 |
+
return x_pred, y_pred
|
41 |
|
42 |
+
def calculate_magnitude_and_axis(x, y):
|
43 |
"""
|
44 |
+
Calculate magnitude and axis from x and y components using Excel formulas.
|
45 |
"""
|
46 |
+
magnitude = math.sqrt(x**2 + y**2)
|
47 |
|
48 |
+
# Calculate intermediate axis (in radians)
|
49 |
+
intermediate_axis_rad = 0.5 * math.atan2(y, x)
|
50 |
+
|
51 |
+
# Convert to degrees
|
52 |
+
intermediate_axis_deg = math.degrees(intermediate_axis_rad)
|
53 |
+
|
54 |
+
# Calculate final axis
|
55 |
+
if x > 0:
|
56 |
+
if y > 0:
|
57 |
+
final_axis = intermediate_axis_deg
|
58 |
+
else:
|
59 |
+
final_axis = intermediate_axis_deg + 180
|
60 |
+
else:
|
61 |
+
final_axis = intermediate_axis_deg + 90
|
62 |
|
63 |
# Ensure axis is between 0 and 180
|
64 |
+
final_axis = final_axis % 180
|
|
|
65 |
|
66 |
+
return magnitude, final_axis
|
67 |
|
68 |
def main():
|
69 |
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
|
|
|
77 |
|
78 |
if st.button('Predict!'):
|
79 |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
80 |
+
x_pred, y_pred = predict_components(age, aca_axis, aca_magnitude)
|
81 |
+
predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(x_pred, y_pred)
|
82 |
+
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
|
83 |
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
|
84 |
+
|
85 |
+
# Display predicted x and y components for verification
|
86 |
+
st.info(f'Predicted x (J0) component: {x_pred:.4f}')
|
87 |
+
st.info(f'Predicted y (J45) component: {y_pred:.4f}')
|
88 |
else:
|
89 |
st.error('Please ensure all inputs are within the specified ranges.')
|
90 |
|