Jfink09 commited on
Commit
eefcc06
·
verified ·
1 Parent(s): 02cf73f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class LaserPredictions(nn.Module):
8
+ def __init__(self, input_dim, hidden_dim, output_dim):
9
+ super(LaserPredictions, self).__init__()
10
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
11
+ self.relu1 = nn.ReLU()
12
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
13
+ self.batch_norm1 = nn.BatchNorm1d(hidden_dim)
14
+
15
+ def forward(self, x2):
16
+ out = self.fc1(x2)
17
+ out = self.relu1(out)
18
+ out = self.batch_norm1(out)
19
+ out = self.fc2(out)
20
+ return out
21
+
22
+ # Load the saved model state dictionary
23
+ model = LaserPredictions(6, 32, 3)
24
+ model.load_state_dict(torch.load('laser_prescription_model.pt'))
25
+ model.eval() # Set the model to evaluation mode
26
+
27
+ def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis, three_month_sphere, three_month_cylinder, three_month_axis):
28
+ input_data = torch.tensor([[pre_op_sphere, pre_op_cylinder, pre_op_axis, three_month_sphere, three_month_cylinder, three_month_axis]], dtype=torch.float32)
29
+
30
+ with torch.no_grad():
31
+ predicted_prescription = model(input_data)
32
+
33
+ predicted_sphere = predicted_prescription[0][0].item()
34
+ predicted_cylinder = predicted_prescription[0][1].item()
35
+ predicted_axis = predicted_prescription[0][2].item()
36
+
37
+ return f"Predicted Laser Prescription:\nSphere: {predicted_sphere:.2f}\nCylinder: {predicted_cylinder:.2f}\nAxis: {predicted_axis:.2f}"
38
+
39
+ inputs = [
40
+ gr.inputs.Number(label="Pre-Op Sphere"),
41
+ gr.inputs.Number(label="Pre-Op Cylinder"),
42
+ gr.inputs.Number(label="Pre-Op Axis"),
43
+ gr.inputs.Number(label="3-Month Sphere"),
44
+ gr.inputs.Number(label="3-Month Cylinder"),
45
+ gr.inputs.Number(label="3-Month Axis")
46
+ ]
47
+ output = gr.outputs.Textbox(label="Predicted Laser Prescription")
48
+
49
+ gr.Interface(fn=predict, inputs=inputs, outputs=output, title="Laser Prescription Prediction").launch()