FupBERT_Space / app.py
c-dunlap's picture
Update app.py input text
9e117f6
"""
© Battelle Memorial Institute 2023
Made available under the GNU General Public License v 2.0
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
"""
import transformers
from transformers import AutoModel
import torch
import streamlit as st
from smiles import load_vocab, smiles_to_tensor
from huggingface_hub import login
st.set_page_config(layout="wide")
@st.cache_resource # Add the caching decorator to prevent model reloads
def load_model(show_spinner='Loading Model...'):
bert_model = AutoModel.from_pretrained('battelle/FupBERT', trust_remote_code=True)
bert_model.eval()
return bert_model
@st.cache_data # Add the caching decorator to prevent data reloads
def load_data():
bert_vocab = load_vocab(r'./vocab.txt')
return bert_vocab
model = load_model()
vocab = load_data()
st.title(':blue[Battelle] FupBERT')
st.write('Note: This is not an official Battelle product')
input_text = st.text_input("Provide input as a smiles string: ")
def predict(inp=input_text):
if not len(inp): # escape if no input sequence
out = 'Please Enter an Input Sequence'
results.write(out)
st.session_state['result'] = out
return
max_seq_len = 256
try:
model_input = smiles_to_tensor(inp, vocab, max_seq_len=max_seq_len)
with torch.no_grad():
outputs = model(model_input)
out = f"log Fup Prediction: {outputs.item()}"
except Exception as e:
out = f"Error: {str(e)}"
st.session_state['result'] = out
results.write(out)
st.button('Evaluate', on_click=predict)
results = st.empty()
if 'result' in st.session_state:
results.write(st.session_state['result'])