from lime.lime_text import LimeTextExplainer from nltk.tokenize import sent_tokenize from predictors import predict_for_explainanility def explainer(text, model_type): def predictor_wrapper(text): return predict_for_explainanility(text=text, model_type=model_type) class_names = ["negative", "positive"] explainer_ = LimeTextExplainer( class_names=class_names, split_expression=sent_tokenize ) sentences = [sent for sent in sent_tokenize(text)] num_sentences = len(sentences) exp = explainer_.explain_instance( text, predictor_wrapper, num_features=num_sentences, num_samples=500 ) weights_mapping = exp.as_map()[1] sentences_weights = {sentence: 0 for sentence in sentences} for idx, weight in weights_mapping: if 0 <= idx < len(sentences): sentences_weights[sentences[idx]] = weight print(sentences_weights, model_type) return sentences_weights, exp def analyze_and_highlight(text, model_type): highlighted_text = "" sentences_weights, _ = explainer(text, model_type) positive_weights = [weight for weight in sentences_weights.values() if weight >= 0] negative_weights = [weight for weight in sentences_weights.values() if weight < 0] smoothing_factor = 0.001 # we do this cos to avoid all white colors min_positive_weight = min(positive_weights) if positive_weights else 0 max_positive_weight = max(positive_weights) if positive_weights else 0 min_negative_weight = min(negative_weights) if negative_weights else 0 max_negative_weight = max(negative_weights) if negative_weights else 0 max_positive_weight += smoothing_factor min_negative_weight -= smoothing_factor for sentence, weight in sentences_weights.items(): sentence = sentence.strip() if not sentence: continue if weight >= 0 and max_positive_weight != min_positive_weight: normalized_weight = (weight - min_positive_weight + smoothing_factor) / ( max_positive_weight - min_positive_weight ) color = f"rgb(255, {int(255 * (1 - normalized_weight))}, {int(255 * (1 - normalized_weight))})" elif weight < 0 and min_negative_weight != max_negative_weight: normalized_weight = (weight - max_negative_weight - smoothing_factor) / ( min_negative_weight - max_negative_weight ) color = f"rgb({int(255 * (1 - normalized_weight))}, 255, {int(255 * (1 - normalized_weight))})" else: color = "rgb(255, 255, 255)" # when no range highlighted_sentence = ( f'{sentence} ' ) highlighted_text += highlighted_sentence return highlighted_text