Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import shap | |
import torch | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Load model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
# Define prediction function | |
def predict(texts): | |
processed_texts = [] | |
for text in texts: | |
if isinstance(text, list): | |
processed_text = tokenizer.convert_tokens_to_string(text) | |
else: | |
processed_text = text | |
processed_texts.append(processed_text) | |
inputs = tokenizer( | |
processed_texts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512, | |
add_special_tokens=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
return probabilities.numpy() | |
# Initialize SHAP components | |
output_names_list = [model.config.id2label[i] for i in range(len(model.config.id2label))] | |
masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True) | |
explainer = shap.Explainer(model=predict, masker=masker, output_names=output_names_list) | |
def analyze_text(text): | |
# Get predictions | |
probabilities = predict([text])[0] | |
predicted_class = np.argmax(probabilities) | |
predicted_label = model.config.id2label[predicted_class] | |
# Generate SHAP explanations | |
shap_values = explainer([text]) | |
# Create HTML visualizations for all classes | |
html_plots = [] | |
for i in range(shap_values.shape[-1]): | |
# Create SHAP text plot and convert to HTML | |
plot_html = shap.plots.text(shap_values[0, :, i], display=False) | |
html_plots.append(plot_html) | |
# Format confidence scores | |
confidence_scores = {model.config.id2label[i]: float(probabilities[i]) | |
for i in range(len(probabilities))} | |
return (predicted_label, | |
confidence_scores, | |
*html_plots) | |
# Create Gradio interface with HTML components | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## π BERT Sentiment Analysis with SHAP Explanations") | |
with gr.Row(): | |
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze...") | |
with gr.Row(): | |
predict_btn = gr.Button("Analyze Sentiment") | |
with gr.Row(): | |
label_output = gr.Label(label="Predicted Sentiment") | |
prob_output = gr.Label(label="Confidence Scores") | |
with gr.Row(): | |
gr.Markdown(""" | |
### SHAP Explanations | |
Below you can see how each word contributes to different sentiment scores (1-5 stars). | |
Red text increases the score, blue decreases it. | |
""") | |
# Individual Explanation Rows | |
plot_components = [] | |
for i in range(5): | |
with gr.Row(): | |
plot_components.append( | |
gr.HTML( | |
label=f"Explanation for {model.config.id2label[i]}", | |
elem_classes=f"shap-plot-{i+1}" | |
) | |
) | |
predict_btn.click( | |
fn=analyze_text, | |
inputs=input_text, | |
outputs=[label_output, prob_output] + plot_components | |
) | |
examples = gr.Examples( | |
examples=[ | |
["This product exceeded all my expectations!"], | |
["Terrible customer service experience."], | |
["The movie was okay, nothing special."], | |
["You are kinda cool"], | |
], | |
inputs=input_text | |
) | |
if __name__ == "__main__": | |
demo.launch(debug = True) |