File size: 7,932 Bytes
48097f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f181ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48097f5
4f181ab
 
 
 
 
 
 
 
48097f5
4f181ab
48097f5
4f181ab
 
 
 
 
 
 
 
 
 
48097f5
4f181ab
 
 
 
 
 
28f312f
4f181ab
 
 
 
 
28f312f
4f181ab
 
 
 
 
 
 
 
 
48097f5
4f181ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0958d2
4f181ab
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda

import streamlit as st
import numpy as np
import torch
import time
from Bio import SeqIO

from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df
from protxlstm.applications.msa_sampler import sample_msa
from protxlstm.models.xlstm import xLSTMLMHeadModel
from protxlstm.utils import load_model
import io

from frontend.constants import info_text, citation_text

if __name__ == "__main__":
    DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"

    mutation_positions = []
    msa_file = None

    if 'fitness_done' not in st.session_state:
        st.session_state.fitness_done = False
        st.session_state.mutations = None
        st.session_state.fitness_duration = None
        st.session_state.target_sequence = ""
        st.session_state.context_sequences = []
        st.session_state.num_context_sequences = 25

    def run_model():
        try:
            st.session_state.fitness_duration = time.time()
            checkpoint = "protxlstm/checkpoints/small"
            num_context_tokens = 2**15
            df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
            if msa_file != None and st.session_state.num_context_sequences != 0:
                def load_sequences_from_msa_file(file_obj):
                    text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
                    sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
                    return sequences
                msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
                st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
                st.session_state.context_sequences += [st.session_state.target_sequence]

            config_update_kwargs = {
                "mlstm_backend": "chunkwise_variable",
                "mlstm_chunksize": 1024,
                "mlstm_return_last_state": True}

            model = load_model(
                checkpoint,
                model_class=xLSTMLMHeadModel,
                device='cpu',
                dtype=torch.bfloat16,
                **config_update_kwargs,
                )
            model = model.eval()
            st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
            print("fitness_done")
            st.session_state.fitness_done = True
            st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
        except Exception as e:
            print(e)

    # PAGE STYLE (mainly for custom aa selection)
    st.set_page_config(layout="wide")
    st.markdown(
        """
        <style>
        .stButtonGroup button {
            padding: 0px 1px 0px 1px !important;
            border: 0 solid transparent !important;
            min-height: 0px !important;
            line-height: 120% !important;
            height: auto !important;
        }
        .stSidebar {
            width: 600px !important;
        }
        </style>
        """,
        unsafe_allow_html=True
    )


    with st.sidebar:
        st.title("Prot-xLSTM Variant Fitness")

        # LOAD SEQUENCE
        st.session_state.target_sequence = st.text_area(
            "Target protein sequence",
            placeholder=DEFAULT_SEQUENCE,
            value=st.session_state.target_sequence
        )
        if st.button("Load sequence"):
            if st.session_state.target_sequence == "":
                st.session_state.target_sequence = DEFAULT_SEQUENCE

        # MANAGE CONTEXT SEQUENCES
        context_type = st.selectbox(
            "Choose how to enter context",
            ("Enter manually", "Use MSA file"),
            index=None,
            placeholder="Choose context",
        )
        if context_type == 'Enter manually':
            context_sequence_str = st.text_area(
                "Enter context protein sequences (seperated by comma)",
                placeholder=DEFAULT_SEQUENCE,
            )
            st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
            msa_file = None
        elif context_type == 'Use MSA file':
            msa_file = st.file_uploader("Choose MSA file")
            st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
        else:
            st.session_state.context_sequences = [st.session_state.target_sequence]
            msa_file = None

    if st.session_state.target_sequence != "":
        with st.container():

            # MUTATION POSITION SELECTION
            aas = list(st.session_state.target_sequence)
            mutation_indices = np.arange(1, len(aas)+1)
            mutation_positions = st.segmented_control(
                "Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
            )
            st.button("Check Fitness", on_click=run_model)

            # DISPLAY RESULTS
            if st.session_state.fitness_done:
                st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
                selected_pos = st.selectbox(
                    "Visualized mutation position",
                    st.session_state.mutations['position'].unique()
                )
                selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
                st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
                st.dataframe(st.session_state.mutations, use_container_width=True)

    # TUTORIAL
    with st.expander("Info & Tutorial", expanded=True):
        st.warning('Due to computational constraints, processing may take up to a few minutes.', icon="⚠️")
        st.subheader("Tutorial")
        st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
        st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
        st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
        st.subheader("General Information")
        st.markdown(info_text, unsafe_allow_html=True)
        st.markdown("")
        st.subheader("Cite us / BibTex")
        st.code(citation_text, language=None)