import streamlit as st import pandas as pd import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import r2_score class RegressionModel2(nn.Module): def __init__(self, input_dim2, hidden_dim2, output_dim2): super(RegressionModel2, self).__init__() self.fc1 = nn.Linear(input_dim2, hidden_dim2) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim2, output_dim2) self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = RegressionModel2(3, 32, 1) model.load_state_dict(torch.load('model.pt')) model.eval() # Set the model to evaluation mode # Define a function to make predictions def predict_astigmatism(age, axis, aca): """ This function takes three arguments (age, axis, aca) as input, converts them to a tensor, makes a prediction using the loaded model, and returns the predicted value. """ # Prepare the input data data = torch.tensor([[age, axis, aca]], dtype=torch.float32) # Make prediction with torch.no_grad(): prediction = model(data) # Return the predicted value return prediction.item() def main(): st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') st.write('', unsafe_allow_html=True) st.write("""""", unsafe_allow_html=True) st.markdown( """ """, unsafe_allow_html=True ) # st.markdown( # """ # #
# #
#

Enter Variables

#
#
#
# # """, # unsafe_allow_html=True # ) age = st.number_input('Enter Patient Age:', step=0.1) aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1) aca_axis = st.number_input('Enter ACA Axis:', step=0.1) if st.button('Predict!'): astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') if __name__ == '__main__': main()