Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
-
model_j0 = RegressionModel2(
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
-
model_j45 = RegressionModel2(
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
@@ -36,11 +36,17 @@ def calculate_initial_j0_j45(magnitude, axis_deg):
|
|
36 |
|
37 |
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
with torch.no_grad():
|
42 |
-
new_j0 = model_j0(
|
43 |
-
new_j45 = model_j45(
|
|
|
44 |
return new_j0, new_j45
|
45 |
|
46 |
def calculate_magnitude_and_axis(j0, j45):
|
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
+
model_j0 = RegressionModel2(1, 32, 1)
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
+
model_j45 = RegressionModel2(1, 32, 1)
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
|
|
36 |
|
37 |
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
39 |
+
# Calculate initial J0 and J45
|
40 |
+
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
|
41 |
+
|
42 |
+
# Prepare inputs for each model
|
43 |
+
input_data_j0 = torch.tensor([[initial_j0]], dtype=torch.float32)
|
44 |
+
input_data_j45 = torch.tensor([[initial_j45]], dtype=torch.float32)
|
45 |
+
|
46 |
with torch.no_grad():
|
47 |
+
new_j0 = model_j0(input_data_j0).item()
|
48 |
+
new_j45 = model_j45(input_data_j45).item()
|
49 |
+
|
50 |
return new_j0, new_j45
|
51 |
|
52 |
def calculate_magnitude_and_axis(j0, j45):
|