File size: 3,801 Bytes
52192c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241be24
 
 
e15a441
241be24
52192c2
241be24
e15a441
241be24
 
 
 
 
52192c2
61229f1
 
 
e473f84
 
485bcab
 
61229f1
8f8e201
e473f84
0c0dd4a
e473f84
 
 
8f8e201
0c0dd4a
 
e473f84
8f8e201
 
 
241be24
e473f84
241be24
 
 
 
e473f84
241be24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f8e201
52192c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import torch
import torch.nn as nn
import math

class RegressionModel2(nn.Module):
    def __init__(self, input_dim2, hidden_dim2, output_dim2):
        super(RegressionModel2, self).__init__()
        self.fc1 = nn.Linear(input_dim2, hidden_dim2)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim2, output_dim2)
        self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)

    def forward(self, x2):
        out = self.fc1(x2)
        out = self.relu1(out)
        out = self.batch_norm1(out)
        out = self.fc2(out)
        return out

# Load the saved model state dictionaries
@st.cache_resource
def load_models():
    model_j0 = RegressionModel2(3, 32, 1)
    model_j0.load_state_dict(torch.load('j0_model-2.pt'))
    model_j0.eval()

    model_j45 = RegressionModel2(3, 32, 1)
    model_j45.load_state_dict(torch.load('j45_model-2.pt'))
    model_j45.eval()
    
    return model_j0, model_j45

model_j0, model_j45 = load_models()

def calculate_initial_j0_j45(magnitude, axis_deg):
    """Calculate initial J0 and J45 from magnitude and axis (in degrees)."""
    axis_rad = math.radians(axis_deg)
    j0 = magnitude * math.cos(2 * axis_rad)
    j45 = magnitude * math.sin(2 * axis_rad)
    return j0, j45

def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
    """Predict new J0 and J45 using the loaded models."""
    initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
    
    input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
    input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
    
    with torch.no_grad():
        new_j0 = model_j0(input_data_j0).item()
        new_j45 = model_j45(input_data_j45).item()
    
    return new_j0, new_j45

def main():
    st.title('Total Corneal Astigmatism Prediction')
    
    # User input fields
    age = st.number_input('Enter Patient Age (18-90 Years):', min_value=18.0, max_value=90.0, value=58.0, step=1.0)
    aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, value=2.3, step=0.1)
    aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, value=97.7, step=0.1)
    
    if st.button('Predict'):
        # Calculate initial J0 and J45
        initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
        
        # Make prediction
        new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
        
        # Calculate TCA magnitude and axis
        tca_magnitude = math.sqrt(new_j0**2 + new_j45**2)
        tca_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
        if tca_axis < 0:
            tca_axis += 180
        
        # Display results
        st.subheader('Prediction Results')
        col1, col2 = st.columns(2)
        with col1:
            st.write(f"Initial J0: {initial_j0:.2f}")
            st.write(f"Initial J45: {initial_j45:.2f}")
            st.write(f"Predicted J0: {new_j0:.2f}")
            st.write(f"Predicted J45: {new_j45:.2f}")
        with col2:
            st.write(f"Predicted TCA Magnitude: {tca_magnitude:.2f} D")
            st.write(f"Predicted TCA Axis: {tca_axis:.1f}°")

        # Optional: Display input tensors and raw outputs for verification
        if st.checkbox('Show detailed model inputs and outputs'):
            st.subheader('Model Details')
            st.write("Input tensor for J0:", torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32))
            st.write("Input tensor for J45:", torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32))
            st.write("Raw J0 output:", new_j0)
            st.write("Raw J45 output:", new_j45)

if __name__ == '__main__':
    main()