Spaces:
Runtime error
Runtime error
File size: 5,669 Bytes
fddf3ff 2e6b9d1 fddf3ff eae7c24 fddf3ff eae7c24 fddf3ff 2e6b9d1 fddf3ff cac1b0e fddf3ff 72de74d fddf3ff cac1b0e fddf3ff 2f3e71b fddf3ff 2f3e71b fddf3ff 2f3e71b fddf3ff 2f3e71b 218de8e fddf3ff eae7c24 fddf3ff d0185e8 fddf3ff d0185e8 fddf3ff 218de8e d0185e8 fddf3ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import asyncio
import gc
import logging
import os
import pandas as pd
import psutil
import streamlit as st
from PIL import Image
from streamlit import components
#from streamlit.caching import clear_cache
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
#logging.basicConfig(
# format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
#)
#def print_memory_usage():
# logging.info(f"RAM memory % used: {psutil.virtual_memory()[2]}")
#@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)
def load_model(model_name):
return (
AutoModelForSequenceClassification.from_pretrained(model_name),
AutoTokenizer.from_pretrained(model_name),
)
print ("before main")
def main():
st.title("Transformers Interpet Demo App")
print ("before main")
image = Image.open("./images/tight@1920x_transparent.png")
st.sidebar.image(image, use_column_width=True)
st.sidebar.markdown(
"Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
)
st.info(
"Due to limited resources only low memory models are available. Run this [app locally](https://github.com/cdpierse/transformers-interpret-streamlit) to run the full selection of available models. "
)
# uncomment the options below to test out the app with a variety of classification models.
models = {
# "textattack/distilbert-base-uncased-rotten-tomatoes": "",
# "textattack/bert-base-uncased-rotten-tomatoes": "",
# "textattack/roberta-base-rotten-tomatoes": "",
# "mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
# "nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
"distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
# "ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
"sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
"MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
# # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
# "mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
}
model_name = st.sidebar.selectbox(
"Choose a classification model", list(models.keys())
)
model, tokenizer = load_model(model_name)
print ("Model loaded")
if model_name.startswith("textattack/"):
model.config.id2label = {0: "NEGATIVE (0) ", 1: "POSITIVE (1)"}
model.eval()
print ("Model Evaluated")
cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
print ("Model Explained")
if cls_explainer.accepts_position_ids:
emb_type_name = st.sidebar.selectbox(
"Choose embedding type for attribution.", ["word", "position"]
)
if emb_type_name == "word":
emb_type_num = 0
if emb_type_name == "position":
emb_type_num = 1
else:
emb_type_num = 0
explanation_classes = ["predicted"] + list(model.config.label2id.keys())
explanation_class_choice = st.sidebar.selectbox(
"Explanation class: The class you would like to explain output with respect to.",
explanation_classes,
)
my_expander = st.beta_expander(
"Click here for a description of models and their tasks"
)
with my_expander:
st.json(models)
# st.info("Max char limit of 350 (memory management)")
text = st.text_area(
"Enter text to be interpreted",
"I like you, I love you",
height=400,
max_chars=850,
)
print ("Before button")
if st.button('Say hello'):
st.write('Why hello there')
else:
st.write('Goodbye')
print ("After test button")
if st.button("Interpret Text"):
#print_memory_usage()
st.text("Output")
with st.spinner("Interpreting your text (This may take some time)"):
print ("Interpreting text")
if explanation_class_choice != "predicted":
word_attributions = cls_explainer(
text,
class_name=explanation_class_choice,
embedding_type=emb_type_num,
internal_batch_size=2,
)
else:
word_attributions = cls_explainer(
text, embedding_type=emb_type_num, internal_batch_size=2
)
if word_attributions:
print ("Word Attributions")
word_attributions_expander = st.beta_expander(
"Click here for raw word attributions"
)
with word_attributions_expander:
st.json(word_attributions)
#components.v1.html(
# cls_explainer.visualize()._repr_html_(), scrolling=True, height=350
#)
print ("end of stuff")
if __name__ == "__main__":
main()
|