File size: 2,244 Bytes
83b1a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8698e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7e5b7b
6e8698e
9e117f6
6e8698e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
© Battelle Memorial Institute 2023
Made available under the GNU General Public License v 2.0

BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW.  EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.  THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU.  SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
"""


import transformers
from transformers import AutoModel
import torch
import streamlit as st
from smiles import load_vocab, smiles_to_tensor
from huggingface_hub import login

st.set_page_config(layout="wide")


@st.cache_resource  # Add the caching decorator to prevent model reloads
def load_model(show_spinner='Loading Model...'):
    bert_model = AutoModel.from_pretrained('battelle/FupBERT', trust_remote_code=True)
    bert_model.eval()
    return bert_model


@st.cache_data  # Add the caching decorator to prevent data reloads
def load_data():
    bert_vocab = load_vocab(r'./vocab.txt')
    return bert_vocab


model = load_model()
vocab = load_data()

st.title(':blue[Battelle] FupBERT')
st.write('Note: This is not an official Battelle product')
input_text = st.text_input("Provide input as a smiles string: ")


def predict(inp=input_text):
    if not len(inp):  # escape if no input sequence
        out = 'Please Enter an Input Sequence'
        results.write(out)
        st.session_state['result'] = out  
        return
        
    max_seq_len = 256
    try:
        model_input = smiles_to_tensor(inp, vocab, max_seq_len=max_seq_len)
        with torch.no_grad():
            outputs = model(model_input)
        out = f"log Fup Prediction: {outputs.item()}"
    except Exception as e:
        out = f"Error: {str(e)}"
    st.session_state['result'] = out   
    results.write(out)


st.button('Evaluate', on_click=predict)
results = st.empty()

if 'result' in st.session_state:
    results.write(st.session_state['result'])