Update app.py
Browse files
app.py
CHANGED
@@ -20,12 +20,12 @@ class LaserPredictions(nn.Module):
|
|
20 |
return out
|
21 |
|
22 |
# Load the saved model state dictionary
|
23 |
-
model = LaserPredictions(
|
24 |
model.load_state_dict(torch.load('laser_prescription_model.pt'))
|
25 |
model.eval() # Set the model to evaluation mode
|
26 |
|
27 |
-
def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis
|
28 |
-
input_data = torch.tensor([[pre_op_sphere, pre_op_cylinder, pre_op_axis
|
29 |
|
30 |
with torch.no_grad():
|
31 |
predicted_prescription = model(input_data)
|
@@ -40,9 +40,6 @@ inputs = [
|
|
40 |
gr.Number(label="Pre-Op Sphere"),
|
41 |
gr.Number(label="Pre-Op Cylinder"),
|
42 |
gr.Number(label="Pre-Op Axis"),
|
43 |
-
gr.Number(label="3-Month Sphere"),
|
44 |
-
gr.Number(label="3-Month Cylinder"),
|
45 |
-
gr.Number(label="3-Month Axis")
|
46 |
]
|
47 |
output = gr.Textbox(label="Predicted Laser Prescription")
|
48 |
|
|
|
20 |
return out
|
21 |
|
22 |
# Load the saved model state dictionary
|
23 |
+
model = LaserPredictions(3, 32, 3)
|
24 |
model.load_state_dict(torch.load('laser_prescription_model.pt'))
|
25 |
model.eval() # Set the model to evaluation mode
|
26 |
|
27 |
+
def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis):
|
28 |
+
input_data = torch.tensor([[pre_op_sphere, pre_op_cylinder, pre_op_axis]], dtype=torch.float32)
|
29 |
|
30 |
with torch.no_grad():
|
31 |
predicted_prescription = model(input_data)
|
|
|
40 |
gr.Number(label="Pre-Op Sphere"),
|
41 |
gr.Number(label="Pre-Op Cylinder"),
|
42 |
gr.Number(label="Pre-Op Axis"),
|
|
|
|
|
|
|
43 |
]
|
44 |
output = gr.Textbox(label="Predicted Laser Prescription")
|
45 |
|