|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class RegressionModel2(nn.Module): |
|
def __init__(self, input_dim2, hidden_dim2, output_dim2): |
|
super(RegressionModel2, self).__init__() |
|
self.fc1 = nn.Linear(input_dim2, hidden_dim2) |
|
self.relu1 = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim2, output_dim2) |
|
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) |
|
|
|
def forward(self, x2): |
|
out = self.fc1(x2) |
|
out = self.relu1(out) |
|
out = self.batch_norm1(out) |
|
out = self.fc2(out) |
|
return out |
|
|
|
|
|
model2 = RegressionModel2(input_dim2, hidden_dim2, output_dim2) |
|
model2.load_state_dict(torch.load('model.pt')) |
|
model2.eval() |
|
|
|
def predict(age, aca, axis): |
|
""" |
|
This function takes three arguments (age, axis, aca) as input, |
|
prepares the data, makes a prediction using the loaded model, |
|
and returns the predicted value. |
|
""" |
|
|
|
data = torch.tensor([[age, aca, axis]], dtype=torch.float32) |
|
|
|
|
|
with torch.no_grad(): |
|
prediction = model2(data) |
|
|
|
|
|
return prediction.item() |
|
|
|
|
|
st.title("Astigmatism Prediction App") |
|
st.write("Enter the patient's information:") |
|
|
|
age = st.number_input("Age", min_value=0) |
|
aca = st.number_input("ACA Magnitude", min_value=0) |
|
axis = st.number_input("ACA Axis", min_value=0) |
|
|
|
if st.button("Predict"): |
|
predicted_value = predict(age, aca, axis) |
|
st.write(f"Predicted Astigmatism Value: {predicted_value}") |