Jfink09 commited on
Commit
dec535e
·
verified ·
1 Parent(s): 7bf4a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -20,12 +20,12 @@ class LaserPredictions(nn.Module):
20
  return out
21
 
22
  # Load the saved model state dictionary
23
- model = LaserPredictions(6, 32, 3)
24
  model.load_state_dict(torch.load('laser_prescription_model.pt'))
25
  model.eval() # Set the model to evaluation mode
26
 
27
- def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis, three_month_sphere, three_month_cylinder, three_month_axis):
28
- input_data = torch.tensor([[pre_op_sphere, pre_op_cylinder, pre_op_axis, three_month_sphere, three_month_cylinder, three_month_axis]], dtype=torch.float32)
29
 
30
  with torch.no_grad():
31
  predicted_prescription = model(input_data)
@@ -40,9 +40,6 @@ inputs = [
40
  gr.Number(label="Pre-Op Sphere"),
41
  gr.Number(label="Pre-Op Cylinder"),
42
  gr.Number(label="Pre-Op Axis"),
43
- gr.Number(label="3-Month Sphere"),
44
- gr.Number(label="3-Month Cylinder"),
45
- gr.Number(label="3-Month Axis")
46
  ]
47
  output = gr.Textbox(label="Predicted Laser Prescription")
48
 
 
20
  return out
21
 
22
  # Load the saved model state dictionary
23
+ model = LaserPredictions(3, 32, 3)
24
  model.load_state_dict(torch.load('laser_prescription_model.pt'))
25
  model.eval() # Set the model to evaluation mode
26
 
27
+ def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis):
28
+ input_data = torch.tensor([[pre_op_sphere, pre_op_cylinder, pre_op_axis]], dtype=torch.float32)
29
 
30
  with torch.no_grad():
31
  predicted_prescription = model(input_data)
 
40
  gr.Number(label="Pre-Op Sphere"),
41
  gr.Number(label="Pre-Op Cylinder"),
42
  gr.Number(label="Pre-Op Axis"),
 
 
 
43
  ]
44
  output = gr.Textbox(label="Predicted Laser Prescription")
45