Spaces:
Running
Running
File size: 4,786 Bytes
52192c2 8e1c747 dcaeb83 485bcab 52192c2 8e1c747 a5bcb7f 485bcab 52192c2 61229f1 3ceae9e 485bcab 61229f1 8f8e201 8e1c747 0c0dd4a 8e1c747 0c0dd4a 8f8e201 0c0dd4a 8f8e201 be5ae35 8f8e201 be5ae35 8f8e201 6f52cd7 06673ea 8f8e201 6f52cd7 be5ae35 8f8e201 3ceae9e 8f8e201 be5ae35 3ceae9e be5ae35 3ceaab3 be5ae35 8f8e201 52192c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import streamlit as st
import torch
import torch.nn as nn
import math
class RegressionModel2(nn.Module):
def __init__(self, input_dim2, hidden_dim2, output_dim2):
super(RegressionModel2, self).__init__()
self.fc1 = nn.Linear(input_dim2, hidden_dim2)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim2, output_dim2)
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
def forward(self, x2):
out = self.fc1(x2)
out = self.relu1(out)
out = self.batch_norm1(out)
out = self.fc2(out)
return out
# Load the saved model state dictionaries
model_j0 = RegressionModel2(3, 32, 1)
model_j0.load_state_dict(torch.load('j0_model-2.pt'))
model_j0.eval()
model_j45 = RegressionModel2(3, 32, 1)
model_j45.load_state_dict(torch.load('j45_model-2.pt'))
model_j45.eval()
def calculate_initial_j0_j45(magnitude, axis_deg):
"""Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
axis_rad = math.radians(axis_deg)
j0 = round(magnitude * math.cos(2 * axis_rad), 2)
j45 = round(magnitude * math.sin(2 * axis_rad), 2)
return j0, j45
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
"""Predict new J0 and J45 using the loaded models."""
aca_axis_rad = math.radians(aca_axis_deg)
aca_x = aca_magnitude * math.cos(aca_axis_rad)
aca_y = aca_magnitude * math.sin(aca_axis_rad)
input_data_j0 = torch.tensor([[age, aca_axis_deg, aca_x]], dtype=torch.float32)
input_data_j45 = torch.tensor([[age, aca_axis_deg, aca_y]], dtype=torch.float32)
with torch.no_grad():
new_j0 = model_j0(input_data_j0).item()
new_j45 = model_j45(input_data_j45).item()
return new_j0, new_j45
def calculate_magnitude(j0, j45):
"""Calculate magnitude from J0 and J45."""
return math.sqrt(j0**2 + j45**2)
def calculate_axis(j0, j45):
"""Calculate axis from J0 and J45."""
axis = 0.5 * math.degrees(math.atan2(j45, j0))
if axis < 0:
axis += 180
return axis
def main():
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
st.title('Total Corneal Astigmatism Prediction')
# Input fields
age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, step=0.1)
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.01)
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1)
if st.button('Predict!'):
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
# Calculate initial J0 and J45 (for comparison)
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
# Predict new J0 and J45 using the models
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
# Calculate predicted magnitude and axis
predicted_magnitude = calculate_magnitude(new_j0, new_j45)
predicted_axis = calculate_axis(new_j0, new_j45)
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
# Display intermediate values for verification
st.info(f'Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°')
st.info(f'Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}')
st.info(f'Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}')
st.info(f'Intermediate calculations:')
st.info(f' atan2(J45, J0): {math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
st.info(f' 0.5 * atan2(J45, J0): {0.5 * math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
# Additional debugging information
st.subheader('Debugging Information:')
st.write(f'Input age: {age}')
st.write(f'Input ACA magnitude: {aca_magnitude:.2f} D')
st.write(f'Input ACA axis: {aca_axis:.1f}°')
st.write(f'Calculated ACA X: {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}')
st.write(f'Calculated ACA Y: {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}')
st.write(f'Model J0 input: [{age}, {aca_axis}, {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}]')
st.write(f'Model J45 input: [{age}, {aca_axis}, {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}]')
else:
st.error('Please ensure all inputs are within the specified ranges.')
if __name__ == '__main__':
main() |