Jfink09 commited on
Commit
8f8e201
·
verified ·
1 Parent(s): 485bcab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -3
app.py CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
- model_j0 = RegressionModel2(1, 32, 1)
23
  model_j0.load_state_dict(torch.load('j0_model.pt'))
24
  model_j0.eval()
25
 
26
- model_j45 = RegressionModel2(1, 32, 1)
27
  model_j45.load_state_dict(torch.load('j45_model.pt'))
28
  model_j45.eval()
29
 
@@ -34,7 +34,64 @@ def calculate_initial_j0_j45(magnitude, axis):
34
  j45 = magnitude * math.sin(2 * axis_rad)
35
  return j0, j45
36
 
37
- def predict_new_j0_j45(j0, j45):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  """Predict new J0 and J45 using the loaded models."""
39
  with torch.no_grad():
40
  new_j0 = model_j0(torch.tensor([[j0]], dtype=torch.float32)).item()
 
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
+ model_j0 = RegressionModel2(3, 32, 1)
23
  model_j0.load_state_dict(torch.load('j0_model.pt'))
24
  model_j0.eval()
25
 
26
+ model_j45 = RegressionModel2(3, 32, 1)
27
  model_j45.load_state_dict(torch.load('j45_model.pt'))
28
  model_j45.eval()
29
 
 
34
  j45 = magnitude * math.sin(2 * axis_rad)
35
  return j0, j45
36
 
37
+ def predict_new_j0_j45(age, magnitude, axis):
38
+ """Predict new J0 and J45 using the loaded models."""
39
+ input_data = torch.tensor([[age, magnitude, axis]], dtype=torch.float32)
40
+ with torch.no_grad():
41
+ new_j0 = model_j0(input_data).item()
42
+ new_j45 = model_j45(input_data).item()
43
+ return new_j0, new_j45
44
+
45
+ def calculate_magnitude_and_axis(j0, j45):
46
+ """Calculate magnitude and axis from J0 and J45."""
47
+ magnitude = math.sqrt(j0**2 + j45**2)
48
+ z = 0.5 * math.atan2(j45, j0)
49
+ z_deg = math.degrees(z)
50
+
51
+ if j0 > 0:
52
+ if j45 > 0:
53
+ final_axis = z_deg
54
+ else:
55
+ final_axis = z_deg + 180
56
+ else:
57
+ final_axis = z_deg + 90
58
+
59
+ final_axis = final_axis % 180
60
+
61
+ return magnitude, final_axis
62
+
63
+ def main():
64
+ st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
65
+
66
+ st.title('Total Corneal Astigmatism Prediction')
67
+
68
+ # Input fields
69
+ age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, step=0.1)
70
+ aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.01)
71
+ aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1)
72
+
73
+ if st.button('Predict!'):
74
+ if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
75
+ # Step 2: Calculate initial J0 and J45 (for display purposes only)
76
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
77
+
78
+ # Step 3: Predict new J0 and J45 using the models
79
+ new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
80
+
81
+ # Steps 4-6: Calculate predicted magnitude and axis
82
+ predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)
83
+
84
+ st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
85
+ st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
86
+
87
+ # Display intermediate values for verification
88
+ st.info(f'Initial J0: {initial_j0:.4f}, Initial J45: {initial_j45:.4f}')
89
+ st.info(f'Predicted J0: {new_j0:.4f}, Predicted J45: {new_j45:.4f}')
90
+ else:
91
+ st.error('Please ensure all inputs are within the specified ranges.')
92
+
93
+ if __name__ == '__main__':
94
+ main()def predict_new_j0_j45(j0, j45):
95
  """Predict new J0 and J45 using the loaded models."""
96
  with torch.no_grad():
97
  new_j0 = model_j0(torch.tensor([[j0]], dtype=torch.float32)).item()