FupBERT_Space / app.py
c-dunlap's picture
Initial app file upload
6e8698e
raw
history blame
1.56 kB
import transformers
from transformers import AutoModel
import torch
import streamlit as st
from smiles import load_vocab, smiles_to_tensor
from huggingface_hub import login
st.set_page_config(layout="wide")
@st.cache_resource # Add the caching decorator to prevent model reloads
def load_model(show_spinner='Loading Model...'):
login(token=st.secrets['token'])
bert_model = AutoModel.from_pretrained('battelle/FupBERT', trust_remote_code=True)
bert_model.eval()
return bert_model
@st.cache_data # Add the caching decorator to prevent data reloads
def load_data():
bert_vocab = load_vocab(r'./vocab.txt')
return bert_vocab
model = load_model()
vocab = load_data()
st.title(':blue[Battelle] FupBert')
st.write('Note: This is not an official Battelle product')
input_text = st.text_input("Provide Input: ")
def predict(inp=input_text):
if not len(inp): # escape if no input sequence
out = 'Please Enter an Input Sequence'
results.write(out)
st.session_state['result'] = out
return
max_seq_len = 256
try:
model_input = smiles_to_tensor(inp, vocab, max_seq_len=max_seq_len)
with torch.no_grad():
outputs = model(model_input)
out = f"log Fup Prediction: {outputs.item()}"
except Exception as e:
out = f"Error: {str(e)}"
st.session_state['result'] = out
results.write(out)
st.button('Evaluate', on_click=predict)
results = st.empty()
if 'result' in st.session_state:
results.write(st.session_state['result'])