Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,11 +19,11 @@ class RegressionModel2(nn.Module):
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
-
model_j0 = RegressionModel2(
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
-
model_j45 = RegressionModel2(
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
@@ -34,11 +34,12 @@ def calculate_initial_j0_j45(magnitude, axis):
|
|
34 |
j45 = magnitude * math.sin(2 * axis_rad)
|
35 |
return j0, j45
|
36 |
|
37 |
-
def predict_new_j0_j45(
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
|
|
39 |
with torch.no_grad():
|
40 |
-
new_j0 = model_j0(
|
41 |
-
new_j45 = model_j45(
|
42 |
return new_j0, new_j45
|
43 |
|
44 |
def calculate_magnitude_and_axis(j0, j45):
|
@@ -70,11 +71,11 @@ def main():
|
|
70 |
|
71 |
if st.button('Predict!'):
|
72 |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
73 |
-
# Calculate initial J0 and J45
|
74 |
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
|
75 |
|
76 |
# Predict new J0 and J45 using the models
|
77 |
-
new_j0, new_j45 = predict_new_j0_j45(
|
78 |
|
79 |
# Calculate predicted magnitude and axis
|
80 |
predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)
|
|
|
19 |
return out
|
20 |
|
21 |
# Load the saved model state dictionaries
|
22 |
+
model_j0 = RegressionModel2(3, 32, 1)
|
23 |
model_j0.load_state_dict(torch.load('j0_model.pt'))
|
24 |
model_j0.eval()
|
25 |
|
26 |
+
model_j45 = RegressionModel2(3, 32, 1)
|
27 |
model_j45.load_state_dict(torch.load('j45_model.pt'))
|
28 |
model_j45.eval()
|
29 |
|
|
|
34 |
j45 = magnitude * math.sin(2 * axis_rad)
|
35 |
return j0, j45
|
36 |
|
37 |
+
def predict_new_j0_j45(age, aca_magnitude, aca_axis):
|
38 |
"""Predict new J0 and J45 using the loaded models."""
|
39 |
+
input_data = torch.tensor([[age, aca_magnitude, aca_axis]], dtype=torch.float32)
|
40 |
with torch.no_grad():
|
41 |
+
new_j0 = model_j0(input_data).item()
|
42 |
+
new_j45 = model_j45(input_data).item()
|
43 |
return new_j0, new_j45
|
44 |
|
45 |
def calculate_magnitude_and_axis(j0, j45):
|
|
|
71 |
|
72 |
if st.button('Predict!'):
|
73 |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
|
74 |
+
# Calculate initial J0 and J45 (for display purposes only)
|
75 |
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
|
76 |
|
77 |
# Predict new J0 and J45 using the models
|
78 |
+
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
|
79 |
|
80 |
# Calculate predicted magnitude and axis
|
81 |
predicted_magnitude, predicted_axis = calculate_magnitude_and_axis(new_j0, new_j45)
|