Jfink09's picture
Update app.py
a5bcb7f verified
raw
history blame
4.79 kB
import streamlit as st
import torch
import torch.nn as nn
import math
class RegressionModel2(nn.Module):
def __init__(self, input_dim2, hidden_dim2, output_dim2):
super(RegressionModel2, self).__init__()
self.fc1 = nn.Linear(input_dim2, hidden_dim2)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim2, output_dim2)
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
def forward(self, x2):
out = self.fc1(x2)
out = self.relu1(out)
out = self.batch_norm1(out)
out = self.fc2(out)
return out
# Load the saved model state dictionaries
model_j0 = RegressionModel2(3, 32, 1)
model_j0.load_state_dict(torch.load('j0_model-2.pt'))
model_j0.eval()
model_j45 = RegressionModel2(3, 32, 1)
model_j45.load_state_dict(torch.load('j45_model-2.pt'))
model_j45.eval()
def calculate_initial_j0_j45(magnitude, axis_deg):
"""Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
axis_rad = math.radians(axis_deg)
j0 = round(magnitude * math.cos(2 * axis_rad), 2)
j45 = round(magnitude * math.sin(2 * axis_rad), 2)
return j0, j45
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
"""Predict new J0 and J45 using the loaded models."""
aca_axis_rad = math.radians(aca_axis_deg)
aca_x = aca_magnitude * math.cos(aca_axis_rad)
aca_y = aca_magnitude * math.sin(aca_axis_rad)
input_data_j0 = torch.tensor([[age, aca_axis_deg, aca_x]], dtype=torch.float32)
input_data_j45 = torch.tensor([[age, aca_axis_deg, aca_y]], dtype=torch.float32)
with torch.no_grad():
new_j0 = model_j0(input_data_j0).item()
new_j45 = model_j45(input_data_j45).item()
return new_j0, new_j45
def calculate_magnitude(j0, j45):
"""Calculate magnitude from J0 and J45."""
return math.sqrt(j0**2 + j45**2)
def calculate_axis(j0, j45):
"""Calculate axis from J0 and J45."""
axis = 0.5 * math.degrees(math.atan2(j45, j0))
if axis < 0:
axis += 180
return axis
def main():
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
st.title('Total Corneal Astigmatism Prediction')
# Input fields
age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, step=0.1)
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.01)
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1)
if st.button('Predict!'):
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
# Calculate initial J0 and J45 (for comparison)
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
# Predict new J0 and J45 using the models
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
# Calculate predicted magnitude and axis
predicted_magnitude = calculate_magnitude(new_j0, new_j45)
predicted_axis = calculate_axis(new_j0, new_j45)
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
# Display intermediate values for verification
st.info(f'Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°')
st.info(f'Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}')
st.info(f'Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}')
st.info(f'Intermediate calculations:')
st.info(f' atan2(J45, J0): {math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
st.info(f' 0.5 * atan2(J45, J0): {0.5 * math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
# Additional debugging information
st.subheader('Debugging Information:')
st.write(f'Input age: {age}')
st.write(f'Input ACA magnitude: {aca_magnitude:.2f} D')
st.write(f'Input ACA axis: {aca_axis:.1f}°')
st.write(f'Calculated ACA X: {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}')
st.write(f'Calculated ACA Y: {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}')
st.write(f'Model J0 input: [{age}, {aca_axis}, {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}]')
st.write(f'Model J45 input: [{age}, {aca_axis}, {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}]')
else:
st.error('Please ensure all inputs are within the specified ranges.')
if __name__ == '__main__':
main()