Spaces:
Sleeping
Sleeping
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda | |
import streamlit as st | |
import numpy as np | |
import torch | |
import time | |
from Bio import SeqIO | |
from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df | |
from protxlstm.applications.msa_sampler import sample_msa | |
from protxlstm.models.xlstm import xLSTMLMHeadModel | |
from protxlstm.utils import load_model | |
import io | |
from frontend.constants import info_text, citation_text | |
if __name__ == "__main__": | |
DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF" | |
mutation_positions = [] | |
msa_file = None | |
if 'fitness_done' not in st.session_state: | |
st.session_state.fitness_done = False | |
st.session_state.mutations = None | |
st.session_state.fitness_duration = None | |
st.session_state.target_sequence = "" | |
st.session_state.context_sequences = [] | |
st.session_state.num_context_sequences = 25 | |
def run_model(): | |
try: | |
st.session_state.fitness_duration = time.time() | |
checkpoint = "protxlstm/checkpoints/small" | |
num_context_tokens = 2**15 | |
df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions) | |
if msa_file != None and st.session_state.num_context_sequences != 0: | |
def load_sequences_from_msa_file(file_obj): | |
text_io = io.TextIOWrapper(file_obj, encoding="utf-8") | |
sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")] | |
return sequences | |
msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)] | |
st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens) | |
st.session_state.context_sequences += [st.session_state.target_sequence] | |
config_update_kwargs = { | |
"mlstm_backend": "chunkwise_variable", | |
"mlstm_chunksize": 1024, | |
"mlstm_return_last_state": True} | |
model = load_model( | |
checkpoint, | |
model_class=xLSTMLMHeadModel, | |
device='cpu', | |
dtype=torch.bfloat16, | |
**config_update_kwargs, | |
) | |
model = model.eval() | |
st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15) | |
print("fitness_done") | |
st.session_state.fitness_done = True | |
st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration | |
except Exception as e: | |
print(e) | |
# PAGE STYLE (mainly for custom aa selection) | |
st.set_page_config(layout="wide") | |
st.markdown( | |
""" | |
<style> | |
.stButtonGroup button { | |
padding: 0px 1px 0px 1px !important; | |
border: 0 solid transparent !important; | |
min-height: 0px !important; | |
line-height: 120% !important; | |
height: auto !important; | |
} | |
.stSidebar { | |
width: 600px !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
with st.sidebar: | |
st.title("Prot-xLSTM Variant Fitness") | |
# LOAD SEQUENCE | |
st.session_state.target_sequence = st.text_area( | |
"Target protein sequence", | |
placeholder=DEFAULT_SEQUENCE, | |
value=st.session_state.target_sequence | |
) | |
if st.button("Load sequence"): | |
if st.session_state.target_sequence == "": | |
st.session_state.target_sequence = DEFAULT_SEQUENCE | |
# MANAGE CONTEXT SEQUENCES | |
context_type = st.selectbox( | |
"Choose how to enter context", | |
("Enter manually", "Use MSA file"), | |
index=None, | |
placeholder="Choose context", | |
) | |
if context_type == 'Enter manually': | |
context_sequence_str = st.text_area( | |
"Enter context protein sequences (seperated by comma)", | |
placeholder=DEFAULT_SEQUENCE, | |
) | |
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence] | |
msa_file = None | |
elif context_type == 'Use MSA file': | |
msa_file = st.file_uploader("Choose MSA file") | |
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25) | |
else: | |
st.session_state.context_sequences = [st.session_state.target_sequence] | |
msa_file = None | |
if st.session_state.target_sequence != "": | |
with st.container(): | |
# MUTATION POSITION SELECTION | |
aas = list(st.session_state.target_sequence) | |
mutation_indices = np.arange(1, len(aas)+1) | |
mutation_positions = st.segmented_control( | |
"Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1], | |
) | |
st.button("Check Fitness", on_click=run_model) | |
# DISPLAY RESULTS | |
if st.session_state.fitness_done: | |
st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.") | |
selected_pos = st.selectbox( | |
"Visualized mutation position", | |
st.session_state.mutations['position'].unique() | |
) | |
selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos) | |
st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True) | |
st.dataframe(st.session_state.mutations, use_container_width=True) | |
# TUTORIAL | |
with st.expander("Info & Tutorial", expanded=True): | |
st.warning('Due to computational constraints, processing may take up to a few minutes.', icon="⚠️") | |
st.subheader("Tutorial") | |
st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'") | |
st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)") | |
st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'") | |
st.subheader("General Information") | |
st.markdown(info_text, unsafe_allow_html=True) | |
st.markdown("") | |
st.subheader("Cite us / BibTex") | |
st.code(citation_text, language=None) |