Jfink09 commited on
Commit
e473f84
·
verified ·
1 Parent(s): a5bcb7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -60
app.py CHANGED
@@ -30,79 +30,62 @@ model_j45.eval()
30
  def calculate_initial_j0_j45(magnitude, axis_deg):
31
  """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
32
  axis_rad = math.radians(axis_deg)
33
- j0 = round(magnitude * math.cos(2 * axis_rad), 2)
34
- j45 = round(magnitude * math.sin(2 * axis_rad), 2)
35
  return j0, j45
36
 
37
  def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
38
  """Predict new J0 and J45 using the loaded models."""
39
- aca_axis_rad = math.radians(aca_axis_deg)
40
- aca_x = aca_magnitude * math.cos(aca_axis_rad)
41
- aca_y = aca_magnitude * math.sin(aca_axis_rad)
42
 
43
- input_data_j0 = torch.tensor([[age, aca_axis_deg, aca_x]], dtype=torch.float32)
44
- input_data_j45 = torch.tensor([[age, aca_axis_deg, aca_y]], dtype=torch.float32)
 
 
 
45
 
46
  with torch.no_grad():
47
  new_j0 = model_j0(input_data_j0).item()
48
  new_j45 = model_j45(input_data_j45).item()
 
 
 
 
49
  return new_j0, new_j45
50
 
51
- def calculate_magnitude(j0, j45):
52
- """Calculate magnitude from J0 and J45."""
53
- return math.sqrt(j0**2 + j45**2)
54
-
55
- def calculate_axis(j0, j45):
56
- """Calculate axis from J0 and J45."""
57
- axis = 0.5 * math.degrees(math.atan2(j45, j0))
58
- if axis < 0:
59
- axis += 180
60
- return axis
61
-
62
  def main():
63
- st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
64
 
65
- st.title('Total Corneal Astigmatism Prediction')
66
-
67
- # Input fields
68
- age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, step=0.1)
69
- aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.01)
70
- aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1)
71
-
72
- if st.button('Predict!'):
73
- if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
74
- # Calculate initial J0 and J45 (for comparison)
75
- initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
76
-
77
- # Predict new J0 and J45 using the models
78
- new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
79
-
80
- # Calculate predicted magnitude and axis
81
- predicted_magnitude = calculate_magnitude(new_j0, new_j45)
82
- predicted_axis = calculate_axis(new_j0, new_j45)
83
-
84
- st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
85
- st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
86
-
87
- # Display intermediate values for verification
88
- st.info(f'Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°')
89
- st.info(f'Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}')
90
- st.info(f'Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}')
91
- st.info(f'Intermediate calculations:')
92
- st.info(f' atan2(J45, J0): {math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
93
- st.info(f' 0.5 * atan2(J45, J0): {0.5 * math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
94
-
95
- # Additional debugging information
96
- st.subheader('Debugging Information:')
97
- st.write(f'Input age: {age}')
98
- st.write(f'Input ACA magnitude: {aca_magnitude:.2f} D')
99
- st.write(f'Input ACA axis: {aca_axis:.1f}°')
100
- st.write(f'Calculated ACA X: {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}')
101
- st.write(f'Calculated ACA Y: {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}')
102
- st.write(f'Model J0 input: [{age}, {aca_axis}, {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}]')
103
- st.write(f'Model J45 input: [{age}, {aca_axis}, {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}]')
104
- else:
105
- st.error('Please ensure all inputs are within the specified ranges.')
106
 
107
  if __name__ == '__main__':
108
  main()
 
30
  def calculate_initial_j0_j45(magnitude, axis_deg):
31
  """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
32
  axis_rad = math.radians(axis_deg)
33
+ j0 = magnitude * math.cos(2 * axis_rad)
34
+ j45 = magnitude * math.sin(2 * axis_rad)
35
  return j0, j45
36
 
37
  def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
38
  """Predict new J0 and J45 using the loaded models."""
39
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
 
 
40
 
41
+ input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
42
+ input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
43
+
44
+ st.write("Input tensor for J0:", input_data_j0)
45
+ st.write("Input tensor for J45:", input_data_j45)
46
 
47
  with torch.no_grad():
48
  new_j0 = model_j0(input_data_j0).item()
49
  new_j45 = model_j45(input_data_j45).item()
50
+
51
+ st.write("Raw J0 output:", new_j0)
52
+ st.write("Raw J45 output:", new_j45)
53
+
54
  return new_j0, new_j45
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  def main():
57
+ st.title('Astigmatism Prediction Debugging')
58
 
59
+ # Fixed inputs for debugging
60
+ age = 58
61
+ aca_magnitude = 2.3
62
+ aca_axis = 97.7
63
+
64
+ st.write(f"Debugging with fixed inputs: Age={age}, ACA Magnitude={aca_magnitude}, ACA Axis={aca_axis}")
65
+
66
+ # Model architecture
67
+ st.subheader("Model Architecture")
68
+ st.write("J0 Model:", model_j0)
69
+ st.write("J45 Model:", model_j45)
70
+
71
+ # Calculate initial J0 and J45
72
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
73
+ st.write(f"Initial J0: {initial_j0:.2f}")
74
+ st.write(f"Initial J45: {initial_j45:.2f}")
75
+
76
+ # Make prediction
77
+ new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
78
+
79
+ st.subheader("Prediction Results")
80
+ st.write(f"Predicted J0: {new_j0:.2f}")
81
+ st.write(f"Predicted J45: {new_j45:.2f}")
82
+ st.write(f"Expected J0 (from Colab): -1.72")
83
+ st.write(f"Expected J45 (from Colab): -0.53")
84
+
85
+ # Calculate TCA magnitude
86
+ tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
87
+ st.write(f"Calculated TCA Magnitude: {tca_magnitude:.2f}")
88
+ st.write(f"Expected TCA Magnitude (from Colab): 1.80")
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  if __name__ == '__main__':
91
  main()