""" © Battelle Memorial Institute 2023 Made available under the GNU General Public License v 2.0 BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. """ import transformers from transformers import AutoModel import torch import streamlit as st from smiles import load_vocab, smiles_to_tensor from huggingface_hub import login st.set_page_config(layout="wide") @st.cache_resource # Add the caching decorator to prevent model reloads def load_model(show_spinner='Loading Model...'): login(token=st.secrets['token']) bert_model = AutoModel.from_pretrained('battelle/FupBERT', trust_remote_code=True) bert_model.eval() return bert_model @st.cache_data # Add the caching decorator to prevent data reloads def load_data(): bert_vocab = load_vocab(r'./vocab.txt') return bert_vocab model = load_model() vocab = load_data() st.title(':blue[Battelle] FupBERT') st.write('Note: This is not an official Battelle product') input_text = st.text_input("Provide Input: ") def predict(inp=input_text): if not len(inp): # escape if no input sequence out = 'Please Enter an Input Sequence' results.write(out) st.session_state['result'] = out return max_seq_len = 256 try: model_input = smiles_to_tensor(inp, vocab, max_seq_len=max_seq_len) with torch.no_grad(): outputs = model(model_input) out = f"log Fup Prediction: {outputs.item()}" except Exception as e: out = f"Error: {str(e)}" st.session_state['result'] = out results.write(out) st.button('Evaluate', on_click=predict) results = st.empty() if 'result' in st.session_state: results.write(st.session_state['result'])