File size: 12,031 Bytes
89677ab
 
e696f95
89677ab
 
 
 
 
 
e696f95
24c5c6a
 
 
 
 
 
 
 
 
 
 
301ba09
 
 
 
 
 
 
24c5c6a
301ba09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c5c6a
301ba09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c5c6a
301ba09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os


with st.spinner("Please wait while we prepare the environment. This may take a few minutes only on the first run..."):
    # Run setup script if not already executed
    if not os.path.exists(".setup_done"):
        os.system("bash setup.sh")
        with open(".setup_done", "w") as f:
            f.write("done")

import streamlit as st
import streamlit.components.v1 as components
import os
import time
import pandas as pd

from run_prothgt_app import *

def convert_df(df):
   return df.to_csv(index=False).encode('utf-8')

# Initialize session state variables
if 'predictions_df' not in st.session_state:
    st.session_state.predictions_df = None
if 'submitted' not in st.session_state:
    st.session_state.submitted = False


with st.sidebar:
    st.markdown("""
        <style>
        .title {
            font-size: 35px;
            font-weight: bold;
            color: #424242;
            margin-bottom: 0px;
        }
        .subtitle {
            font-size: 20px;
            color: #424242;
            margin-bottom: 20px;
            line-height: 1.5;
        }
        .badges {
            margin-top: 10px;
            margin-bottom: 20px;
        }
        </style>
        
        <div class="title">ProtHGT</div>
        <div class="subtitle">Heterogeneous Graph Transformers for Automated Protein Function Prediction Using Knowledge Graphs and Language Models</div>
        <div class="badges">
            <a href="">
                <img src="https://img.shields.io/badge/DOI-10.1002/pro.4988-b31b1b.svg" alt="publication">
            </a>
            <a href="https://github.com/HUBioDataLab/ProtHGT">
                <img src="https://img.shields.io/badge/GitHub-black?logo=github" alt="github-repository">
            </a>
        </div>
    """, unsafe_allow_html=True)

    available_proteins = get_available_proteins() 

    selected_proteins = []
    
    # Add protein selection methods
    selection_method = st.radio(
        "Choose input method:",
        ["Search proteins", "Upload protein ID file"]
    )

    if selection_method == "Search proteins":
        # Add custom CSS to make multiselect scrollable
        st.markdown("""
            <style>
            [data-testid="stMultiSelect"] div:nth-child(2) {
                max-height: 200px;
                overflow-y: auto;
            }
            </style>
            """, unsafe_allow_html=True)
            
        selected_proteins = st.multiselect(
            "Select or search for proteins (UniProt IDs)",
            options=available_proteins,
            placeholder="Start typing to search...",
            max_selections=100
        )

        if selected_proteins:
            st.write(f"Selected {len(selected_proteins)} proteins")
            
    else:
        uploaded_file = st.file_uploader(
            "Upload a text file with UniProt IDs (one per line, max 100)*",
            type=['txt']
        )        

        if uploaded_file:
            protein_list = [line.decode('utf-8').strip() for line in uploaded_file]
            # Remove empty lines and duplicates
            protein_list = list(filter(None, protein_list))
            protein_list = list(dict.fromkeys(protein_list))
            
            # filter out proteins that are not in available_proteins
            protein_list = [p for p in protein_list if p in available_proteins]
            proteins_not_found = [p for p in protein_list if p not in available_proteins]

            if len(protein_list) > 100:
                st.error("Please upload a file with maximum 100 protein IDs.")
                selected_proteins = []
            else:
                selected_proteins = protein_list
                st.write(f"Loaded {len(selected_proteins)} proteins")
                if proteins_not_found:
                    st.error(f"Proteins not found in input knowledge graph: {', '.join(proteins_not_found)}")
                    st.warning("Currently, our system can generate predictions only for proteins included in our input knowledge graph. Real-time retrieval of relationship data from external source databases is not yet supported. However, we are actively working on integrating this capability in future updates.")
    
    if selected_proteins:
        # Option 1: Collapsible expander
        with st.expander("View Selected Proteins"):
            st.write(f"Total proteins selected: {len(selected_proteins)}")
            
            # Create scrollable container with fixed height
            st.markdown(
                f"""
                <div style="
                    height: 150px; 
                    overflow-y: scroll;
                    border: 1px solid #ccc;
                    border-radius: 4px;
                    padding: 8px;
                    background-color: white;">
                    {'<br>'.join(selected_proteins)}
                </div>
                """, 
                unsafe_allow_html=True
            )

            st.markdown("<div style='padding-top: 10px;'></div>", unsafe_allow_html=True)

            # Add download button
            proteins_text = '\n'.join(selected_proteins)
            st.download_button(
                label="Download List",
                data=proteins_text,
                file_name="selected_proteins.txt",
                mime="text/plain",
                key="download_button"
            )

        # Add GO category selection
        go_category_options = {
            'All Categories': None,
            'Molecular Function': 'GO_term_F',
            'Biological Process': 'GO_term_P',
            'Cellular Component': 'GO_term_C'
        }
        selected_go_category = st.selectbox(
            "Select GO Category for predictions",
            options=list(go_category_options.keys()),
            help="Choose which GO category to generate predictions for. Selecting 'All Categories' will generate predictions for all three categories."
        )

    st.warning("⚠️ Due to memory and computational constraints, the maximum number of proteins that can be processed at once is limited to 100 proteins. For larger datasets, please consider running the model locally using our GitHub repository.")

    if selected_proteins and selected_go_category:
        # Add a button to trigger predictions
        if st.button("Generate Predictions"):
            st.session_state.submitted = True

if st.session_state.submitted:
    with st.spinner("Generating predictions..."):
        # Generate predictions only if not already in session state
        if st.session_state.predictions_df is None:

            # Load model config from JSON file
            import json
            import os

            # Define data directory path
            data_dir = "data"
            models_dir = os.path.join(data_dir, "models")

            # Load model configuration
            model_config_paths = {
                'GO_term_F': os.path.join(models_dir, "prothgt-config-molecular-function.yaml"),
                'GO_term_P': os.path.join(models_dir, "prothgt-config-biological-process.yaml"),
                'GO_term_C': os.path.join(models_dir, "prothgt-config-cellular-component.yaml")
            }

            # Paths for model and data
            model_paths = {
                'GO_term_F': os.path.join(models_dir, "prothgt-model-molecular-function.pt"),
                'GO_term_P': os.path.join(models_dir, "prothgt-model-biological-process.pt"),
                'GO_term_C': os.path.join(models_dir, "prothgt-model-cellular-component.pt")
            }

            # Get the selected GO category
            go_category = go_category_options[selected_go_category]

            # If a specific category is selected, use that model path
            if go_category:
                model_config_paths = [model_config_paths[go_category]]
                model_paths = [model_paths[go_category]]
                go_categories = [go_category]
            else:
                model_config_paths = [model_config_paths[cat] for cat in ['GO_term_F', 'GO_term_P', 'GO_term_C']]
                model_paths = [model_paths[cat] for cat in ['GO_term_F', 'GO_term_P', 'GO_term_C']]
                go_categories = ['GO_term_F', 'GO_term_P', 'GO_term_C']

            # Generate predictions
            predictions_df = generate_prediction_df(
                protein_ids=selected_proteins,
                model_paths=model_paths,
                model_config_paths=model_config_paths,
                go_category=go_categories
            )

            st.session_state.predictions_df = predictions_df
        
        # Display and filter predictions
        st.success("Predictions generated successfully!")
        st.markdown("### Filter and View Predictions")
        
        # Create filters
        st.markdown("### Filter Predictions")
        col1, col2, col3 = st.columns(3)
        
        with col1:
            # Protein filter
            selected_protein = st.selectbox(
                "Filter by Protein",
                options=['All'] + sorted(st.session_state.predictions_df['Protein'].unique().tolist())
            )
            
        with col2:
            # GO category filter
            selected_category = st.selectbox(
                "Filter by GO Category",
                options=['All'] + sorted(st.session_state.predictions_df['GO_category'].unique().tolist())
            )
            
        with col3:
            # Probability threshold
            min_probability_threshold = st.slider(
                "Minimum Probability",
                min_value=0.0,
                max_value=1.0,
                value=0.5,
                step=0.05
            )

            max_probability_threshold = st.slider(
                "Maximum Probability",
                min_value=0.0,
                max_value=1.0,
                value=1.0,
                step=0.05
            )

        # Filter the dataframe using session state data
        filtered_df = st.session_state.predictions_df.copy()

        if selected_protein != 'All':
            filtered_df = filtered_df[filtered_df['Protein'] == selected_protein]
            
        if selected_category != 'All':
            filtered_df = filtered_df[filtered_df['GO_category'] == selected_category]
            
        filtered_df = filtered_df[(filtered_df['Probability'] >= min_probability_threshold) & 
                                (filtered_df['Probability'] <= max_probability_threshold)]

        # Sort by probability
        filtered_df = filtered_df.sort_values('Probability', ascending=False)


        # Display the filtered dataframe
        st.dataframe(
            filtered_df,
            hide_index=True,
            column_config={
                "Probability": st.column_config.ProgressColumn(
                    "Probability",
                    format="%.2f",
                    min_value=0,
                    max_value=1,
                ),
                "Protein": st.column_config.TextColumn(
                    "Protein",
                    help="UniProt ID",
                ),
                "GO_category": st.column_config.TextColumn(
                    "GO Category",
                    help="Gene Ontology Category",
                ),
                "GO_term": st.column_config.TextColumn(
                    "GO Term",
                    help="Gene Ontology Term ID",
                ),
            }
        )

        # Download filtered results
        st.download_button(
            label="Download Filtered Results",
            data=convert_df(filtered_df),
            file_name="filtered_predictions.csv",
            mime="text/csv",
            key="download_filtered_predictions"
        )

# Add a reset button in the sidebar
with st.sidebar:
    if st.session_state.submitted:
        if st.button("Reset"):
            st.session_state.predictions_df = None
            st.session_state.submitted = False
            st.experimental_rerun()