AlephBERT / app.py
Elron's picture
Update app.py
1f96fb2
import streamlit as st
from transformers import pipeline
from transformers.tokenization_utils import TruncationStrategy
import tokenizers
import pandas as pd
import requests
st.set_page_config(
page_title='AlephBERT Demo',
page_icon="🥙",
initial_sidebar_state="expanded",
)
models = {
"AlephBERT-base": {
"name_or_path":"onlplab/alephbert-base",
"description":"AlephBERT base model",
},
"HeBERT-base-TAU": {
"name_or_path":"avichr/heBERT",
"description":"HeBERT model created by TAU"
},
"mBERT-base-multilingual-cased": {
"name_or_path":"bert-base-multilingual-cased",
"description":"Multilingual BERT model"
}
}
@st.cache(show_spinner=False)
def get_json_from_url(url):
return models
return requests.get(url).json()
# models = get_json_from_url('https://huggingface.co/spaces/biu-nlp/AlephBERT/raw/main/models.json')
@st.cache(show_spinner=False, hash_funcs={tokenizers.Tokenizer: str})
def load_model(model):
pipe = pipeline('fill-mask', models[model]['name_or_path'])
def do_tokenize(inputs):
return pipe.tokenizer(
inputs,
add_special_tokens=True,
return_tensors=pipe.framework,
padding=True,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
)
def _parse_and_tokenize(
inputs, tokenized=False, **kwargs
):
if not tokenized:
inputs = do_tokenize(inputs)
return inputs
pipe._parse_and_tokenize = _parse_and_tokenize
return pipe, do_tokenize
st.title('AlephBERT🥙')
st.sidebar.markdown(
"""<div><a target="_blank" href="https://nlp.biu.ac.il/~rtsarfaty/onlp#"><img src="https://nlp.biu.ac.il/~rtsarfaty/static/landing_static/img/onlp_logo.png" style="filter: invert(100%);display: block;margin-left: auto;margin-right: auto;
width: 70%;"></a>
<p style="color:white; font-size:13px; font-family:monospace; text-align: center">AlephBERT Demo &bull; <a href="https://nlp.biu.ac.il/~rtsarfaty/onlp#" style="text-decoration: none;color: white;" target="_blank">ONLP Lab</a></p></div>
<br>""",
unsafe_allow_html=True,
)
mode = 'Models'
if mode == 'Models':
model = st.sidebar.selectbox(
'Select Model',
list(models.keys()))
masking_level = st.sidebar.selectbox('Masking Level:', ['Tokens', 'SubWords'])
n_res = st.sidebar.number_input(
'Number Of Results',
format='%d',
value=5,
min_value=1,
max_value=100)
model_tags = model.split('-')
model_tags[0] = 'Model:' + model_tags[0]
st.markdown(''.join([f'<span style="color:white; font-size:13px; font-family:monospace; background-color: #f63766;margin:3px;padding:8px;border-radius: 5px;">{tag}</span>' for tag in model_tags]),unsafe_allow_html=True)
st.markdown('___')
unmasker, tokenize = load_model(model)
input_text = st.text_input('Insert text you want to mask', '')
if input_text:
input_masked = None
tokenized = tokenize(input_text)
ids = tokenized['input_ids'].tolist()[0]
subwords = unmasker.tokenizer.convert_ids_to_tokens(ids)
if masking_level == 'Tokens':
tokens = str(input_text).split()
mask_idx = st.selectbox('Select token to mask:', [None] + list(range(len(tokens))), format_func=lambda i: tokens[i] if i else '')
if mask_idx is not None:
input_masked = ' '.join(token if i != mask_idx else '[MASK]' for i, token in enumerate(tokens))
display_input = input_masked
if masking_level == 'SubWords':
tokens = subwords
idx = st.selectbox('Select token to mask:', list(range(0,len(tokens)-1)), format_func=lambda i: tokens[i] if i else '')
tokenized['input_ids'][0][idx] = unmasker.tokenizer.mask_token_id
ids = tokenized['input_ids'].tolist()[0]
display_input = ' '.join(unmasker.tokenizer.convert_ids_to_tokens(ids[1:-1]))
if idx:
input_masked = tokenized
if input_masked:
st.markdown('#### Input:')
ids = tokenized['input_ids'].tolist()[0]
subwords = unmasker.tokenizer.convert_ids_to_tokens(ids)
st.markdown(f'<p dir="rtl">{display_input}</p>',
unsafe_allow_html=True,
)
st.markdown('#### Outputs:')
with st.spinner(f'Running {model_tags[0]} (may take a minute)...'):
res = unmasker(input_masked, tokenized=masking_level == 'SubWords', top_k=n_res)
if res:
res = [{'Prediction':r['token_str'], 'Completed Sentence':r['sequence'].replace('[SEP]', '').replace('[CLS]', ''), 'Score':r['score']} for r in res]
res_table = pd.DataFrame(res)
st.table(res_table)