Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,11 +24,11 @@ class RegressionModel2(nn.Module):
|
|
24 |
@st.cache_resource
|
25 |
def load_models():
|
26 |
model_j0 = RegressionModel2(3, 32, 1)
|
27 |
-
model_j0.load_state_dict(torch.load('j0_model
|
28 |
model_j0.eval()
|
29 |
|
30 |
model_j45 = RegressionModel2(3, 32, 1)
|
31 |
-
model_j45.load_state_dict(torch.load('j45_model
|
32 |
model_j45.eval()
|
33 |
|
34 |
return model_j0, model_j45
|
@@ -56,10 +56,32 @@ def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
|
56 |
def main():
|
57 |
st.title('Total Corneal Astigmatism Prediction')
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
if st.button('Predict!'):
|
65 |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
@@ -75,7 +97,7 @@ def main():
|
|
75 |
if predicted_axis < 0:
|
76 |
predicted_axis += 180
|
77 |
|
78 |
-
# Display results in
|
79 |
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
|
80 |
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
|
81 |
|
|
|
24 |
@st.cache_resource
|
25 |
def load_models():
|
26 |
model_j0 = RegressionModel2(3, 32, 1)
|
27 |
+
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
28 |
model_j0.eval()
|
29 |
|
30 |
model_j45 = RegressionModel2(3, 32, 1)
|
31 |
+
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
32 |
model_j45.eval()
|
33 |
|
34 |
return model_j0, model_j45
|
|
|
56 |
def main():
|
57 |
st.title('Total Corneal Astigmatism Prediction')
|
58 |
|
59 |
+
# Initialize session state for input values if not already present
|
60 |
+
if 'age' not in st.session_state:
|
61 |
+
st.session_state.age = ''
|
62 |
+
if 'aca_magnitude' not in st.session_state:
|
63 |
+
st.session_state.aca_magnitude = ''
|
64 |
+
if 'aca_axis' not in st.session_state:
|
65 |
+
st.session_state.aca_axis = ''
|
66 |
+
|
67 |
+
# Input fields with session state
|
68 |
+
age = st.number_input('Enter Patient Age (18-90 Years):',
|
69 |
+
min_value=18.0, max_value=90.0,
|
70 |
+
value=st.session_state.age,
|
71 |
+
step=0.1,
|
72 |
+
key='age')
|
73 |
+
|
74 |
+
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):',
|
75 |
+
min_value=0.0, max_value=10.0,
|
76 |
+
value=st.session_state.aca_magnitude,
|
77 |
+
step=0.01,
|
78 |
+
key='aca_magnitude')
|
79 |
+
|
80 |
+
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):',
|
81 |
+
min_value=0.0, max_value=180.0,
|
82 |
+
value=st.session_state.aca_axis,
|
83 |
+
step=0.1,
|
84 |
+
key='aca_axis')
|
85 |
|
86 |
if st.button('Predict!'):
|
87 |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
|
|
97 |
if predicted_axis < 0:
|
98 |
predicted_axis += 180
|
99 |
|
100 |
+
# Display results in green success boxes
|
101 |
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
|
102 |
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
|
103 |
|