File size: 4,786 Bytes
52192c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1c747
dcaeb83
485bcab
52192c2
8e1c747
a5bcb7f
485bcab
52192c2
61229f1
 
 
3ceae9e
 
485bcab
 
61229f1
8f8e201
8e1c747
 
 
0c0dd4a
8e1c747
 
0c0dd4a
8f8e201
0c0dd4a
 
8f8e201
 
be5ae35
 
 
 
 
 
 
 
 
 
8f8e201
 
 
 
 
 
 
 
 
 
 
 
 
be5ae35
8f8e201
 
6f52cd7
06673ea
8f8e201
6f52cd7
be5ae35
 
8f8e201
3ceae9e
8f8e201
 
 
be5ae35
3ceae9e
 
be5ae35
 
 
3ceaab3
 
 
be5ae35
 
 
 
 
 
 
8f8e201
 
 
52192c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import torch
import torch.nn as nn
import math

class RegressionModel2(nn.Module):
    def __init__(self, input_dim2, hidden_dim2, output_dim2):
        super(RegressionModel2, self).__init__()
        self.fc1 = nn.Linear(input_dim2, hidden_dim2)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim2, output_dim2)
        self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)

    def forward(self, x2):
        out = self.fc1(x2)
        out = self.relu1(out)
        out = self.batch_norm1(out)
        out = self.fc2(out)
        return out

# Load the saved model state dictionaries
model_j0 = RegressionModel2(3, 32, 1)
model_j0.load_state_dict(torch.load('j0_model-2.pt'))
model_j0.eval()

model_j45 = RegressionModel2(3, 32, 1)
model_j45.load_state_dict(torch.load('j45_model-2.pt'))
model_j45.eval()

def calculate_initial_j0_j45(magnitude, axis_deg):
    """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
    axis_rad = math.radians(axis_deg)
    j0 = round(magnitude * math.cos(2 * axis_rad), 2)
    j45 = round(magnitude * math.sin(2 * axis_rad), 2)
    return j0, j45

def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
    """Predict new J0 and J45 using the loaded models."""
    aca_axis_rad = math.radians(aca_axis_deg)
    aca_x = aca_magnitude * math.cos(aca_axis_rad)
    aca_y = aca_magnitude * math.sin(aca_axis_rad)
    
    input_data_j0 = torch.tensor([[age, aca_axis_deg, aca_x]], dtype=torch.float32)
    input_data_j45 = torch.tensor([[age, aca_axis_deg, aca_y]], dtype=torch.float32)
    
    with torch.no_grad():
        new_j0 = model_j0(input_data_j0).item()
        new_j45 = model_j45(input_data_j45).item()
    return new_j0, new_j45

def calculate_magnitude(j0, j45):
    """Calculate magnitude from J0 and J45."""
    return math.sqrt(j0**2 + j45**2)

def calculate_axis(j0, j45):
    """Calculate axis from J0 and J45."""
    axis = 0.5 * math.degrees(math.atan2(j45, j0))
    if axis < 0:
        axis += 180
    return axis

def main():
    st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
    
    st.title('Total Corneal Astigmatism Prediction')

    # Input fields
    age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, step=0.1)
    aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.01)
    aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1)

    if st.button('Predict!'):
        if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
            # Calculate initial J0 and J45 (for comparison)
            initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
            
            # Predict new J0 and J45 using the models
            new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
            
            # Calculate predicted magnitude and axis
            predicted_magnitude = calculate_magnitude(new_j0, new_j45)
            predicted_axis = calculate_axis(new_j0, new_j45)
            
            st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
            st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
            
            # Display intermediate values for verification
            st.info(f'Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°')
            st.info(f'Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}')
            st.info(f'Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}')
            st.info(f'Intermediate calculations:')
            st.info(f'  atan2(J45, J0): {math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
            st.info(f'  0.5 * atan2(J45, J0): {0.5 * math.degrees(math.atan2(new_j45, new_j0)):.2f}°')
            
            # Additional debugging information
            st.subheader('Debugging Information:')
            st.write(f'Input age: {age}')
            st.write(f'Input ACA magnitude: {aca_magnitude:.2f} D')
            st.write(f'Input ACA axis: {aca_axis:.1f}°')
            st.write(f'Calculated ACA X: {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}')
            st.write(f'Calculated ACA Y: {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}')
            st.write(f'Model J0 input: [{age}, {aca_axis}, {aca_magnitude * math.cos(math.radians(aca_axis)):.4f}]')
            st.write(f'Model J45 input: [{age}, {aca_axis}, {aca_magnitude * math.sin(math.radians(aca_axis)):.4f}]')
        else:
            st.error('Please ensure all inputs are within the specified ranges.')

if __name__ == '__main__':
    main()