|
import gradio as gr |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
import torch |
|
import torch.nn as nn |
|
|
|
class LaserPredictions(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, output_dim): |
|
super(LaserPredictions, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
self.relu1 = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
self.batch_norm1 = nn.BatchNorm1d(hidden_dim) |
|
|
|
def forward(self, x2): |
|
out = self.fc1(x2) |
|
out = self.relu1(out) |
|
out = self.batch_norm1(out) |
|
out = self.fc2(out) |
|
return out |
|
|
|
|
|
model = LaserPredictions(6, 32, 3) |
|
model.load_state_dict(torch.load('laser_prescription_model.pt')) |
|
model.eval() |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
import torch |
|
import torch.nn as nn |
|
|
|
class LaserPredictions(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, output_dim): |
|
super(LaserPredictions, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
self.relu1 = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
self.batch_norm1 = nn.BatchNorm1d(hidden_dim) |
|
|
|
def forward(self, x2): |
|
out = self.fc1(x2) |
|
out = self.relu1(out) |
|
out = self.batch_norm1(out) |
|
out = self.fc2(out) |
|
return out |
|
|
|
|
|
model = LaserPredictions(6, 32, 3) |
|
model.load_state_dict(torch.load('laser_prescription_model.pt')) |
|
model.eval() |
|
|
|
def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis): |
|
|
|
post_op_values = [0.0, 0.0, 0.0] |
|
|
|
|
|
input_data = [pre_op_sphere, pre_op_cylinder, pre_op_axis] + post_op_values |
|
input_tensor = torch.tensor([input_data], dtype=torch.float32) |
|
|
|
with torch.no_grad(): |
|
predicted_prescription = model(input_tensor) |
|
|
|
predicted_sphere = predicted_prescription[0][0].item() |
|
predicted_cylinder = predicted_prescription[0][1].item() |
|
predicted_axis = predicted_prescription[0][2].item() |
|
|
|
return f"Predicted Laser Prescription:\nSphere: {predicted_sphere:.2f}\nCylinder: {predicted_cylinder:.2f}\nAxis: {predicted_axis:.2f}" |
|
|
|
css = """ |
|
.gradio-container { |
|
background-color: #131517; |
|
} |
|
|
|
button { |
|
background: #102534; |
|
border: 1px solid #0c538c; |
|
outline: none; |
|
border-radius: 5px; |
|
padding: 10px 20px; |
|
font-size: 16px; |
|
color: #fff; |
|
} |
|
|
|
|
|
button:hover { |
|
cursor: pointer; |
|
opacity: .7; |
|
} |
|
|
|
input { |
|
background-color: #202428; |
|
color: #fff; |
|
border: 1px solid #2d353c; |
|
border-radius: 5px; |
|
padding: 10px 20px; |
|
outline: none; |
|
font-size: 16px; |
|
} |
|
|
|
::-webkit-input-placeholder { |
|
color: #7a848f; |
|
} |
|
|
|
input:focus::placeholder { |
|
color: transparent; |
|
} |
|
|
|
#component-0 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-1 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-2 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-3 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-4 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-5 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-6 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-7 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-8 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-9 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-10 { |
|
background-color: #131517; |
|
} |
|
|
|
#component-13 { |
|
background-color: #131517; |
|
} |
|
|
|
textarea { |
|
background-color: #131517; |
|
resize: none; |
|
} |
|
|
|
footer { |
|
visibility: hidden; |
|
} |
|
""" |
|
|
|
inputs = [ |
|
gr.Number(label="Pre-Op Sphere"), |
|
gr.Number(label="Pre-Op Cylinder"), |
|
gr.Number(label="Pre-Op Axis"), |
|
] |
|
output = gr.Textbox(label="Predicted Laser Prescription") |
|
|
|
gr.Interface(fn=predict, inputs=inputs, outputs=output, title="Laser Treatment Prediction", css=css).launch(share=True) |