Jfink09 commited on
Commit
0c0dd4a
·
verified ·
1 Parent(s): 61d583c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
- model_j0 = RegressionModel2(3, 32, 1)
23
  model_j0.load_state_dict(torch.load('j0_model.pt'))
24
  model_j0.eval()
25
 
26
- model_j45 = RegressionModel2(3, 32, 1)
27
  model_j45.load_state_dict(torch.load('j45_model.pt'))
28
  model_j45.eval()
29
 
@@ -36,11 +36,17 @@ def calculate_initial_j0_j45(magnitude, axis_deg):
36
 
37
  def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
38
  """Predict new J0 and J45 using the loaded models."""
39
- aca_axis_rad = math.radians(aca_axis_deg)
40
- input_data = torch.tensor([[age, aca_magnitude, aca_axis_rad]], dtype=torch.float32)
 
 
 
 
 
41
  with torch.no_grad():
42
- new_j0 = model_j0(input_data).item()
43
- new_j45 = model_j45(input_data).item()
 
44
  return new_j0, new_j45
45
 
46
  def calculate_magnitude_and_axis(j0, j45):
 
19
  return out
20
 
21
  # Load the saved model state dictionaries
22
+ model_j0 = RegressionModel2(1, 32, 1)
23
  model_j0.load_state_dict(torch.load('j0_model.pt'))
24
  model_j0.eval()
25
 
26
+ model_j45 = RegressionModel2(1, 32, 1)
27
  model_j45.load_state_dict(torch.load('j45_model.pt'))
28
  model_j45.eval()
29
 
 
36
 
37
  def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
38
  """Predict new J0 and J45 using the loaded models."""
39
+ # Calculate initial J0 and J45
40
+ initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
41
+
42
+ # Prepare inputs for each model
43
+ input_data_j0 = torch.tensor([[initial_j0]], dtype=torch.float32)
44
+ input_data_j45 = torch.tensor([[initial_j45]], dtype=torch.float32)
45
+
46
  with torch.no_grad():
47
+ new_j0 = model_j0(input_data_j0).item()
48
+ new_j45 = model_j45(input_data_j45).item()
49
+
50
  return new_j0, new_j45
51
 
52
  def calculate_magnitude_and_axis(j0, j45):