Jfink09 commited on
Commit
485bcab
·
verified ·
1 Parent(s): 2c582df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -19,48 +19,42 @@ class RegressionModel2(nn.Module):
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
- model_x = RegressionModel2(3, 32, 1)
23
- model_x.load_state_dict(torch.load('j0_model.pt'))
24
- model_x.eval()
25
 
26
- model_y = RegressionModel2(3, 32, 1)
27
- model_y.load_state_dict(torch.load('j45_model.pt'))
28
- model_y.eval()
29
 
30
- def predict_components(age, axis, aca):
31
- """
32
- This function predicts both x (J0) and y (J45) components using the loaded models.
33
- """
34
- data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
35
-
 
 
 
36
  with torch.no_grad():
37
- x_pred = model_x(data).item()
38
- y_pred = model_y(data).item()
39
-
40
- return x_pred, y_pred
41
 
42
- def calculate_magnitude_and_axis(x, y):
43
- """
44
- Calculate magnitude and axis from x and y components using Excel formulas.
45
- """
46
- magnitude = math.sqrt(x**2 + y**2)
47
-
48
- # Calculate intermediate axis (in radians)
49
- intermediate_axis_rad = 0.5 * math.atan2(y, x)
50
 
51
- # Convert to degrees
52
- intermediate_axis_deg = math.degrees(intermediate_axis_rad)
53
-
54
- # Calculate final axis
55
- if x > 0:
56
- if y > 0:
57
- final_axis = intermediate_axis_deg
58
  else:
59
- final_axis = intermediate_axis_deg + 180
60
  else:
61
- final_axis = intermediate_axis_deg + 90
62
 
63
- # Ensure axis is between 0 and 180
64
  final_axis = final_axis % 180
65
 
66
  return magnitude, final_axis
@@ -77,14 +71,21 @@ def main():
77
 
78
  if st.button('Predict!'):
79
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
80
- x_pred, y_pred = predict_components(age, aca_axis, aca_magnitude)
81
- predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(x_pred, y_pred)
 
 
 
 
 
 
 
82
  st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
83
  st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
84
 
85
- # Display predicted x and y components for verification
86
- st.info(f'Predicted x (J0) component: {x_pred:.4f}')
87
- st.info(f'Predicted y (J45) component: {y_pred:.4f}')
88
  else:
89
  st.error('Please ensure all inputs are within the specified ranges.')
90
 
 
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
+ model_j0 = RegressionModel2(1, 32, 1)
23
+ model_j0.load_state_dict(torch.load('j0_model.pt'))
24
+ model_j0.eval()
25
 
26
+ model_j45 = RegressionModel2(1, 32, 1)
27
+ model_j45.load_state_dict(torch.load('j45_model.pt'))
28
+ model_j45.eval()
29
 
30
+ def calculate_initial_j0_j45(magnitude, axis):
31
+ """Calculate initial J0 and J45 from magnitude and axis."""
32
+ axis_rad = math.radians(axis)
33
+ j0 = magnitude * math.cos(2 * axis_rad)
34
+ j45 = magnitude * math.sin(2 * axis_rad)
35
+ return j0, j45
36
+
37
+ def predict_new_j0_j45(j0, j45):
38
+ """Predict new J0 and J45 using the loaded models."""
39
  with torch.no_grad():
40
+ new_j0 = model_j0(torch.tensor([[j0]], dtype=torch.float32)).item()
41
+ new_j45 = model_j45(torch.tensor([[j45]], dtype=torch.float32)).item()
42
+ return new_j0, new_j45
 
43
 
44
+ def calculate_magnitude_and_axis(j0, j45):
45
+ """Calculate magnitude and axis from J0 and J45."""
46
+ magnitude = math.sqrt(j0**2 + j45**2)
47
+ z = 0.5 * math.atan2(j45, j0)
48
+ z_deg = math.degrees(z)
 
 
 
49
 
50
+ if j0 > 0:
51
+ if j45 > 0:
52
+ final_axis = z_deg
 
 
 
 
53
  else:
54
+ final_axis = z_deg + 180
55
  else:
56
+ final_axis = z_deg + 90
57
 
 
58
  final_axis = final_axis % 180
59
 
60
  return magnitude, final_axis
 
71
 
72
  if st.button('Predict!'):
73
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
74
+ # Step 2: Calculate initial J0 and J45
75
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
76
+
77
+ # Step 3: Predict new J0 and J45
78
+ new_j0, new_j45 = predict_new_j0_j45(initial_j0, initial_j45)
79
+
80
+ # Steps 4-6: Calculate predicted magnitude and axis
81
+ predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)
82
+
83
  st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
84
  st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
85
 
86
+ # Display intermediate values for verification
87
+ st.info(f'Initial J0: {initial_j0:.4f}, Initial J45: {initial_j45:.4f}')
88
+ st.info(f'Predicted J0: {new_j0:.4f}, Predicted J45: {new_j45:.4f}')
89
  else:
90
  st.error('Please ensure all inputs are within the specified ranges.')
91