Jfink09 commited on
Commit
2c582df
·
verified ·
1 Parent(s): 2f398a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -21
app.py CHANGED
@@ -19,40 +19,51 @@ class RegressionModel2(nn.Module):
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
- model_j0 = RegressionModel2(3, 32, 1)
23
- model_j0.load_state_dict(torch.load('j0_model.pt'))
24
- model_j0.eval()
25
 
26
- model_j45 = RegressionModel2(3, 32, 1)
27
- model_j45.load_state_dict(torch.load('j45_model.pt'))
28
- model_j45.eval()
29
 
30
  def predict_components(age, axis, aca):
31
  """
32
- This function predicts both J0 and J45 components using the loaded models.
33
  """
34
  data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
35
 
36
  with torch.no_grad():
37
- j0_pred = model_j0(data).item()
38
- j45_pred = model_j45(data).item()
39
 
40
- return j0_pred, j45_pred
41
 
42
- def calculate_magnitude_and_axis(j0, j45):
43
  """
44
- Calculate magnitude and axis from J0 and J45 components.
45
  """
46
- magnitude = 2 * math.sqrt(j0**2 + j45**2)
47
 
48
- axis = 0.5 * math.atan2(j45, j0)
49
- axis = math.degrees(axis)
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Ensure axis is between 0 and 180
52
- if axis < 0:
53
- axis += 180
54
 
55
- return magnitude, axis
56
 
57
  def main():
58
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
@@ -66,10 +77,14 @@ def main():
66
 
67
  if st.button('Predict!'):
68
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
69
- j0_pred, j45_pred = predict_components(age, aca_axis, aca_magnitude)
70
- predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(j0_pred, j45_pred)
71
- st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
72
  st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
 
 
 
 
73
  else:
74
  st.error('Please ensure all inputs are within the specified ranges.')
75
 
 
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
+ model_x = RegressionModel2(3, 32, 1)
23
+ model_x.load_state_dict(torch.load('j0_model.pt'))
24
+ model_x.eval()
25
 
26
+ model_y = RegressionModel2(3, 32, 1)
27
+ model_y.load_state_dict(torch.load('j45_model.pt'))
28
+ model_y.eval()
29
 
30
  def predict_components(age, axis, aca):
31
  """
32
+ This function predicts both x (J0) and y (J45) components using the loaded models.
33
  """
34
  data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
35
 
36
  with torch.no_grad():
37
+ x_pred = model_x(data).item()
38
+ y_pred = model_y(data).item()
39
 
40
+ return x_pred, y_pred
41
 
42
+ def calculate_magnitude_and_axis(x, y):
43
  """
44
+ Calculate magnitude and axis from x and y components using Excel formulas.
45
  """
46
+ magnitude = math.sqrt(x**2 + y**2)
47
 
48
+ # Calculate intermediate axis (in radians)
49
+ intermediate_axis_rad = 0.5 * math.atan2(y, x)
50
+
51
+ # Convert to degrees
52
+ intermediate_axis_deg = math.degrees(intermediate_axis_rad)
53
+
54
+ # Calculate final axis
55
+ if x > 0:
56
+ if y > 0:
57
+ final_axis = intermediate_axis_deg
58
+ else:
59
+ final_axis = intermediate_axis_deg + 180
60
+ else:
61
+ final_axis = intermediate_axis_deg + 90
62
 
63
  # Ensure axis is between 0 and 180
64
+ final_axis = final_axis % 180
 
65
 
66
+ return magnitude, final_axis
67
 
68
  def main():
69
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
 
77
 
78
  if st.button('Predict!'):
79
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
80
+ x_pred, y_pred = predict_components(age, aca_axis, aca_magnitude)
81
+ predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(x_pred, y_pred)
82
+ st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.4f} D')
83
  st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
84
+
85
+ # Display predicted x and y components for verification
86
+ st.info(f'Predicted x (J0) component: {x_pred:.4f}')
87
+ st.info(f'Predicted y (J45) component: {y_pred:.4f}')
88
  else:
89
  st.error('Please ensure all inputs are within the specified ranges.')
90