Enter Variables
# ## #
import streamlit as st import pandas as pd import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import r2_score class RegressionModel2(nn.Module): def __init__(self, input_dim2, hidden_dim2, output_dim2): super(RegressionModel2, self).__init__() self.fc1 = nn.Linear(input_dim2, hidden_dim2) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim2, output_dim2) self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = RegressionModel2(3, 32, 1) model.load_state_dict(torch.load('model.pt')) model.eval() # Set the model to evaluation mode # Define a function to make predictions def predict_astigmatism(age, axis, aca): """ This function takes three arguments (age, axis, aca) as input, converts them to a tensor, makes a prediction using the loaded model, and returns the predicted value. """ # Prepare the input data data = torch.tensor([[age, axis, aca]], dtype=torch.float32) # Make prediction with torch.no_grad(): prediction = model(data) # Return the predicted value return prediction.item() # def main(): # st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') # st.write('', unsafe_allow_html=True) # st.write("""""", unsafe_allow_html=True) # st.markdown( # """ # # """, # unsafe_allow_html=True # ) # # st.markdown( # # """ # #
# #Error: Age must be between 18 and 90.
', unsafe_allow_html=True) # ACA Magnitude input aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.1, value=st.session_state.aca_magnitude) if aca_magnitude != st.session_state.aca_magnitude: st.session_state.aca_magnitude = aca_magnitude if aca_magnitude is not None and (aca_magnitude < 0 or aca_magnitude > 10): st.markdown('Error: ACA Magnitude must be between 0 and 10.
', unsafe_allow_html=True) # ACA Axis input aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1, value=st.session_state.aca_axis) if aca_axis != st.session_state.aca_axis: st.session_state.aca_axis = aca_axis if aca_axis is not None and (aca_axis < 0 or aca_axis > 180): st.markdown('Error: ACA Axis must be between 0 and 180.
', unsafe_allow_html=True) if st.button('Predict!'): if age is not None and aca_magnitude is not None and aca_axis is not None: if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180: astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') else: st.error('Please correct the input errors before predicting.') else: st.error('Please fill in all fields before predicting.') if __name__ == '__main__': main()