Jfink09 commited on
Commit
743e01a
·
verified ·
1 Parent(s): dc21578

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -24,11 +24,11 @@ class RegressionModel2(nn.Module):
24
  @st.cache_resource
25
  def load_models():
26
  model_j0 = RegressionModel2(3, 32, 1)
27
- model_j0.load_state_dict(torch.load('j0_model-2.pt'))
28
  model_j0.eval()
29
 
30
  model_j45 = RegressionModel2(3, 32, 1)
31
- model_j45.load_state_dict(torch.load('j45_model-2.pt'))
32
  model_j45.eval()
33
 
34
  return model_j0, model_j45
@@ -56,10 +56,32 @@ def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
56
  def main():
57
  st.title('Total Corneal Astigmatism Prediction')
58
 
59
- # Input fields
60
- age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=0.1)
61
- aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.01)
62
- aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  if st.button('Predict!'):
65
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
@@ -75,7 +97,7 @@ def main():
75
  if predicted_axis < 0:
76
  predicted_axis += 180
77
 
78
- # Display results in a green success box
79
  st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
80
  st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
81
 
 
24
  @st.cache_resource
25
  def load_models():
26
  model_j0 = RegressionModel2(3, 32, 1)
27
+ model_j0.load_state_dict(torch.load('j0_model.pt'))
28
  model_j0.eval()
29
 
30
  model_j45 = RegressionModel2(3, 32, 1)
31
+ model_j45.load_state_dict(torch.load('j45_model.pt'))
32
  model_j45.eval()
33
 
34
  return model_j0, model_j45
 
56
  def main():
57
  st.title('Total Corneal Astigmatism Prediction')
58
 
59
+ # Initialize session state for input values if not already present
60
+ if 'age' not in st.session_state:
61
+ st.session_state.age = ''
62
+ if 'aca_magnitude' not in st.session_state:
63
+ st.session_state.aca_magnitude = ''
64
+ if 'aca_axis' not in st.session_state:
65
+ st.session_state.aca_axis = ''
66
+
67
+ # Input fields with session state
68
+ age = st.number_input('Enter Patient Age (18-90 Years):',
69
+ min_value=18.0, max_value=90.0,
70
+ value=st.session_state.age,
71
+ step=0.1,
72
+ key='age')
73
+
74
+ aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):',
75
+ min_value=0.0, max_value=10.0,
76
+ value=st.session_state.aca_magnitude,
77
+ step=0.01,
78
+ key='aca_magnitude')
79
+
80
+ aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):',
81
+ min_value=0.0, max_value=180.0,
82
+ value=st.session_state.aca_axis,
83
+ step=0.1,
84
+ key='aca_axis')
85
 
86
  if st.button('Predict!'):
87
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
 
97
  if predicted_axis < 0:
98
  predicted_axis += 180
99
 
100
+ # Display results in green success boxes
101
  st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
102
  st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
103