import gradio as gr import pandas as pd from sklearn.model_selection import train_test_split import torch import torch.nn as nn class LaserPredictions(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(LaserPredictions, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.batch_norm1 = nn.BatchNorm1d(hidden_dim) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = LaserPredictions(6, 32, 3) model.load_state_dict(torch.load('laser_prescription_model.pt')) model.eval() # Set the model to evaluation mode import gradio as gr import pandas as pd from sklearn.model_selection import train_test_split import torch import torch.nn as nn class LaserPredictions(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(LaserPredictions, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.batch_norm1 = nn.BatchNorm1d(hidden_dim) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = LaserPredictions(6, 32, 3) model.load_state_dict(torch.load('laser_prescription_model.pt')) model.eval() # Set the model to evaluation mode def predict(pre_op_sphere, pre_op_cylinder, pre_op_axis): # Use zero values for post-op features, as the target prescription is set to 0 post_op_values = [0.0, 0.0, 0.0] # Combine pre-op and post-op values input_data = [pre_op_sphere, pre_op_cylinder, pre_op_axis] + post_op_values input_tensor = torch.tensor([input_data], dtype=torch.float32) with torch.no_grad(): predicted_prescription = model(input_tensor) predicted_sphere = predicted_prescription[0][0].item() predicted_cylinder = predicted_prescription[0][1].item() predicted_axis = predicted_prescription[0][2].item() return f"Predicted Laser Prescription:\nSphere: {predicted_sphere:.2f}\nCylinder: {predicted_cylinder:.2f}\nAxis: {predicted_axis:.2f}" css = """ gradio-app { background: #131517; } """ inputs = [ gr.Number(label="Pre-Op Sphere"), gr.Number(label="Pre-Op Cylinder"), gr.Number(label="Pre-Op Axis"), ] output = gr.Textbox(label="Predicted Laser Prescription") gr.Interface(fn=predict, inputs=inputs, outputs=output, title="Laser Prescription Prediction", css=css).launch(share=True)