Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,13 +19,19 @@ class RegressionModel2(nn.Module):
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
-
|
23 |
-
|
24 |
-
model_j0
|
|
|
|
|
25 |
|
26 |
-
model_j45 = RegressionModel2(3, 32, 1)
|
27 |
-
model_j45.load_state_dict(torch.load('j45_model
|
28 |
-
model_j45.eval()
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def calculate_initial_j0_j45(magnitude, axis_deg):
|
31 |
"""Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
|
@@ -41,51 +47,52 @@ def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
|
41 |
input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
|
42 |
input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
|
43 |
|
44 |
-
st.write("Input tensor for J0:", input_data_j0)
|
45 |
-
st.write("Input tensor for J45:", input_data_j45)
|
46 |
-
|
47 |
with torch.no_grad():
|
48 |
new_j0 = model_j0(input_data_j0).item()
|
49 |
new_j45 = model_j45(input_data_j45).item()
|
50 |
|
51 |
-
st.write("Raw J0 output:", new_j0)
|
52 |
-
st.write("Raw J45 output:", new_j45)
|
53 |
-
|
54 |
return new_j0, new_j45
|
55 |
|
56 |
def main():
|
57 |
-
st.title('Astigmatism Prediction
|
58 |
-
|
59 |
-
# Fixed inputs for debugging
|
60 |
-
age = 58
|
61 |
-
aca_magnitude = 2.3
|
62 |
-
aca_axis = 97.7
|
63 |
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
if __name__ == '__main__':
|
91 |
main()
|
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
+
@st.cache_resource
|
23 |
+
def load_models():
|
24 |
+
model_j0 = RegressionModel2(3, 32, 1)
|
25 |
+
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
26 |
+
model_j0.eval()
|
27 |
|
28 |
+
model_j45 = RegressionModel2(3, 32, 1)
|
29 |
+
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
30 |
+
model_j45.eval()
|
31 |
+
|
32 |
+
return model_j0, model_j45
|
33 |
+
|
34 |
+
model_j0, model_j45 = load_models()
|
35 |
|
36 |
def calculate_initial_j0_j45(magnitude, axis_deg):
|
37 |
"""Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
|
|
|
47 |
input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
|
48 |
input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
|
49 |
|
|
|
|
|
|
|
50 |
with torch.no_grad():
|
51 |
new_j0 = model_j0(input_data_j0).item()
|
52 |
new_j45 = model_j45(input_data_j45).item()
|
53 |
|
|
|
|
|
|
|
54 |
return new_j0, new_j45
|
55 |
|
56 |
def main():
|
57 |
+
st.title('Total Corneal Astigmatism Prediction')
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
# User input fields
|
60 |
+
age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=1.0)
|
61 |
+
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.1)
|
62 |
+
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
|
63 |
|
64 |
+
if st.button('Predict'):
|
65 |
+
# Calculate initial J0 and J45
|
66 |
+
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
|
67 |
+
|
68 |
+
# Make prediction
|
69 |
+
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
|
70 |
+
|
71 |
+
# Calculate TCA magnitude and axis
|
72 |
+
tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
|
73 |
+
tca_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
|
74 |
+
if tca_axis < 0:
|
75 |
+
tca_axis += 180
|
76 |
+
|
77 |
+
# Display results
|
78 |
+
st.subheader('Prediction Results')
|
79 |
+
col1, col2 = st.columns(2)
|
80 |
+
with col1:
|
81 |
+
st.write(f"Initial J0: {initial_j0:.2f}")
|
82 |
+
st.write(f"Initial J45: {initial_j45:.2f}")
|
83 |
+
st.write(f"Predicted J0: {new_j0:.2f}")
|
84 |
+
st.write(f"Predicted J45: {new_j45:.2f}")
|
85 |
+
with col2:
|
86 |
+
st.write(f"Predicted TCA Magnitude: {tca_magnitude:.2f} D")
|
87 |
+
st.write(f"Predicted TCA Axis: {tca_axis:.1f}°")
|
88 |
+
|
89 |
+
# Optional: Display input tensors and raw outputs for verification
|
90 |
+
if st.checkbox('Show detailed model inputs and outputs'):
|
91 |
+
st.subheader('Model Details')
|
92 |
+
st.write("Input tensor for J0:", torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32))
|
93 |
+
st.write("Input tensor for J45:", torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32))
|
94 |
+
st.write("Raw J0 output:", new_j0)
|
95 |
+
st.write("Raw J45 output:", new_j45)
|
96 |
|
97 |
if __name__ == '__main__':
|
98 |
main()
|