Spaces:
Sleeping
Sleeping
File size: 7,932 Bytes
48097f5 4f181ab 48097f5 4f181ab 48097f5 4f181ab 48097f5 4f181ab 48097f5 4f181ab 28f312f 4f181ab 28f312f 4f181ab 48097f5 4f181ab a0958d2 4f181ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" #disable cuda
import streamlit as st
import numpy as np
import torch
import time
from Bio import SeqIO
from protxlstm.applications.fitness_prediction import single_mutation_landscape_xlstm, create_mutation_df
from protxlstm.applications.msa_sampler import sample_msa
from protxlstm.models.xlstm import xLSTMLMHeadModel
from protxlstm.utils import load_model
import io
from frontend.constants import info_text, citation_text
if __name__ == "__main__":
DEFAULT_SEQUENCE = "MTARGLALGLLLLLLCPAQVFSQSCVWYGECGIAYGDKRYNCEYSGPPKPLPKDGYDLVQELCPGFFFGNVSLCCDVRQLQTLKDNLQLPLQFLSRCPSCFYNLLNLFCELTCSPRQSQFLNVTATEDYVDPVTNQTKTNVKELQYYVGQSFANAMYNACRDVEAPSSNDKALGLLCGKDADACNATNWIEYMFNKDNGQAPFTITPVFSDFPVHGMEPMNNATKGCDESVDEVTAPCSCQDCSIVCGPKPQPPPPPAPWTILGLDAMYVIMWITYMAFLLVFFGAFFAVWCYRKRYFVSEYTPIDSNIAFSVNASDKGEASCCDPVSAAFEGCLRRLFTRWGSFCVRNPGCVIFFSLVFITACSSGLVFVRVTTNPVDLWSAPSSQARLEKEYFDQHFGPFFRTEQLIIRAPLTDKHIYQPYPSGADVPFGPPLDIQILHQVLDLQIAIENITASYDNETVTLQDICLAPLSPYNTNCTILSVLNYFQNSHSVLDHKKGDDFFVYADYHTHFLYCVRAPASLNDTSLLHDPCLGTFGGPVFPWLVLGGYDDQNYNNATALVITFPVNNYYNDTEKLQRAQAWEKEFINFVKNYKNPNLTISFTAERSIEDELNRESDSDVFTVVISYAIMFLYISLALGHMKSCRRLLVDSKVSLGIAGILIVLSSVACSLGVFSYIGLPLTLIVIEVIPFLVLAVGVDNIFILVQAYQRDERLQGETLDQQLGRVLGEVAPSMFLSSFSETVAFFLGALSVMPAVHTFSLFAGLAVFIDFLLQITCFVSLLGLDIKRQEKNRLDIFCCVRGAEDGTSVQASESCLFRFFKNSYSPLLLKDWMRPIVIAIFVGVLSFSIAVLNKVDIGLDQSLSMPDDSYMVDYFKSISQYLHAGPPVYFVLEEGHDYTSSKGQNMVCGGMGCNNDSLVQQIFNAAQLDNYTRIGFAPSSWIDDYFDWVKPQSSCCRVDNITDQFCNASVVDPACVRCRPLTPEGKQRPQGGDFMRFLPMFLSDNPNPKCGKGGHAAYSSAVNILLGHGTRVGATYFMTYHTVLQTSADFIDALKKARLIASNVTETMGINGSAYRVFPYSVFYVFYEQYLTIIDDTIFNLGVSLGAIFLVTMVLLGCELWSAVIMCATIAMVLVNMFGVMWLWGISLNAVSLVNLVMSCGISVEFCSHITRAFTVSMKGSRVERAEEALAHMGSSVFSGITLTKFGGIVVLAFAKSQIFQIFYFRMYLAMVLLGATHGLIFLPVLLSYIGPSVNKAKSCATEERYKGTERERLLNF"
mutation_positions = []
msa_file = None
if 'fitness_done' not in st.session_state:
st.session_state.fitness_done = False
st.session_state.mutations = None
st.session_state.fitness_duration = None
st.session_state.target_sequence = ""
st.session_state.context_sequences = []
st.session_state.num_context_sequences = 25
def run_model():
try:
st.session_state.fitness_duration = time.time()
checkpoint = "protxlstm/checkpoints/small"
num_context_tokens = 2**15
df_mutations = create_mutation_df(st.session_state.target_sequence, mutation_positions)
if msa_file != None and st.session_state.num_context_sequences != 0:
def load_sequences_from_msa_file(file_obj):
text_io = io.TextIOWrapper(file_obj, encoding="utf-8")
sequences = [str(record.seq) for record in SeqIO.parse(text_io, "fasta")]
return sequences
msa_sequences = [msa.upper() for msa in load_sequences_from_msa_file(msa_file)]
st.session_state.context_sequences = sample_msa(msa_sequences, max_context_sequences=st.session_state.num_context_sequences, context_length=num_context_tokens)
st.session_state.context_sequences += [st.session_state.target_sequence]
config_update_kwargs = {
"mlstm_backend": "chunkwise_variable",
"mlstm_chunksize": 1024,
"mlstm_return_last_state": True}
model = load_model(
checkpoint,
model_class=xLSTMLMHeadModel,
device='cpu',
dtype=torch.bfloat16,
**config_update_kwargs,
)
model = model.eval()
st.session_state.mutations, _ = single_mutation_landscape_xlstm(model, df_mutations, st.session_state.context_sequences, chunk_chunk_size=2**15)
print("fitness_done")
st.session_state.fitness_done = True
st.session_state.fitness_duration = time.time() - st.session_state.fitness_duration
except Exception as e:
print(e)
# PAGE STYLE (mainly for custom aa selection)
st.set_page_config(layout="wide")
st.markdown(
"""
<style>
.stButtonGroup button {
padding: 0px 1px 0px 1px !important;
border: 0 solid transparent !important;
min-height: 0px !important;
line-height: 120% !important;
height: auto !important;
}
.stSidebar {
width: 600px !important;
}
</style>
""",
unsafe_allow_html=True
)
with st.sidebar:
st.title("Prot-xLSTM Variant Fitness")
# LOAD SEQUENCE
st.session_state.target_sequence = st.text_area(
"Target protein sequence",
placeholder=DEFAULT_SEQUENCE,
value=st.session_state.target_sequence
)
if st.button("Load sequence"):
if st.session_state.target_sequence == "":
st.session_state.target_sequence = DEFAULT_SEQUENCE
# MANAGE CONTEXT SEQUENCES
context_type = st.selectbox(
"Choose how to enter context",
("Enter manually", "Use MSA file"),
index=None,
placeholder="Choose context",
)
if context_type == 'Enter manually':
context_sequence_str = st.text_area(
"Enter context protein sequences (seperated by comma)",
placeholder=DEFAULT_SEQUENCE,
)
st.session_state.context_sequences = context_sequence_str.split(",") + [st.session_state.target_sequence]
msa_file = None
elif context_type == 'Use MSA file':
msa_file = st.file_uploader("Choose MSA file")
st.session_state.num_context_sequences = st.number_input("How many of these sequences should be used?", min_value=0, step=1, value=25)
else:
st.session_state.context_sequences = [st.session_state.target_sequence]
msa_file = None
if st.session_state.target_sequence != "":
with st.container():
# MUTATION POSITION SELECTION
aas = list(st.session_state.target_sequence)
mutation_indices = np.arange(1, len(aas)+1)
mutation_positions = st.segmented_control(
"Choose mutation positions (click to select)", mutation_indices, selection_mode="multi", format_func=lambda i: aas[i-1],
)
st.button("Check Fitness", on_click=run_model)
# DISPLAY RESULTS
if st.session_state.fitness_done:
st.metric(label="Running time", value=f"{st.session_state.fitness_duration:.2f} sec.")
selected_pos = st.selectbox(
"Visualized mutation position",
st.session_state.mutations['position'].unique()
)
selected_data = st.session_state.mutations.where(st.session_state.mutations['position'] == selected_pos)
st.bar_chart(selected_data, x='mutation', y='effect', horizontal=True)
st.dataframe(st.session_state.mutations, use_container_width=True)
# TUTORIAL
with st.expander("Info & Tutorial", expanded=True):
st.warning('Due to computational constraints, processing may take up to a few minutes.', icon="⚠️")
st.subheader("Tutorial")
st.markdown("**1.** Choose a target protein sequence (leave empty to use a sample sequence) and press 'Load Sequence'")
st.markdown("**2.** Enter or upload you context sequences. (leave empty to use no context)")
st.markdown("**3.** Choose which amino acids to mutate (click on the AA's to select them) and press 'Check Fitness'")
st.subheader("General Information")
st.markdown(info_text, unsafe_allow_html=True)
st.markdown("")
st.subheader("Cite us / BibTex")
st.code(citation_text, language=None) |