Jfink09 commited on
Commit
970f4ce
·
verified ·
1 Parent(s): a1d5259

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -20,15 +20,27 @@ class RegressionModel2(nn.Module):
20
  out = self.fc2(out)
21
  return out
22
 
23
- # Load the trained model
24
- model2 = RegressionModel2(3, 32, 1)
25
- model2.load_state_dict(torch.load('model.pt'))
26
- model2.eval()
27
 
28
- def predict_astigmatism(age, aca_magnitude, aca_axis):
29
- input_data = torch.tensor([[age, aca_magnitude, aca_axis]], dtype=torch.float32)
30
- output = model2(input_data)
31
- return output.item()
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def main():
34
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
 
20
  out = self.fc2(out)
21
  return out
22
 
23
+ # Load the saved model state dictionary
24
+ model = RegressionModel2(input_dim2, hidden_dim2, output_dim2)
25
+ model.load_state_dict(torch.load('model.pt'))
26
+ model.eval() # Set the model to evaluation mode
27
 
28
+ # Define a function to make predictions
29
+ def predict_astigmatism(age, axis, aca):
30
+ """
31
+ This function takes three arguments (age, axis, aca) as input,
32
+ converts them to a tensor, makes a prediction using the loaded model,
33
+ and returns the predicted value.
34
+ """
35
+ # Prepare the input data
36
+ data = torch.tensor([[age, axis, aca]], dtype=torch.float32)
37
+
38
+ # Make prediction
39
+ with torch.no_grad():
40
+ prediction = model(data)
41
+
42
+ # Return the predicted value
43
+ return prediction.item()
44
 
45
  def main():
46
  st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')