Jfink09 commited on
Commit
241be24
·
verified ·
1 Parent(s): e473f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -42
app.py CHANGED
@@ -19,13 +19,19 @@ class RegressionModel2(nn.Module):
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
- model_j0 = RegressionModel2(3, 32, 1)
23
- model_j0.load_state_dict(torch.load('j0_model-2.pt'))
24
- model_j0.eval()
 
 
25
 
26
- model_j45 = RegressionModel2(3, 32, 1)
27
- model_j45.load_state_dict(torch.load('j45_model-2.pt'))
28
- model_j45.eval()
 
 
 
 
29
 
30
  def calculate_initial_j0_j45(magnitude, axis_deg):
31
  """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
@@ -41,51 +47,52 @@ def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
41
  input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
42
  input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
43
 
44
- st.write("Input tensor for J0:", input_data_j0)
45
- st.write("Input tensor for J45:", input_data_j45)
46
-
47
  with torch.no_grad():
48
  new_j0 = model_j0(input_data_j0).item()
49
  new_j45 = model_j45(input_data_j45).item()
50
 
51
- st.write("Raw J0 output:", new_j0)
52
- st.write("Raw J45 output:", new_j45)
53
-
54
  return new_j0, new_j45
55
 
56
  def main():
57
- st.title('Astigmatism Prediction Debugging')
58
-
59
- # Fixed inputs for debugging
60
- age = 58
61
- aca_magnitude = 2.3
62
- aca_axis = 97.7
63
 
64
- st.write(f"Debugging with fixed inputs: Age={age}, ACA Magnitude={aca_magnitude}, ACA Axis={aca_axis}")
 
 
 
65
 
66
- # Model architecture
67
- st.subheader("Model Architecture")
68
- st.write("J0 Model:", model_j0)
69
- st.write("J45 Model:", model_j45)
70
-
71
- # Calculate initial J0 and J45
72
- initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
73
- st.write(f"Initial J0: {initial_j0:.2f}")
74
- st.write(f"Initial J45: {initial_j45:.2f}")
75
-
76
- # Make prediction
77
- new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
78
-
79
- st.subheader("Prediction Results")
80
- st.write(f"Predicted J0: {new_j0:.2f}")
81
- st.write(f"Predicted J45: {new_j45:.2f}")
82
- st.write(f"Expected J0 (from Colab): -1.72")
83
- st.write(f"Expected J45 (from Colab): -0.53")
84
-
85
- # Calculate TCA magnitude
86
- tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
87
- st.write(f"Calculated TCA Magnitude: {tca_magnitude:.2f}")
88
- st.write(f"Expected TCA Magnitude (from Colab): 1.80")
 
 
 
 
 
 
 
 
 
89
 
90
  if __name__ == '__main__':
91
  main()
 
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
+ @st.cache_resource
23
+ def load_models():
24
+ model_j0 = RegressionModel2(3, 32, 1)
25
+ model_j0.load_state_dict(torch.load('j0_model.pt'))
26
+ model_j0.eval()
27
 
28
+ model_j45 = RegressionModel2(3, 32, 1)
29
+ model_j45.load_state_dict(torch.load('j45_model.pt'))
30
+ model_j45.eval()
31
+
32
+ return model_j0, model_j45
33
+
34
+ model_j0, model_j45 = load_models()
35
 
36
  def calculate_initial_j0_j45(magnitude, axis_deg):
37
  """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
 
47
  input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
48
  input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
49
 
 
 
 
50
  with torch.no_grad():
51
  new_j0 = model_j0(input_data_j0).item()
52
  new_j45 = model_j45(input_data_j45).item()
53
 
 
 
 
54
  return new_j0, new_j45
55
 
56
  def main():
57
+ st.title('Total Corneal Astigmatism Prediction')
 
 
 
 
 
58
 
59
+ # User input fields
60
+ age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=1.0)
61
+ aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.1)
62
+ aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
63
 
64
+ if st.button('Predict'):
65
+ # Calculate initial J0 and J45
66
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
67
+
68
+ # Make prediction
69
+ new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
70
+
71
+ # Calculate TCA magnitude and axis
72
+ tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
73
+ tca_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
74
+ if tca_axis < 0:
75
+ tca_axis += 180
76
+
77
+ # Display results
78
+ st.subheader('Prediction Results')
79
+ col1, col2 = st.columns(2)
80
+ with col1:
81
+ st.write(f"Initial J0: {initial_j0:.2f}")
82
+ st.write(f"Initial J45: {initial_j45:.2f}")
83
+ st.write(f"Predicted J0: {new_j0:.2f}")
84
+ st.write(f"Predicted J45: {new_j45:.2f}")
85
+ with col2:
86
+ st.write(f"Predicted TCA Magnitude: {tca_magnitude:.2f} D")
87
+ st.write(f"Predicted TCA Axis: {tca_axis:.1f}°")
88
+
89
+ # Optional: Display input tensors and raw outputs for verification
90
+ if st.checkbox('Show detailed model inputs and outputs'):
91
+ st.subheader('Model Details')
92
+ st.write("Input tensor for J0:", torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32))
93
+ st.write("Input tensor for J45:", torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32))
94
+ st.write("Raw J0 output:", new_j0)
95
+ st.write("Raw J45 output:", new_j45)
96
 
97
  if __name__ == '__main__':
98
  main()