Elias Buerger
warnings
a0958d2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda
import streamlit as st
import numpy as np
import torch
import time
from Bio import SeqIO
from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df
from protxlstm.applications.msa_sampler import sample_msa
from protxlstm.models.xlstm import xLSTMLMHeadModel
from protxlstm.utils import load_model
import io
from frontend.constants import info_text, citation_text
if __name__ == "__main__":
DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"
mutation_positions = []
msa_file = None
if 'fitness_done' not in st.session_state:
st.session_state.fitness_done = False
st.session_state.mutations = None
st.session_state.fitness_duration = None
st.session_state.target_sequence = ""
st.session_state.context_sequences = []
st.session_state.num_context_sequences = 25
def run_model():
try:
st.session_state.fitness_duration = time.time()
checkpoint = "protxlstm/checkpoints/small"
num_context_tokens = 2**15
df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
if msa_file != None and st.session_state.num_context_sequences != 0:
def load_sequences_from_msa_file(file_obj):
text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
return sequences
msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
st.session_state.context_sequences += [st.session_state.target_sequence]
config_update_kwargs = {
"mlstm_backend": "chunkwise_variable",
"mlstm_chunksize": 1024,
"mlstm_return_last_state": True}
model = load_model(
checkpoint,
model_class=xLSTMLMHeadModel,
device='cpu',
dtype=torch.bfloat16,
**config_update_kwargs,
)
model = model.eval()
st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
print("fitness_done")
st.session_state.fitness_done = True
st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
except Exception as e:
print(e)
# PAGE STYLE (mainly for custom aa selection)
st.set_page_config(layout="wide")
st.markdown(
"""
<style>
.stButtonGroup button {
padding: 0px 1px 0px 1px !important;
border: 0 solid transparent !important;
min-height: 0px !important;
line-height: 120% !important;
height: auto !important;
}
.stSidebar {
width: 600px !important;
}
</style>
""",
unsafe_allow_html=True
)
with st.sidebar:
st.title("Prot-xLSTM Variant Fitness")
# LOAD SEQUENCE
st.session_state.target_sequence = st.text_area(
"Target protein sequence",
placeholder=DEFAULT_SEQUENCE,
value=st.session_state.target_sequence
)
if st.button("Load sequence"):
if st.session_state.target_sequence == "":
st.session_state.target_sequence = DEFAULT_SEQUENCE
# MANAGE CONTEXT SEQUENCES
context_type = st.selectbox(
"Choose how to enter context",
("Enter manually", "Use MSA file"),
index=None,
placeholder="Choose context",
)
if context_type == 'Enter manually':
context_sequence_str = st.text_area(
"Enter context protein sequences (seperated by comma)",
placeholder=DEFAULT_SEQUENCE,
)
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
msa_file = None
elif context_type == 'Use MSA file':
msa_file = st.file_uploader("Choose MSA file")
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
else:
st.session_state.context_sequences = [st.session_state.target_sequence]
msa_file = None
if st.session_state.target_sequence != "":
with st.container():
# MUTATION POSITION SELECTION
aas = list(st.session_state.target_sequence)
mutation_indices = np.arange(1, len(aas)+1)
mutation_positions = st.segmented_control(
"Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
)
st.button("Check Fitness", on_click=run_model)
# DISPLAY RESULTS
if st.session_state.fitness_done:
st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
selected_pos = st.selectbox(
"Visualized mutation position",
st.session_state.mutations['position'].unique()
)
selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
st.dataframe(st.session_state.mutations, use_container_width=True)
# TUTORIAL
with st.expander("Info & Tutorial", expanded=True):
st.warning('Due to computational constraints, processing may take up to a few minutes.', icon="⚠️")
st.subheader("Tutorial")
st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
st.subheader("General Information")
st.markdown(info_text, unsafe_allow_html=True)
st.markdown("")
st.subheader("Cite us / BibTex")
st.code(citation_text, language=None)