Jfink09 commited on
Commit
9df995a
·
verified ·
1 Parent(s): e15a441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -41
app.py CHANGED
@@ -18,15 +18,14 @@ class RegressionModel2(nn.Module):
18
  out = self.fc2(out)
19
  return out
20
 
21
- # Load the saved model state dictionaries
22
  @st.cache_resource
23
  def load_models():
24
  model_j0 = RegressionModel2(3, 32, 1)
25
- model_j0.load_state_dict(torch.load('j0_model-2.pt'))
26
  model_j0.eval()
27
 
28
  model_j45 = RegressionModel2(3, 32, 1)
29
- model_j45.load_state_dict(torch.load('j45_model-2.pt'))
30
  model_j45.eval()
31
 
32
  return model_j0, model_j45
@@ -34,14 +33,12 @@ def load_models():
34
  model_j0, model_j45 = load_models()
35
 
36
  def calculate_initial_j0_j45(magnitude, axis_deg):
37
- """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
38
  axis_rad = math.radians(axis_deg)
39
  j0 = magnitude * math.cos(2 * axis_rad)
40
  j45 = magnitude * math.sin(2 * axis_rad)
41
  return j0, j45
42
 
43
  def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
44
- """Predict new J0 and J45 using the loaded models."""
45
  initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
46
 
47
  input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
@@ -54,45 +51,51 @@ def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
54
  return new_j0, new_j45
55
 
56
  def main():
57
- st.title('Total Corneal Astigmatism Prediction')
58
 
59
- # User input fields
60
- age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=1.0)
61
- aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.1)
 
 
62
  aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
63
-
64
- if st.button('Predict'):
65
- # Calculate initial J0 and J45
66
- initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
67
-
68
- # Make prediction
69
- new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
70
-
71
- # Calculate TCA magnitude and axis
72
- tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
73
- tca_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
74
- if tca_axis < 0:
75
- tca_axis += 180
76
-
77
- # Display results
78
- st.subheader('Prediction Results')
79
- col1, col2 = st.columns(2)
80
- with col1:
81
- st.write(f"Initial J0: {initial_j0:.2f}")
82
- st.write(f"Initial J45: {initial_j45:.2f}")
83
- st.write(f"Predicted J0: {new_j0:.2f}")
84
- st.write(f"Predicted J45: {new_j45:.2f}")
85
- with col2:
86
- st.write(f"Predicted TCA Magnitude: {tca_magnitude:.2f} D")
87
- st.write(f"Predicted TCA Axis: {tca_axis:.1f}°")
88
 
89
- # Optional: Display input tensors and raw outputs for verification
90
- if st.checkbox('Show detailed model inputs and outputs'):
91
- st.subheader('Model Details')
92
- st.write("Input tensor for J0:", torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32))
93
- st.write("Input tensor for J45:", torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32))
94
- st.write("Raw J0 output:", new_j0)
95
- st.write("Raw J45 output:", new_j45)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == '__main__':
98
  main()
 
18
  out = self.fc2(out)
19
  return out
20
 
 
21
  @st.cache_resource
22
  def load_models():
23
  model_j0 = RegressionModel2(3, 32, 1)
24
+ model_j0.load_state_dict(torch.load('j0_model.pt'))
25
  model_j0.eval()
26
 
27
  model_j45 = RegressionModel2(3, 32, 1)
28
+ model_j45.load_state_dict(torch.load('j45_model.pt'))
29
  model_j45.eval()
30
 
31
  return model_j0, model_j45
 
33
  model_j0, model_j45 = load_models()
34
 
35
  def calculate_initial_j0_j45(magnitude, axis_deg):
 
36
  axis_rad = math.radians(axis_deg)
37
  j0 = magnitude * math.cos(2 * axis_rad)
38
  j45 = magnitude * math.sin(2 * axis_rad)
39
  return j0, j45
40
 
41
  def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
 
42
  initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
43
 
44
  input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
 
51
  return new_j0, new_j45
52
 
53
  def main():
54
+ st.set_page_config(page_title='Total Corneal Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
55
 
56
+ st.title('Total Corneal Astigmatism Prediction')
57
+
58
+ # Input fields
59
+ age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=0.1)
60
+ aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.01)
61
  aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ if st.button('Predict!'):
64
+ if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
65
+ # Calculate initial J0 and J45
66
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
67
+
68
+ # Predict new J0 and J45 using the models
69
+ new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
70
+
71
+ # Calculate predicted magnitude and axis
72
+ predicted_magnitude = math.sqrt(new_j0**2 + new_j45**2)
73
+ predicted_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
74
+ if predicted_axis < 0:
75
+ predicted_axis += 180
76
+
77
+ # Display results in a green success box
78
+ st.success(f'''
79
+ Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D
80
+ Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°
81
+ ''')
82
+
83
+ # Display intermediate values for verification
84
+ st.info(f'''
85
+ Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°
86
+ Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}
87
+ Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}
88
+ ''')
89
+
90
+ # Additional debugging information (optional)
91
+ if st.checkbox('Show detailed model inputs and outputs'):
92
+ st.subheader('Debugging Information:')
93
+ st.write(f"Input tensor for J0: {torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)}")
94
+ st.write(f"Input tensor for J45: {torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)}")
95
+ st.write(f"Raw J0 output: {new_j0}")
96
+ st.write(f"Raw J45 output: {new_j45}")
97
+ else:
98
+ st.error('Please ensure all inputs are within the specified ranges.')
99
 
100
  if __name__ == '__main__':
101
  main()