import streamlit as st import torch import torch.nn as nn import math class RegressionModel2(nn.Module): def __init__(self, input_dim2, hidden_dim2, output_dim2): super(RegressionModel2, self).__init__() self.fc1 = nn.Linear(input_dim2, hidden_dim2) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim2, output_dim2) self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionaries @st.cache_resource def load_models(): model_j0 = RegressionModel2(3, 32, 1) model_j0.load_state_dict(torch.load('j0_model-2.pt')) model_j0.eval() model_j45 = RegressionModel2(3, 32, 1) model_j45.load_state_dict(torch.load('j45_model-2.pt')) model_j45.eval() return model_j0, model_j45 model_j0, model_j45 = load_models() def calculate_initial_j0_j45(magnitude, axis_deg): """Calculate initial J0 and J45 from magnitude and axis (in degrees).""" axis_rad = math.radians(axis_deg) j0 = magnitude * math.cos(2 * axis_rad) j45 = magnitude * math.sin(2 * axis_rad) return j0, j45 def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg): """Predict new J0 and J45 using the loaded models.""" initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg) input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32) input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32) with torch.no_grad(): new_j0 = model_j0(input_data_j0).item() new_j45 = model_j45(input_data_j45).item() return new_j0, new_j45 def main(): st.title('Total Corneal Astigmatism Prediction') # User input fields age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=1.0) aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.1) aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1) if st.button('Predict'): # Calculate initial J0 and J45 initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis) # Make prediction new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis) # Calculate TCA magnitude and axis tca_magnitude = math.sqrt(new_j0**2 + new_j45**2) tca_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0)) if tca_axis < 0: tca_axis += 180 # Display results st.subheader('Prediction Results') col1, col2 = st.columns(2) with col1: st.write(f"Initial J0: {initial_j0:.2f}") st.write(f"Initial J45: {initial_j45:.2f}") st.write(f"Predicted J0: {new_j0:.2f}") st.write(f"Predicted J45: {new_j45:.2f}") with col2: st.write(f"Predicted TCA Magnitude: {tca_magnitude:.2f} D") st.write(f"Predicted TCA Axis: {tca_axis:.1f}°") # Optional: Display input tensors and raw outputs for verification if st.checkbox('Show detailed model inputs and outputs'): st.subheader('Model Details') st.write("Input tensor for J0:", torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)) st.write("Input tensor for J45:", torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)) st.write("Raw J0 output:", new_j0) st.write("Raw J45 output:", new_j45) if __name__ == '__main__': main()