Spaces:
Running
Running
import subprocess | |
import sys | |
##Lines 1-8 are necessary because the normal requirements.txt path for installing a package from disk doesn't work on HF spaces, thank you to Omar Sanseviero for the help! | |
import numpy as np | |
import pandas as pd | |
import shap | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from datasets import load_dataset | |
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering, | |
AutoModelForSeq2SeqLM, | |
AutoModelForSequenceClassification, AutoTokenizer, | |
pipeline) | |
st.set_page_config(page_title="HF-SHAP") | |
st.title("HF-SHAP: A front end for SHAP") | |
st.caption("By Allen Roush") | |
st.caption("github: https://github.com/Hellisotherpeople") | |
st.caption("Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/") | |
st.title("SHAP (SHapley Additive exPlanations)") | |
st.image("https://shap.readthedocs.io/en/latest/_images/shap_header.png", width = 700) | |
st.caption("By Lundberg, Scott M and Lee, Su-In") | |
st.caption("Slightly modified by Allen Roush to fix a bug with text plotting not working outside of Jupyter Notebooks") | |
st.caption("Full Citation: https://raw.githubusercontent.com/slundberg/shap/master/docs/references/shap_nips.bib") | |
st.caption("See on github:: https://github.com/slundberg/shap") | |
st.caption("More details of how SHAP works: https://christophm.github.io/interpretable-ml-book/shap.html") | |
form = st.sidebar.form("Main Settings") | |
form.header("Main Settings") | |
task_done = form.selectbox("Which NLP task do you want to solve?", ["Text Generation", "Sentiment Analysis", "Translation", "Summarization"]) | |
custom_doc = form.checkbox("Use a document from an existing dataset?", value = False) | |
if custom_doc: | |
dataset_name = form.text_area("Enter the name of the huggingface Dataset to do analysis of:", value = "Hellisotherpeople/DebateSum") | |
dataset_name_2 = form.text_area("Enter the name of the config for the dataset if it has one", value = "") | |
split_name = form.text_area("Enter the name of the split of the dataset that you want to use", value = "train") | |
number_of_records = form.number_input("Enter the number of documents that you want to analyze from the dataset", value = 200) | |
column_name = form.text_area("Enter the name of the column that we are doing analysis on (the X value)", value = "Full-Document") | |
index_to_analyze_start = form.number_input("Enter the index start of the document that you want to analyze of the dataset", value = 1) | |
index_to_analyze_end = form.number_input("Enter the index end of the document that you want to analyze of the dataset", value = 2) | |
form.caption("Multiple documents may not work on certain tasks") | |
else: | |
doc = st.text_area("Enter a custom document", value = "This is an example custom document") | |
if task_done == "Text Generation": | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2") | |
form.caption("This will download a new model, so it may take awhile or even break if the model is too large") | |
decoder = form.checkbox("Is this a decoder model?", value = True) | |
form.caption("This should be true for models like GPT-2, and false for models like BERT") | |
max_length = form.number_input("What's the max length of the text?", value = 50) | |
min_length = form.number_input("What's the min length of the text?", value = 20, max_value = max_length) | |
penalize_repetion = form.number_input("How strongly do we want to penalize repetition in the text generation?", value = 2) | |
sample = form.checkbox("Shall we use top-k and top-p decoding?", value = True) | |
form.caption("Setting this to false makes it greedy") | |
if sample: | |
top_k = form.number_input("What value of K should we use for Top-K sampling? Set to zero to disable", value = 50) | |
form.caption("In Top-K sampling, the K most likely next words are filtered and the probability mass is redistributed among only those K next words. ") | |
top_p = form.number_input("What value of P should we use for Top-p sampling? Set to zero to disable", value = 0.95, max_value = 1.0, min_value = 0.0) | |
form.caption("Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. The probability mass is then redistributed among this set of words.") | |
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0) | |
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words") | |
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate") | |
elif task_done == "Sentiment Analysis": | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Sentiment Analysis", value = "nateraw/bert-base-uncased-emotion") | |
rescale_logits = form.checkbox("Do we rescale the probabilities in terms of log odds?", value = False) | |
elif task_done == "Translation": | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "Helsinki-NLP/opus-mt-en-es") | |
elif task_done == "Summarization": | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "sshleifer/distilbart-xsum-12-1") | |
else: | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Question Answering", value = "deepset/roberta-base-squad2") | |
form.header("Model Explanation Display Settings") | |
output_width = form.number_input("Enter the number of pixels for width of model explanation html display", value = 800) | |
output_height = form.number_input("Enter the number of pixels for height of model explanation html display", value = 1000) | |
form.form_submit_button("Submit") | |
def load_and_process_data(path, name, streaming, split_name, number_of_records): | |
dataset = load_dataset(path = path, name = name, streaming=streaming) | |
#return list(dataset) | |
dataset_head = dataset[split_name].take(number_of_records) | |
df = pd.DataFrame.from_dict(dataset_head) | |
return df[column_name] | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
if task_done == "Text Generation": | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.config.is_decoder=decoder | |
if sample == True: | |
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "temperature": temperature, "top_k": top_k, "top_p" : top_p, "no_repeat_ngram_size": penalize_repetion} | |
else: | |
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "no_repeat_ngram_size": penalize_repetion} | |
elif task_done == "Sentiment Analysis": | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
elif task_done == "Translation": | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
elif task_done == "Summarization": | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
elif task_done == "Question Answering": | |
#TODO: This one is going to be harder... | |
# https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/question_answering/Explaining%20a%20Question%20Answering%20Transformers%20Model.html | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
return tokenizer, model | |
tokenizer, model = load_model(model_name) | |
if custom_doc: | |
df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records) | |
doc = list(df[index_to_analyze_start:index_to_analyze_end]) | |
st.write(doc) | |
if task_done == "Sentiment Analysis": | |
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) | |
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits) | |
else: | |
explainer = shap.Explainer(model, tokenizer) | |
if custom_doc: | |
shap_values = explainer(doc) | |
else: | |
shap_values = explainer([doc]) | |
the_plot = shap.plots.text(shap_values, display = False) | |
st.caption("The plot is interactive! Try Hovering over or clicking on the input or output text") | |
components.html(the_plot, height = output_height, width = output_width, scrolling = True) | |