|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
|
import streamlit as st |
|
import numpy as np |
|
import torch |
|
import time |
|
from Bio import SeqIO |
|
|
|
from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df |
|
from protxlstm.applications.msa_sampler import sample_msa |
|
from protxlstm.models.xlstm import xLSTMLMHeadModel |
|
from protxlstm.utils import load_model |
|
import io |
|
|
|
from frontend.constants import info_text, citation_text |
|
|
|
if __name__ == "__main__": |
|
DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF" |
|
|
|
mutation_positions = [] |
|
msa_file = None |
|
|
|
if 'fitness_done' not in st.session_state: |
|
st.session_state.fitness_done = False |
|
st.session_state.mutations = None |
|
st.session_state.fitness_duration = None |
|
st.session_state.target_sequence = "" |
|
st.session_state.context_sequences = [] |
|
st.session_state.num_context_sequences = 25 |
|
|
|
def run_model(): |
|
try: |
|
st.session_state.fitness_duration = time.time() |
|
checkpoint = "protxlstm/checkpoints/small" |
|
num_context_tokens = 2**15 |
|
df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions) |
|
if msa_file != None and st.session_state.num_context_sequences != 0: |
|
def load_sequences_from_msa_file(file_obj): |
|
text_io = io.TextIOWrapper(file_obj, encoding="utf-8") |
|
sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")] |
|
return sequences |
|
msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)] |
|
st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens) |
|
st.session_state.context_sequences += [st.session_state.target_sequence] |
|
|
|
config_update_kwargs = { |
|
"mlstm_backend": "chunkwise_variable", |
|
"mlstm_chunksize": 1024, |
|
"mlstm_return_last_state": True} |
|
|
|
model = load_model( |
|
checkpoint, |
|
model_class=xLSTMLMHeadModel, |
|
device='cpu', |
|
dtype=torch.bfloat16, |
|
**config_update_kwargs, |
|
) |
|
model = model.eval() |
|
st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15) |
|
print("fitness_done") |
|
st.session_state.fitness_done = True |
|
st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.markdown( |
|
""" |
|
<style> |
|
.stButtonGroup button { |
|
padding: 0px 1px 0px 1px !important; |
|
border: 0 solid transparent !important; |
|
min-height: 0px !important; |
|
line-height: 120% !important; |
|
height: auto !important; |
|
} |
|
.stSidebar { |
|
width: 600px !important; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
with st.sidebar: |
|
st.title("Prot-xLSTM Variant Fitness") |
|
|
|
|
|
st.session_state.target_sequence = st.text_area( |
|
"Target protein sequence", |
|
placeholder=DEFAULT_SEQUENCE, |
|
value=st.session_state.target_sequence |
|
) |
|
if st.button("Load sequence"): |
|
if st.session_state.target_sequence == "": |
|
st.session_state.target_sequence = DEFAULT_SEQUENCE |
|
|
|
|
|
context_type = st.selectbox( |
|
"Choose how to enter context", |
|
("Enter manually", "Use MSA file"), |
|
index=None, |
|
placeholder="Choose context", |
|
) |
|
if context_type == 'Enter manually': |
|
context_sequence_str = st.text_area( |
|
"Enter context protein sequences (seperated by comma)", |
|
placeholder=DEFAULT_SEQUENCE, |
|
) |
|
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence] |
|
msa_file = None |
|
elif context_type == 'Use MSA file': |
|
msa_file = st.file_uploader("Choose MSA file") |
|
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25) |
|
else: |
|
st.session_state.context_sequences = [st.session_state.target_sequence] |
|
msa_file = None |
|
|
|
if st.session_state.target_sequence != "": |
|
with st.container(): |
|
|
|
|
|
aas = list(st.session_state.target_sequence) |
|
mutation_indices = np.arange(1, len(aas)+1) |
|
mutation_positions = st.segmented_control( |
|
"Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1], |
|
) |
|
st.button("Check Fitness", on_click=run_model) |
|
|
|
|
|
if st.session_state.fitness_done: |
|
st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.") |
|
selected_pos = st.selectbox( |
|
"Visualized mutation position", |
|
st.session_state.mutations['position'].unique() |
|
) |
|
selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos) |
|
st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True) |
|
st.dataframe(st.session_state.mutations, use_container_width=True) |
|
|
|
|
|
with st.expander("Info & Tutorial", expanded=True): |
|
st.warning('Due to computational constraints, processing may take up to a few minutes.', icon="⚠️") |
|
st.subheader("Tutorial") |
|
st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'") |
|
st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)") |
|
st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'") |
|
st.subheader("General Information") |
|
st.markdown(info_text, unsafe_allow_html=True) |
|
st.markdown("") |
|
st.subheader("Cite us / BibTex") |
|
st.code(citation_text, language=None) |