File size: 8,563 Bytes
71f7546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import streamlit as st 
import pandas as pd 
import pickle
import json

with open('list_col_num.txt', 'r') as file_1:
  list_col_num_if = json.load(file_1)

with open('list_col_cat.txt', 'r') as file_2:
  list_col_cat_if = json.load(file_2)

with open('model_dt.pkl', 'rb') as file_3:
  model_dt = pickle.load(file_3)

def get_trt_index(trt):
    if trt == 'ZDV only':
        return 0
    elif trt == 'ZDV + ddl':
        return 1
    elif trt == 'ZDV + Zal':
        return 2
    else:
        return 3

def get_yes_no_index(yes_no):
    if yes_no == 'no':
        return 0
    else:
        return 1

def get_strat_index(strat):
    if strat == 'Antiretroviral Naive':
        return 1
    elif strat == 'Antiretroviral <= 52 weeks':
        return 2
    else:
        return 3

def get_symptom_index(symptom):
    if symptom == 'asymp':
        return 0 
    else: 
        return 1

def get_treat_index(treat):
    if treat == 'ZDV only':
        return 0 
    else:
        return 1

def get_race_index(race):
    if race == 'White':
        return 0 
    else:
        return 1

def get_gender_index(gender):
    if gender == 'Female':
        return 0 
    else:
        return 1

def get_str2_index(str2):
    if str2 == 'naive':
        return 0 
    else:
        return 1

def run():
    with st.form("prediction_form"):
        st.write('Personal Information')
        time = st.number_input('Input time to failure or censoring', value=100)
        trt = st.selectbox('Select Treatment Indicator ', {'ZDV only','ZDV + ddl','ZDV + Zal','ddl only'},index=0)
        age = st.number_input('Input age in years', value=20)
        wtkg = st.number_input('Input weight in kg', value=40.0)
        hemo = st.selectbox('Is patient has hemophilia ?', {'no','yes'},index=0)
        homo = st.selectbox('Is patient has experience do homosexuality activity ?', {'no','yes'},index=0)
        drugs = st.selectbox('Is patient has history of IV drug use ?', {'no','yes'},index=0)
        karnof = st.number_input('Input Karnofsky score (on scale 0 - 100)', value=40, min_value=0, max_value=100,step=1)
        oprior = st.selectbox('Is patient is Non-ZDV antiretroviral therapy pre-175 ?', {'no','yes'},index=0)
        z30 = st.selectbox('Is patient is ZDV in the 30 days prior to 175 ?', {'no','yes'},index=0)
        preanti = st.number_input('Input days pre-175 anti-retroviral therapy', value=40)
        race = st.selectbox('Input patient race ?', {'White','non-white'},index=0)
        gender = st.selectbox('Select gender ?', {'Female','Male'},index=0)
        str2 = st.selectbox('Input antiretroviral history', {'naive','experienced'},index=0)
        strat = st.selectbox('Input antiretroviral history stratification', {'Antiretroviral Naive','Antiretroviral <= 52 weeks','Antiretroviral > 52 weeks'},index=0)
        symptom = st.selectbox('Input symptomatic indicator', {'asymp','symp'},index=0)
        treat = st.selectbox('Input treatment indicator ', {'ZDV only','others'},index=0)
        offtrt = st.selectbox('Input indicator of off-trt before 96+/-5 weeks ', {'no','yes'},index=0)
        cd40 = st.number_input('Input CD4', value=40.0)
        cd420 = st.number_input('Input CD4 at 20+/-5 weeks', value=40.0)
        cd80 = st.number_input('Input CD8', value=40.0)
        cd820 = st.number_input('Input CD8 at 20+/-5 weeks', value=40.0)
        # limit_balance = st.number_input('Input limit balance', value=10000.0)
        # sex = st.selectbox('Gender', {'Male','Female'},index=0)
        
        # marital_status = st.selectbox('Marital Status ', {'Married','Single','Others'},index=0)
        # age = st.number_input('Age', value=20)
        
        # st.markdown('---')
        # st.write('Repayment Status')

        # pay_1 = st.selectbox('RePayment Status September 2005 ', {'Pay Duly','Payment delay for one month','Payment delay for two month','Payment delay for three month','Payment delay for four month','Payment delay for five month','Payment delay for six month','Payment delay for seven month','Payment delay for eight month','Payment delay for nine month'},index=0)
        # pay_2 = st.selectbox('RePayment Status August 2005 ', {'Pay Duly','Payment delay for one month','Payment delay for two month','Payment delay for three month','Payment delay for four month','Payment delay for five month','Payment delay for six month','Payment delay for seven month','Payment delay for eight month','Payment delay for nine month'},index=0)
        # pay_3 = st.selectbox('RePayment Status July 2005 ', {'Pay Duly','Payment delay for one month','Payment delay for two month','Payment delay for three month','Payment delay for four month','Payment delay for five month','Payment delay for six month','Payment delay for seven month','Payment delay for eight month','Payment delay for nine month'},index=0)
        # pay_4 = st.selectbox('RePayment Status June 2005 ', {'Pay Duly','Payment delay for one month','Payment delay for two month','Payment delay for three month','Payment delay for four month','Payment delay for five month','Payment delay for six month','Payment delay for seven month','Payment delay for eight month','Payment delay for nine month'},index=0)
        # pay_5 = st.selectbox('RePayment Status May 2005 ', {'Pay Duly','Payment delay for one month','Payment delay for two month','Payment delay for three month','Payment delay for four month','Payment delay for five month','Payment delay for six month','Payment delay for seven month','Payment delay for eight month','Payment delay for nine month'},index=0)
        # pay_6 = st.selectbox('RePayment Status April 2005 ', {'Pay Duly','Payment delay for one month','Payment delay for two month','Payment delay for three month','Payment delay for four month','Payment delay for five month','Payment delay for six month','Payment delay for seven month','Payment delay for eight month','Payment delay for nine month'},index=0)
        
        # st.markdown('---')
        # st.write('Bill Amount')
        # bill_amt_1 = st.number_input('Input Billing Amount September 2005 ', value=10000.0)
        # bill_amt_2 = st.number_input('Input Billing Amount August 2005 ', value=10000.0)
        # bill_amt_3 = st.number_input('Input Billing Amount July 2005 ', value=10000.0)
        # bill_amt_4 = st.number_input('Input Billing Amount June 2005 ', value=10000.0)
        # bill_amt_5 = st.number_input('Input Billing Amount May 2005 ', value=10000.0)
        # bill_amt_6 = st.number_input('Input Billing Amount April 2005 ', value=10000.0)

        # st.markdown('---')
        # st.write('Amount Previous payment')
        # pay_amt_1 = st.number_input('Amount of previous payment in September 2005 ', value=10000.0)
        # pay_amt_2 = st.number_input('Amount of previous payment in August 2005 ', value=10000.0)
        # pay_amt_3 = st.number_input('Amount of previous payment in July 2005 ', value=10000.0)
        # pay_amt_4 = st.number_input('Amount of previous payment in June 2005 ', value=10000.0)
        # pay_amt_5 = st.number_input('Amount of previous payment in May 2005 ', value=10000.0)
        # pay_amt_6 = st.number_input('Amount of previous payment in April 2005 ', value=10000.0)
        
        submitted = st.form_submit_button("Submit")
    st.write("Outside the form")

    data_inf = {
        'time': time,
        'trt' : get_trt_index(trt),
        'age': age,
        'wtkg': wtkg,
        'hemo': get_yes_no_index(hemo),
        'homo': get_yes_no_index(homo),
        'drugs': get_yes_no_index(drugs),
        'karnof': karnof,
        'oprior': get_yes_no_index(oprior),
        'z30': get_yes_no_index(z30),
        'preanti': preanti,
        'race': get_race_index(race),
        'gender': get_gender_index(gender),
        'str2': get_str2_index(str2),
        'strat': get_strat_index(strat),
        'symptom': get_symptom_index(symptom),
        'treat': get_treat_index(treat),
        'offtrt': get_yes_no_index(offtrt),
        'cd40':cd40,
        'cd420':cd420,
        'cd80':cd80,
        'cd820':cd820
    }

    if submitted:
        df = pd.DataFrame([data_inf])
        df[list_col_cat_if] = df[list_col_cat_if].astype(object)
        df = df[list_col_cat_if + list_col_num_if]

        # Do model predict from data input
        predict_result =  model_dt.predict(df)
        if predict_result[0] == 1 : 
            predic_result_value = 'yes'
        else: 
            predic_result_value = 'no'
        st.write(f'## Is patient infected AIDS: {predic_result_value}')



if __name__ == '__main__':
    run()