File size: 3,478 Bytes
d7784f0
 
47e9ce2
ddff90b
 
7ce5b82
ddff90b
7ce5b82
ddff90b
3279179
 
 
ddff90b
0416a61
1b98715
f2852e3
847adc5
e7caceb
6d2e57c
e7caceb
 
6d2e57c
e7caceb
388fbdd
e7caceb
 
388fbdd
 
e7caceb
6d2e57c
e7caceb
 
 
6d2e57c
f2852e3
38efeba
847adc5
f2852e3
0416a61
f2852e3
ddff90b
b102419
 
 
 
 
 
 
2360c00
388fbdd
81c44d2
 
 
 
 
971a385
81c44d2
 
8f768aa
 
4e45f70
3c6a305
8f768aa
b102419
 
80b9099
b102419
 
 
 
8f768aa
388fbdd
8f768aa
7ead1f4
6d2e57c
916be64
8f768aa
 
 
0416a61
1656abd
 
f2852e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import nltk
nltk.download('stopwords')
nltk.download('punkt')
import pandas as pd
#classify_abs is a dependency for extract_abs
import classify_abs
import extract_abs
#pd.set_option('display.max_colwidth', None)
import streamlit as st
import spacy
import tensorflow as tf
import pickle

########## Title for the Web App ##########
st.markdown('''<img src="https://huggingface.co/spaces/ncats/EpiPipeline4GARD/raw/main/NCATS_logo.svg?sanitize=true" alt="National Center for Advancing Translational Sciences Logo" width="700" height="600">''',unsafe_allow_html=True)
st.title("Epidemiology Extraction Pipeline for Rare Diseases")
st.subheader("National Center for Advancing Translational Sciences (NIH/NCATS)") 

#### CHANGE SIDEBAR WIDTH ###
st.markdown(
    """
    <style>
    [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
        width: 275px;
    }
    [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
        width: 275px;
        margin-left: -400px;
    }
    </style>
    """,
    unsafe_allow_html=True,
)

#max_results is Maximum number of PubMed ID's to retrieve BEFORE filtering
max_results = st.sidebar.number_input("Maximum number of articles to find in PubMed", min_value=1, max_value=None, value=50)

filtering = st.sidebar.radio("What type of filtering would you like?",('Strict', 'Lenient', 'None'))

extract_diseases = st.sidebar.checkbox("Extract Rare Diseases", value=False)

@st.experimental_singleton
def load_models_experimental():
    classify_model_vars = classify_abs.init_classify_model()
    NER_pipeline, entity_classes = extract_abs.init_NER_pipeline()
    GARD_dict, max_length = extract_abs.load_GARD_diseases()
    return classify_model_vars, NER_pipeline, entity_classes, GARD_dict, max_length

@st.cache(allow_output_mutation=True)
def load_models():
    # load the tokenizer
    with open('tokenizer.pickle', 'rb') as handle:
        classify_tokenizer = pickle.load(handle)
    
    # load the model
    classify_model = tf.keras.models.load_model("LSTM_RNN_Model") 
    
    #classify_model_vars = classify_abs.init_classify_model()
    NER_pipeline, entity_classes = extract_abs.init_NER_pipeline()
    GARD_dict, max_length = extract_abs.load_GARD_diseases()
    return classify_tokenizer, classify_model, NER_pipeline, entity_classes, GARD_dict, max_length
    
with st.spinner('Loading Epidemiology Models and Dependencies...'):
    classify_model_vars, NER_pipeline, entity_classes, GARD_dict, max_length = load_models_experimental()
    #classify_tokenizer, classify_model, NER_pipeline, entity_classes, GARD_dict, max_length = load_models()
    #Load spaCy models which cannot be cached due to hash function error
    #nlp = spacy.load('en_core_web_lg')
    #nlpSci = spacy.load("en_ner_bc5cdr_md")
    #nlpSci2 = spacy.load('en_ner_bionlp13cg_md')
    #classify_model_vars = (nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer)
st.success('All Models and Dependencies Loaded!')

disease_or_gard_id = st.text_input("Input a rare disease term or GARD ID.")

if disease_or_gard_id:
  df = extract_abs.streamlit_extraction(disease_or_gard_id, max_results, filtering,
                                        NER_pipeline, entity_classes, 
                                        extract_diseases,GARD_dict, max_length, 
                                        classify_model_vars)
  st.dataframe(df)
  #st.dataframe(data=None, width=None, height=None)
  
# st.code(body, language="python")