Tschoui commited on
Commit
4f181ab
·
1 Parent(s): cb65138

main script statement added

Browse files
Files changed (1) hide show
  1. app.py +130 -129
app.py CHANGED
@@ -15,137 +15,138 @@ import io
15
 
16
  from frontend.constants import info_text, citation_text
17
 
18
- DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"
19
-
20
- mutation_positions = []
21
- msa_file = None
22
-
23
- if 'fitness_done' not in st.session_state:
24
- st.session_state.fitness_done = False
25
- st.session_state.mutations = None
26
- st.session_state.fitness_duration = None
27
- st.session_state.target_sequence = ""
28
- st.session_state.context_sequences = []
29
- st.session_state.num_context_sequences = 25
30
-
31
- def run_model():
32
- try:
33
- st.session_state.fitness_duration = time.time()
34
- checkpoint = "protxlstm/checkpoints/small"
35
- num_context_tokens = 2**15
36
- df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
37
- if msa_file != None and st.session_state.num_context_sequences != 0:
38
- def load_sequences_from_msa_file(file_obj):
39
- text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
40
- sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
41
- return sequences
42
- msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
43
- st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
44
- st.session_state.context_sequences += [st.session_state.target_sequence]
45
-
46
- config_update_kwargs = {
47
- "mlstm_backend": "chunkwise_variable",
48
- "mlstm_chunksize": 1024,
49
- "mlstm_return_last_state": True}
50
-
51
- model = load_model(
52
- checkpoint,
53
- model_class=xLSTMLMHeadModel,
54
- device='cpu',
55
- dtype=torch.bfloat16,
56
- **config_update_kwargs,
57
- )
58
- model = model.eval()
59
- st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
60
- print("fitness_done")
61
- st.session_state.fitness_done = True
62
- st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
63
- except Exception as e:
64
- print(e)
65
-
66
- # PAGE STYLE (mainly for custom aa selection)
67
- st.set_page_config(layout="wide")
68
- st.markdown(
69
- """
70
- <style>
71
- .stButtonGroup button {
72
- padding: 0px 1px 0px 1px !important;
73
- border: 0 solid transparent !important;
74
- min-height: 0px !important;
75
- line-height: 120% !important;
76
- height: auto !important;
77
- }
78
- .stSidebar {
79
- width: 600px !important;
80
- }
81
- </style>
82
- """,
83
- unsafe_allow_html=True
84
- )
85
-
86
-
87
- with st.sidebar:
88
- st.title("Prot-xLSTM Variant Fitness")
89
-
90
- # LOAD SEQUENCE
91
- st.session_state.target_sequence = st.text_area(
92
- "Target protein sequence",
93
- placeholder=DEFAULT_SEQUENCE,
94
- value=st.session_state.target_sequence
95
- )
96
- if st.button("Load sequence"):
97
- if st.session_state.target_sequence == "":
98
- st.session_state.target_sequence = DEFAULT_SEQUENCE
99
-
100
- # MANAGE CONTEXT SEQUENCES
101
- context_type = st.selectbox(
102
- "Choose how to enter context",
103
- ("Enter manually", "Use MSA file"),
104
- index=None,
105
- placeholder="Choose context",
106
  )
107
- if context_type == 'Enter manually':
108
- context_sequence_str = st.text_area(
109
- "Enter context protein sequences (seperated by comma)",
 
 
 
 
 
110
  placeholder=DEFAULT_SEQUENCE,
 
111
  )
112
- st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
113
- elif context_type == 'Use MSA file':
114
- msa_file = st.file_uploader("Choose MSA file")
115
- st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
116
- else:
117
- st.session_state.context_sequences = [st.session_state.target_sequence]
118
-
119
- if st.session_state.target_sequence != "":
120
- with st.container():
121
-
122
- # MUTATION POSITION SELECTION
123
- aas = list(st.session_state.target_sequence)
124
- mutation_indices = np.arange(1, len(aas)+1)
125
- mutation_positions = st.segmented_control(
126
- "Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
127
  )
128
- st.button("Check Fitness", on_click=run_model)
129
-
130
- # DISPLAY RESULTS
131
- if st.session_state.fitness_done:
132
- st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
133
- selected_pos = st.selectbox(
134
- "Visualized mutation position",
135
- st.session_state.mutations['position'].unique()
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
- selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
138
- st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
139
- st.dataframe(st.session_state.mutations, use_container_width=True)
140
-
141
- # TUTORIAL
142
- with st.expander("Info & Tutorial", expanded=True):
143
- st.subheader("Tutorial")
144
- st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
145
- st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
146
- st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
147
- st.subheader("General Information")
148
- st.markdown(info_text, unsafe_allow_html=True)
149
- st.markdown("")
150
- st.subheader("Cite us / BibTex")
151
- st.code(citation_text, language=None)
 
 
 
 
 
 
 
 
 
 
15
 
16
  from frontend.constants import info_text, citation_text
17
 
18
+ if __name__ == "__main__":
19
+ DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"
20
+
21
+ mutation_positions = []
22
+ msa_file = None
23
+
24
+ if 'fitness_done' not in st.session_state:
25
+ st.session_state.fitness_done = False
26
+ st.session_state.mutations = None
27
+ st.session_state.fitness_duration = None
28
+ st.session_state.target_sequence = ""
29
+ st.session_state.context_sequences = []
30
+ st.session_state.num_context_sequences = 25
31
+
32
+ def run_model():
33
+ try:
34
+ st.session_state.fitness_duration = time.time()
35
+ checkpoint = "protxlstm/checkpoints/small"
36
+ num_context_tokens = 2**15
37
+ df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
38
+ if msa_file != None and st.session_state.num_context_sequences != 0:
39
+ def load_sequences_from_msa_file(file_obj):
40
+ text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
41
+ sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
42
+ return sequences
43
+ msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
44
+ st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
45
+ st.session_state.context_sequences += [st.session_state.target_sequence]
46
+
47
+ config_update_kwargs = {
48
+ "mlstm_backend": "chunkwise_variable",
49
+ "mlstm_chunksize": 1024,
50
+ "mlstm_return_last_state": True}
51
+
52
+ model = load_model(
53
+ checkpoint,
54
+ model_class=xLSTMLMHeadModel,
55
+ device='cpu',
56
+ dtype=torch.bfloat16,
57
+ **config_update_kwargs,
58
+ )
59
+ model = model.eval()
60
+ st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
61
+ print("fitness_done")
62
+ st.session_state.fitness_done = True
63
+ st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
64
+ except Exception as e:
65
+ print(e)
66
+
67
+ # PAGE STYLE (mainly for custom aa selection)
68
+ st.set_page_config(layout="wide")
69
+ st.markdown(
70
+ """
71
+ <style>
72
+ .stButtonGroup button {
73
+ padding: 0px 1px 0px 1px !important;
74
+ border: 0 solid transparent !important;
75
+ min-height: 0px !important;
76
+ line-height: 120% !important;
77
+ height: auto !important;
78
+ }
79
+ .stSidebar {
80
+ width: 600px !important;
81
+ }
82
+ </style>
83
+ """,
84
+ unsafe_allow_html=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
+
87
+
88
+ with st.sidebar:
89
+ st.title("Prot-xLSTM Variant Fitness")
90
+
91
+ # LOAD SEQUENCE
92
+ st.session_state.target_sequence = st.text_area(
93
+ "Target protein sequence",
94
  placeholder=DEFAULT_SEQUENCE,
95
+ value=st.session_state.target_sequence
96
  )
97
+ if st.button("Load sequence"):
98
+ if st.session_state.target_sequence == "":
99
+ st.session_state.target_sequence = DEFAULT_SEQUENCE
100
+
101
+ # MANAGE CONTEXT SEQUENCES
102
+ context_type = st.selectbox(
103
+ "Choose how to enter context",
104
+ ("Enter manually", "Use MSA file"),
105
+ index=None,
106
+ placeholder="Choose context",
 
 
 
 
 
107
  )
108
+ if context_type == 'Enter manually':
109
+ context_sequence_str = st.text_area(
110
+ "Enter context protein sequences (seperated by comma)",
111
+ placeholder=DEFAULT_SEQUENCE,
112
+ )
113
+ st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
114
+ elif context_type == 'Use MSA file':
115
+ msa_file = st.file_uploader("Choose MSA file")
116
+ st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
117
+ else:
118
+ st.session_state.context_sequences = [st.session_state.target_sequence]
119
+
120
+ if st.session_state.target_sequence != "":
121
+ with st.container():
122
+
123
+ # MUTATION POSITION SELECTION
124
+ aas = list(st.session_state.target_sequence)
125
+ mutation_indices = np.arange(1, len(aas)+1)
126
+ mutation_positions = st.segmented_control(
127
+ "Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
128
  )
129
+ st.button("Check Fitness", on_click=run_model)
130
+
131
+ # DISPLAY RESULTS
132
+ if st.session_state.fitness_done:
133
+ st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
134
+ selected_pos = st.selectbox(
135
+ "Visualized mutation position",
136
+ st.session_state.mutations['position'].unique()
137
+ )
138
+ selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
139
+ st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
140
+ st.dataframe(st.session_state.mutations, use_container_width=True)
141
+
142
+ # TUTORIAL
143
+ with st.expander("Info & Tutorial", expanded=True):
144
+ st.subheader("Tutorial")
145
+ st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
146
+ st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
147
+ st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
148
+ st.subheader("General Information")
149
+ st.markdown(info_text, unsafe_allow_html=True)
150
+ st.markdown("")
151
+ st.subheader("Cite us / BibTex")
152
+ st.code(citation_text, language=None)