from transformers import AutoTokenizer, AutoModelForSequenceClassification import shap import torch import gradio as gr import matplotlib.pyplot as plt import numpy as np # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") # Define prediction function def predict(texts): processed_texts = [] for text in texts: if isinstance(text, list): processed_text = tokenizer.convert_tokens_to_string(text) else: processed_text = text processed_texts.append(processed_text) inputs = tokenizer( processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=512, add_special_tokens=True ) with torch.no_grad(): outputs = model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) return probabilities.numpy() # Initialize SHAP components output_names_list = [model.config.id2label[i] for i in range(len(model.config.id2label))] masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True) explainer = shap.Explainer(model=predict, masker=masker, output_names=output_names_list) def analyze_text(text): # Get predictions probabilities = predict([text])[0] predicted_class = np.argmax(probabilities) predicted_label = model.config.id2label[predicted_class] # Generate SHAP explanations shap_values = explainer([text]) # Create HTML visualizations for all classes html_plots = [] for i in range(shap_values.shape[-1]): # Create SHAP text plot and convert to HTML plot_html = shap.plots.text(shap_values[0, :, i], display=False) html_plots.append(plot_html) # Format confidence scores confidence_scores = {model.config.id2label[i]: float(probabilities[i]) for i in range(len(probabilities))} return (predicted_label, confidence_scores, *html_plots) # Create Gradio interface with HTML components with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## 🔍 BERT Sentiment Analysis with SHAP Explanations") with gr.Row(): input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze...") with gr.Row(): predict_btn = gr.Button("Analyze Sentiment") with gr.Row(): label_output = gr.Label(label="Predicted Sentiment") prob_output = gr.Label(label="Confidence Scores") with gr.Row(): gr.Markdown(""" ### SHAP Explanations Below you can see how each word contributes to different sentiment scores (1-5 stars). Red text increases the score, blue decreases it. """) # Individual Explanation Rows plot_components = [] for i in range(5): with gr.Row(): plot_components.append( gr.HTML( label=f"Explanation for {model.config.id2label[i]}", elem_classes=f"shap-plot-{i+1}" ) ) predict_btn.click( fn=analyze_text, inputs=input_text, outputs=[label_output, prob_output] + plot_components ) examples = gr.Examples( examples=[ ["This product exceeded all my expectations!"], ["Terrible customer service experience."], ["The movie was okay, nothing special."], ["You are kinda cool"], ], inputs=input_text ) if __name__ == "__main__": demo.launch(debug = True)