Jfink09 commited on
Commit
06673ea
·
verified ·
1 Parent(s): 6f52cd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
- model_j0 = RegressionModel2(1, 32, 1)
23
  model_j0.load_state_dict(torch.load('j0_model.pt'))
24
  model_j0.eval()
25
 
26
- model_j45 = RegressionModel2(1, 32, 1)
27
  model_j45.load_state_dict(torch.load('j45_model.pt'))
28
  model_j45.eval()
29
 
@@ -34,11 +34,12 @@ def calculate_initial_j0_j45(magnitude, axis):
34
  j45 = magnitude * math.sin(2 * axis_rad)
35
  return j0, j45
36
 
37
- def predict_new_j0_j45(initial_j0, initial_j45):
38
  """Predict new J0 and J45 using the loaded models."""
 
39
  with torch.no_grad():
40
- new_j0 = model_j0(torch.tensor([[initial_j0]], dtype=torch.float32)).item()
41
- new_j45 = model_j45(torch.tensor([[initial_j45]], dtype=torch.float32)).item()
42
  return new_j0, new_j45
43
 
44
  def calculate_magnitude_and_axis(j0, j45):
@@ -70,11 +71,11 @@ def main():
70
 
71
  if st.button('Predict!'):
72
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
73
- # Calculate initial J0 and J45
74
  initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
75
 
76
  # Predict new J0 and J45 using the models
77
- new_j0, new_j45 = predict_new_j0_j45(initial_j0, initial_j45)
78
 
79
  # Calculate predicted magnitude and axis
80
  predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)
 
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
+ model_j0 = RegressionModel2(3, 32, 1)
23
  model_j0.load_state_dict(torch.load('j0_model.pt'))
24
  model_j0.eval()
25
 
26
+ model_j45 = RegressionModel2(3, 32, 1)
27
  model_j45.load_state_dict(torch.load('j45_model.pt'))
28
  model_j45.eval()
29
 
 
34
  j45 = magnitude * math.sin(2 * axis_rad)
35
  return j0, j45
36
 
37
+ def predict_new_j0_j45(age, aca_magnitude, aca_axis):
38
  """Predict new J0 and J45 using the loaded models."""
39
+ input_data = torch.tensor([[age, aca_magnitude, aca_axis]], dtype=torch.float32)
40
  with torch.no_grad():
41
+ new_j0 = model_j0(input_data).item()
42
+ new_j45 = model_j45(input_data).item()
43
  return new_j0, new_j45
44
 
45
  def calculate_magnitude_and_axis(j0, j45):
 
71
 
72
  if st.button('Predict!'):
73
  if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
74
+ # Calculate initial J0 and J45 (for display purposes only)
75
  initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
76
 
77
  # Predict new J0 and J45 using the models
78
+ new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
79
 
80
  # Calculate predicted magnitude and axis
81
  predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)