import os os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda import streamlit as st import numpy as np import torch import time from Bio import SeqIO from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df from protxlstm.applications.msa_sampler import sample_msa from protxlstm.models.xlstm import xLSTMLMHeadModel from protxlstm.utils import load_model import io from frontend.constants import info_text, citation_text if __name__ == "__main__": DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF" mutation_positions = [] msa_file = None if 'fitness_done' not in st.session_state: st.session_state.fitness_done = False st.session_state.mutations = None st.session_state.fitness_duration = None st.session_state.target_sequence = "" st.session_state.context_sequences = [] st.session_state.num_context_sequences = 25 def run_model(): try: st.session_state.fitness_duration = time.time() checkpoint = "protxlstm/checkpoints/small" num_context_tokens = 2**15 df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions) if msa_file != None and st.session_state.num_context_sequences != 0: def load_sequences_from_msa_file(file_obj): text_io = io.TextIOWrapper(file_obj, encoding="utf-8") sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")] return sequences msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)] st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens) st.session_state.context_sequences += [st.session_state.target_sequence] config_update_kwargs = { "mlstm_backend": "chunkwise_variable", "mlstm_chunksize": 1024, "mlstm_return_last_state": True} model = load_model( checkpoint, model_class=xLSTMLMHeadModel, device='cpu', dtype=torch.bfloat16, **config_update_kwargs, ) model = model.eval() st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15) print("fitness_done") st.session_state.fitness_done = True st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration except Exception as e: print(e) # PAGE STYLE (mainly for custom aa selection) st.set_page_config(layout="wide") st.markdown( """ """, unsafe_allow_html=True ) with st.sidebar: st.title("Prot-xLSTM Variant Fitness") # LOAD SEQUENCE st.session_state.target_sequence = st.text_area( "Target protein sequence", placeholder=DEFAULT_SEQUENCE, value=st.session_state.target_sequence ) if st.button("Load sequence"): if st.session_state.target_sequence == "": st.session_state.target_sequence = DEFAULT_SEQUENCE # MANAGE CONTEXT SEQUENCES context_type = st.selectbox( "Choose how to enter context", ("Enter manually", "Use MSA file"), index=None, placeholder="Choose context", ) if context_type == 'Enter manually': context_sequence_str = st.text_area( "Enter context protein sequences (seperated by comma)", placeholder=DEFAULT_SEQUENCE, ) st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence] msa_file = None elif context_type == 'Use MSA file': msa_file = st.file_uploader("Choose MSA file") st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25) else: st.session_state.context_sequences = [st.session_state.target_sequence] msa_file = None if st.session_state.target_sequence != "": with st.container(): # MUTATION POSITION SELECTION aas = list(st.session_state.target_sequence) mutation_indices = np.arange(1, len(aas)+1) mutation_positions = st.segmented_control( "Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1], ) st.button("Check Fitness", on_click=run_model) # DISPLAY RESULTS if st.session_state.fitness_done: st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.") selected_pos = st.selectbox( "Visualized mutation position", st.session_state.mutations['position'].unique() ) selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos) st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True) st.dataframe(st.session_state.mutations, use_container_width=True) # TUTORIAL with st.expander("Info & Tutorial", expanded=True): st.warning('Due to computational constraints, processing may take up to a few minutes.', icon="⚠️") st.subheader("Tutorial") st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'") st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)") st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'") st.subheader("General Information") st.markdown(info_text, unsafe_allow_html=True) st.markdown("") st.subheader("Cite us / BibTex") st.code(citation_text, language=None)