Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
-
model_j0 = RegressionModel2(
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
-
model_j45 = RegressionModel2(
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
@@ -36,17 +36,16 @@ def calculate_initial_j0_j45(magnitude, axis_deg):
|
|
36 |
|
37 |
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
input_data_j45 = torch.tensor([[initial_j45]], dtype=torch.float32)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
new_j0 = model_j0(input_data_j0).item()
|
48 |
new_j45 = model_j45(input_data_j45).item()
|
49 |
-
|
50 |
return new_j0, new_j45
|
51 |
|
52 |
def calculate_magnitude_and_axis(j0, j45):
|
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
+
model_j0 = RegressionModel2(3, 32, 1)
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
+
model_j45 = RegressionModel2(3, 32, 1)
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
|
|
36 |
|
37 |
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
39 |
+
aca_axis_rad = math.radians(aca_axis_deg)
|
40 |
+
aca_x = aca_magnitude * math.cos(aca_axis_rad)
|
41 |
+
aca_y = aca_magnitude * math.sin(aca_axis_rad)
|
42 |
|
43 |
+
input_data_j0 = torch.tensor([[age, aca_axis_deg, aca_x]], dtype=torch.float32)
|
44 |
+
input_data_j45 = torch.tensor([[age, aca_axis_deg, aca_y]], dtype=torch.float32)
|
|
|
45 |
|
46 |
with torch.no_grad():
|
47 |
new_j0 = model_j0(input_data_j0).item()
|
48 |
new_j45 = model_j45(input_data_j45).item()
|
|
|
49 |
return new_j0, new_j45
|
50 |
|
51 |
def calculate_magnitude_and_axis(j0, j45):
|