Jfink09 commited on
Commit
61c910c
·
verified ·
1 Parent(s): dec535e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -20,20 +20,53 @@ class LaserPredictions(nn.Module):
20
  return out
21
 
22
  # Load the saved model state dictionary
23
- model = LaserPredictions(3, 32, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  model.load_state_dict(torch.load('laser_prescription_model.pt'))
25
  model.eval() # Set the model to evaluation mode
26
 
27
  def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis):
28
- input_data = torch.tensor([[pre_op_sphere, pre_op_cylinder, pre_op_axis]], dtype=torch.float32)
29
-
 
 
 
 
 
30
  with torch.no_grad():
31
- predicted_prescription = model(input_data)
32
-
33
  predicted_sphere = predicted_prescription[0][0].item()
34
  predicted_cylinder = predicted_prescription[0][1].item()
35
  predicted_axis = predicted_prescription[0][2].item()
36
-
37
  return f"Predicted Laser Prescription:\nSphere: {predicted_sphere:.2f}\nCylinder: {predicted_cylinder:.2f}\nAxis: {predicted_axis:.2f}"
38
 
39
  inputs = [
 
20
  return out
21
 
22
  # Load the saved model state dictionary
23
+ model = LaserPredictions(6, 32, 3)
24
+ model.load_state_dict(torch.load('laser_prescription_model.pt'))
25
+ model.eval() # Set the model to evaluation mode
26
+
27
+ You're absolutely right! In real nomograms, the target prescription is often set to 0, which means the post-op values would naturally be close to zero. In this case, using dummy values of zero for the post-op inputs when making predictions should not significantly interfere with the model's performance.
28
+ Given this insight, let's update the code to reflect this approach:
29
+ pythonCopy codeimport gradio as gr
30
+ import pandas as pd
31
+ from sklearn.model_selection import train_test_split
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ class LaserPredictions(nn.Module):
36
+ def __init__(self, input_dim, hidden_dim, output_dim):
37
+ super(LaserPredictions, self).__init__()
38
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
39
+ self.relu1 = nn.ReLU()
40
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
41
+ self.batch_norm1 = nn.BatchNorm1d(hidden_dim)
42
+
43
+ def forward(self, x2):
44
+ out = self.fc1(x2)
45
+ out = self.relu1(out)
46
+ out = self.batch_norm1(out)
47
+ out = self.fc2(out)
48
+ return out
49
+
50
+ # Load the saved model state dictionary
51
+ model = LaserPredictions(6, 32, 3)
52
  model.load_state_dict(torch.load('laser_prescription_model.pt'))
53
  model.eval() # Set the model to evaluation mode
54
 
55
  def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis):
56
+ # Use zero values for post-op features, as the target prescription is set to 0
57
+ post_op_values = [0.0, 0.0, 0.0]
58
+
59
+ # Combine pre-op and post-op values
60
+ input_data = [pre_op_sphere, pre_op_cylinder, pre_op_axis] + post_op_values
61
+ input_tensor = torch.tensor([input_data], dtype=torch.float32)
62
+
63
  with torch.no_grad():
64
+ predicted_prescription = model(input_tensor)
65
+
66
  predicted_sphere = predicted_prescription[0][0].item()
67
  predicted_cylinder = predicted_prescription[0][1].item()
68
  predicted_axis = predicted_prescription[0][2].item()
69
+
70
  return f"Predicted Laser Prescription:\nSphere: {predicted_sphere:.2f}\nCylinder: {predicted_cylinder:.2f}\nAxis: {predicted_axis:.2f}"
71
 
72
  inputs = [