import gradio as gr import pandas as pd from sklearn.model_selection import train_test_split import torch import torch.nn as nn class LaserPredictions(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(LaserPredictions, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.batch_norm1 = nn.BatchNorm1d(hidden_dim) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = LaserPredictions(6, 32, 3) model.load_state_dict(torch.load('laser_prescription_model.pt')) model.eval() # Set the model to evaluation mode import gradio as gr import pandas as pd from sklearn.model_selection import train_test_split import torch import torch.nn as nn class LaserPredictions(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(LaserPredictions, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.batch_norm1 = nn.BatchNorm1d(hidden_dim) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = LaserPredictions(6, 32, 3) model.load_state_dict(torch.load('laser_prescription_model.pt')) model.eval() # Set the model to evaluation mode def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis): # Use zero values for post-op features, as the target prescription is set to 0 post_op_values = [0.0, 0.0, 0.0] # Combine pre-op and post-op values input_data = [pre_op_sphere, pre_op_cylinder, pre_op_axis] + post_op_values input_tensor = torch.tensor([input_data], dtype=torch.float32) with torch.no_grad(): predicted_prescription = model(input_tensor) predicted_sphere = predicted_prescription[0][0].item() predicted_cylinder = predicted_prescription[0][1].item() predicted_axis = predicted_prescription[0][2].item() return f"Predicted Laser Prescription:\nSphere: {predicted_sphere:.2f}\nCylinder: {predicted_cylinder:.2f}\nAxis: {predicted_axis:.2f}" css = """ .gradio-container { background-color: #131517; } button { background: #102534; border: 1px solid #0c538c; outline: none; border-radius: 5px; padding: 10px 20px; font-size: 16px; color: #fff; } button:hover { cursor: pointer; opacity: .7; } input { background-color: #202428; color: #fff; border: 1px solid #2d353c; border-radius: 5px; padding: 10px 20px; outline: none; font-size: 16px; } ::-webkit-input-placeholder { color: #7a848f; } input:focus::placeholder { color: transparent; } #component-0 { background-color: #131517; } #component-1 { background-color: #131517; } #component-2 { background-color: #131517; } #component-3 { background-color: #131517; } #component-4 { background-color: #131517; } #component-5 { background-color: #131517; } #component-6 { background-color: #131517; } #component-7 { background-color: #131517; } #component-8 { background-color: #131517; } #component-9 { background-color: #131517; } #component-10 { background-color: #131517; } #component-13 { background-color: #131517; } textarea { background-color: #131517; resize: none; } footer { visibility: hidden; } """ inputs = [ gr.Number(label="Pre-Op Sphere"), gr.Number(label="Pre-Op Cylinder"), gr.Number(label="Pre-Op Axis"), ] output = gr.Textbox(label="Predicted Laser Prescription") gr.Interface(fn=predict, inputs=inputs, outputs=output, title="Laser Treatment Prediction", css=css).launch(share=True)