# TODO consider if this can be collapsed back down into the pipeline_train.py import argparse import json import logging import os import random from collections import OrderedDict from itertools import chain from typing import List, Tuple import numpy as np import torch import torch.nn as nn from BERT_explainability.modules.BERT.BERT_cls_lrp import \ BertForSequenceClassification as BertForClsOrigLrp from BERT_explainability.modules.BERT.BertForSequenceClassification import \ BertForSequenceClassification as BertForSequenceClassificationTest from BERT_explainability.modules.BERT.ExplanationGenerator import Generator from BERT_rationale_benchmark.utils import (Annotation, Evidence, load_datasets, load_documents, write_jsonl) from sklearn.metrics import accuracy_score from transformers import BertForSequenceClassification, BertTokenizer logging.basicConfig( level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s" ) logger = logging.getLogger(__name__) # let's make this more or less deterministic (not resistent to restarts) random.seed(12345) np.random.seed(67890) torch.manual_seed(10111213) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False import numpy as np latex_special_token = ["!@#$%^&*()"] def generate(text_list, attention_list, latex_file, color="red", rescale_value=False): attention_list = attention_list[: len(text_list)] if attention_list.max() == attention_list.min(): attention_list = torch.zeros_like(attention_list) else: attention_list = ( 100 * (attention_list - attention_list.min()) / (attention_list.max() - attention_list.min()) ) attention_list[attention_list < 1] = 0 attention_list = attention_list.tolist() text_list = [text_list[i].replace("$", "") for i in range(len(text_list))] if rescale_value: attention_list = rescale(attention_list) word_num = len(text_list) text_list = clean_word(text_list) with open(latex_file, "w") as f: f.write( r"""\documentclass[varwidth=150mm]{standalone} \special{papersize=210mm,297mm} \usepackage{color} \usepackage{tcolorbox} \usepackage{CJK} \usepackage{adjustbox} \tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt} \begin{document} \begin{CJK*}{UTF8}{gbsn}""" + "\n" ) string = ( r"""{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{""" + "\n" ) for idx in range(word_num): # string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} " # print(text_list[idx]) if "\#\#" in text_list[idx]: token = text_list[idx].replace("\#\#", "") string += ( "\\colorbox{%s!%s}{" % (color, attention_list[idx]) + "\\strut " + token + "}" ) else: string += ( " " + "\\colorbox{%s!%s}{" % (color, attention_list[idx]) + "\\strut " + text_list[idx] + "}" ) string += "\n}}}" f.write(string + "\n") f.write( r"""\end{CJK*} \end{document}""" ) def clean_word(word_list): new_word_list = [] for word in word_list: for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]: if latex_sensitive in word: word = word.replace(latex_sensitive, "\\" + latex_sensitive) new_word_list.append(word) return new_word_list def scores_per_word_from_scores_per_token(input, tokenizer, input_ids, scores_per_id): words = tokenizer.convert_ids_to_tokens(input_ids) words = [word.replace("##", "") for word in words] score_per_char = [] # TODO: DELETE input_ids_chars = [] for word in words: if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]: continue input_ids_chars += list(word) # TODO: DELETE for i in range(len(scores_per_id)): if words[i] in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]: continue score_per_char += [scores_per_id[i]] * len(words[i]) score_per_word = [] start_idx = 0 end_idx = 0 # TODO: DELETE words_from_chars = [] for inp in input: if start_idx >= len(score_per_char): break end_idx = end_idx + len(inp) score_per_word.append(np.max(score_per_char[start_idx:end_idx])) # TODO: DELETE words_from_chars.append("".join(input_ids_chars[start_idx:end_idx])) start_idx = end_idx if words_from_chars[:-1] != input[: len(words_from_chars) - 1]: print(words_from_chars) print(input[: len(words_from_chars)]) print(words) print(tokenizer.convert_ids_to_tokens(input_ids)) assert False return torch.tensor(score_per_word) def get_input_words(input, tokenizer, input_ids): words = tokenizer.convert_ids_to_tokens(input_ids) words = [word.replace("##", "") for word in words] input_ids_chars = [] for word in words: if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]: continue input_ids_chars += list(word) start_idx = 0 end_idx = 0 words_from_chars = [] for inp in input: if start_idx >= len(input_ids_chars): break end_idx = end_idx + len(inp) words_from_chars.append("".join(input_ids_chars[start_idx:end_idx])) start_idx = end_idx if words_from_chars[:-1] != input[: len(words_from_chars) - 1]: print(words_from_chars) print(input[: len(words_from_chars)]) print(words) print(tokenizer.convert_ids_to_tokens(input_ids)) assert False return words_from_chars def bert_tokenize_doc( doc: List[List[str]], tokenizer, special_token_map ) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]: """Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words""" sents = [] sent_token_spans = [] for sent in doc: tokens = [] spans = [] start = 0 for w in sent: if w in special_token_map: tokens.append(w) else: tokens.extend(tokenizer.tokenize(w)) end = len(tokens) spans.append((start, end)) start = end sents.append(tokens) sent_token_spans.append(spans) return sents, sent_token_spans def initialize_models(params: dict, batch_first: bool, use_half_precision=False): assert batch_first max_length = params["max_length"] tokenizer = BertTokenizer.from_pretrained(params["bert_vocab"]) pad_token_id = tokenizer.pad_token_id cls_token_id = tokenizer.cls_token_id sep_token_id = tokenizer.sep_token_id bert_dir = params["bert_dir"] evidence_classes = dict( (y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"]) ) evidence_classifier = BertForSequenceClassification.from_pretrained( bert_dir, num_labels=len(evidence_classes) ) word_interner = tokenizer.vocab de_interner = tokenizer.ids_to_tokens return evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer BATCH_FIRST = True def extract_docid_from_dataset_element(element): return next(iter(element.evidences))[0].docid def extract_evidence_from_dataset_element(element): return next(iter(element.evidences)) def main(): parser = argparse.ArgumentParser( description="""Trains a pipeline model. Step 1 is evidence identification, that is identify if a given sentence is evidence or not Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task (e.g. sentiment or significance). These models should be separated into two separate steps, but at the moment: * prep data (load, intern documents, load json) * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values. * train evidence identification * convert data for evidence classification - take all rationales + decisions and use this as input * train evidence classification * decode first the evidence, then run classification for each split """, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--data_dir", dest="data_dir", required=True, help="Which directory contains a {train,val,test}.jsonl file?", ) parser.add_argument( "--output_dir", dest="output_dir", required=True, help="Where shall we write intermediate models + final data to?", ) parser.add_argument( "--model_params", dest="model_params", required=True, help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.", ) args = parser.parse_args() assert BATCH_FIRST os.makedirs(args.output_dir, exist_ok=True) with open(args.model_params, "r") as fp: logger.info(f"Loading model parameters from {args.model_params}") model_params = json.load(fp) logger.info(f"Params: {json.dumps(model_params, indent=2, sort_keys=True)}") train, val, test = load_datasets(args.data_dir) docids = set( e.docid for e in chain.from_iterable( chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test))) ) ) documents = load_documents(args.data_dir, docids) logger.info(f"Loaded {len(documents)} documents") ( evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer, ) = initialize_models(model_params, batch_first=BATCH_FIRST) logger.info(f"We have {len(word_interner)} wordpieces") cache = os.path.join(args.output_dir, "preprocessed.pkl") if os.path.exists(cache): logger.info(f"Loading interned documents from {cache}") (interned_documents) = torch.load(cache) else: logger.info(f"Interning documents") interned_documents = {} for d, doc in documents.items(): encoding = tokenizer.encode_plus( doc, add_special_tokens=True, max_length=model_params["max_length"], return_token_type_ids=False, pad_to_max_length=False, return_attention_mask=True, return_tensors="pt", truncation=True, ) interned_documents[d] = encoding torch.save((interned_documents), cache) evidence_classifier = evidence_classifier.cuda() optimizer = None scheduler = None save_dir = args.output_dir logging.info(f"Beginning training classifier") evidence_classifier_output_dir = os.path.join(save_dir, "classifier") os.makedirs(save_dir, exist_ok=True) os.makedirs(evidence_classifier_output_dir, exist_ok=True) model_save_file = os.path.join(evidence_classifier_output_dir, "classifier.pt") epoch_save_file = os.path.join( evidence_classifier_output_dir, "classifier_epoch_data.pt" ) device = next(evidence_classifier.parameters()).device if optimizer is None: optimizer = torch.optim.Adam( evidence_classifier.parameters(), lr=model_params["evidence_classifier"]["lr"], ) criterion = nn.CrossEntropyLoss(reduction="none") batch_size = model_params["evidence_classifier"]["batch_size"] epochs = model_params["evidence_classifier"]["epochs"] patience = model_params["evidence_classifier"]["patience"] max_grad_norm = model_params["evidence_classifier"].get("max_grad_norm", None) class_labels = [k for k, v in sorted(evidence_classes.items())] results = { "train_loss": [], "train_f1": [], "train_acc": [], "val_loss": [], "val_f1": [], "val_acc": [], } best_epoch = -1 best_val_acc = 0 best_val_loss = float("inf") best_model_state_dict = None start_epoch = 0 epoch_data = {} if os.path.exists(epoch_save_file): logging.info(f"Restoring model from {model_save_file}") evidence_classifier.load_state_dict(torch.load(model_save_file)) epoch_data = torch.load(epoch_save_file) start_epoch = epoch_data["epoch"] + 1 # handle finishing because patience was exceeded or we didn't get the best final epoch if bool(epoch_data.get("done", 0)): start_epoch = epochs results = epoch_data["results"] best_epoch = start_epoch best_model_state_dict = OrderedDict( {k: v.cpu() for k, v in evidence_classifier.state_dict().items()} ) logging.info(f"Restoring training from epoch {start_epoch}") logging.info( f"Training evidence classifier from epoch {start_epoch} until epoch {epochs}" ) optimizer.zero_grad() for epoch in range(start_epoch, epochs): epoch_train_data = random.sample(train, k=len(train)) epoch_train_loss = 0 epoch_training_acc = 0 evidence_classifier.train() logging.info( f"Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples" ) for batch_start in range(0, len(epoch_train_data), batch_size): batch_elements = epoch_train_data[ batch_start : min(batch_start + batch_size, len(epoch_train_data)) ] targets = [evidence_classes[s.classification] for s in batch_elements] targets = torch.tensor(targets, dtype=torch.long, device=device) samples_encoding = [ interned_documents[extract_docid_from_dataset_element(s)] for s in batch_elements ] input_ids = ( torch.stack( [ samples_encoding[i]["input_ids"] for i in range(len(samples_encoding)) ] ) .squeeze(1) .to(device) ) attention_masks = ( torch.stack( [ samples_encoding[i]["attention_mask"] for i in range(len(samples_encoding)) ] ) .squeeze(1) .to(device) ) preds = evidence_classifier( input_ids=input_ids, attention_mask=attention_masks )[0] epoch_training_acc += accuracy_score( preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False ) loss = criterion(preds, targets.to(device=preds.device)).sum() epoch_train_loss += loss.item() loss.backward() assert loss == loss # for nans if max_grad_norm: torch.nn.utils.clip_grad_norm_( evidence_classifier.parameters(), max_grad_norm ) optimizer.step() if scheduler: scheduler.step() optimizer.zero_grad() epoch_train_loss /= len(epoch_train_data) epoch_training_acc /= len(epoch_train_data) assert epoch_train_loss == epoch_train_loss # for nans results["train_loss"].append(epoch_train_loss) logging.info(f"Epoch {epoch} training loss {epoch_train_loss}") logging.info(f"Epoch {epoch} training accuracy {epoch_training_acc}") with torch.no_grad(): epoch_val_loss = 0 epoch_val_acc = 0 epoch_val_data = random.sample(val, k=len(val)) evidence_classifier.eval() val_batch_size = 32 logging.info( f"Validating with {len(epoch_val_data) // val_batch_size} batches with {len(epoch_val_data)} examples" ) for batch_start in range(0, len(epoch_val_data), val_batch_size): batch_elements = epoch_val_data[ batch_start : min(batch_start + val_batch_size, len(epoch_val_data)) ] targets = [evidence_classes[s.classification] for s in batch_elements] targets = torch.tensor(targets, dtype=torch.long, device=device) samples_encoding = [ interned_documents[extract_docid_from_dataset_element(s)] for s in batch_elements ] input_ids = ( torch.stack( [ samples_encoding[i]["input_ids"] for i in range(len(samples_encoding)) ] ) .squeeze(1) .to(device) ) attention_masks = ( torch.stack( [ samples_encoding[i]["attention_mask"] for i in range(len(samples_encoding)) ] ) .squeeze(1) .to(device) ) preds = evidence_classifier( input_ids=input_ids, attention_mask=attention_masks )[0] epoch_val_acc += accuracy_score( preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False ) loss = criterion(preds, targets.to(device=preds.device)).sum() epoch_val_loss += loss.item() epoch_val_loss /= len(val) epoch_val_acc /= len(val) results["val_acc"].append(epoch_val_acc) results["val_loss"] = epoch_val_loss logging.info(f"Epoch {epoch} val loss {epoch_val_loss}") logging.info(f"Epoch {epoch} val acc {epoch_val_acc}") if epoch_val_acc > best_val_acc or ( epoch_val_acc == best_val_acc and epoch_val_loss < best_val_loss ): best_model_state_dict = OrderedDict( {k: v.cpu() for k, v in evidence_classifier.state_dict().items()} ) best_epoch = epoch best_val_acc = epoch_val_acc best_val_loss = epoch_val_loss epoch_data = { "epoch": epoch, "results": results, "best_val_acc": best_val_acc, "done": 0, } torch.save(evidence_classifier.state_dict(), model_save_file) torch.save(epoch_data, epoch_save_file) logging.debug( f"Epoch {epoch} new best model with val accuracy {epoch_val_acc}" ) if epoch - best_epoch > patience: logging.info(f"Exiting after epoch {epoch} due to no improvement") epoch_data["done"] = 1 torch.save(epoch_data, epoch_save_file) break epoch_data["done"] = 1 epoch_data["results"] = results torch.save(epoch_data, epoch_save_file) evidence_classifier.load_state_dict(best_model_state_dict) evidence_classifier = evidence_classifier.to(device=device) evidence_classifier.eval() # test test_classifier = BertForSequenceClassificationTest.from_pretrained( model_params["bert_dir"], num_labels=len(evidence_classes) ).to(device) orig_lrp_classifier = BertForClsOrigLrp.from_pretrained( model_params["bert_dir"], num_labels=len(evidence_classes) ).to(device) if os.path.exists(epoch_save_file): logging.info(f"Restoring model from {model_save_file}") test_classifier.load_state_dict(torch.load(model_save_file)) orig_lrp_classifier.load_state_dict(torch.load(model_save_file)) test_classifier.eval() orig_lrp_classifier.eval() test_batch_size = 1 logging.info( f"Testing with {len(test) // test_batch_size} batches with {len(test)} examples" ) # explainability explanations = Generator(test_classifier) explanations_orig_lrp = Generator(orig_lrp_classifier) method = "transformer_attribution" method_folder = { "transformer_attribution": "ours", "partial_lrp": "partial_lrp", "last_attn": "last_attn", "attn_gradcam": "attn_gradcam", "lrp": "lrp", "rollout": "rollout", "ground_truth": "ground_truth", "generate_all": "generate_all", } method_expl = { "transformer_attribution": explanations.generate_LRP, "partial_lrp": explanations_orig_lrp.generate_LRP_last_layer, "last_attn": explanations_orig_lrp.generate_attn_last_layer, "attn_gradcam": explanations_orig_lrp.generate_attn_gradcam, "lrp": explanations_orig_lrp.generate_full_lrp, "rollout": explanations_orig_lrp.generate_rollout, } os.makedirs(os.path.join(args.output_dir, method_folder[method]), exist_ok=True) result_files = [] for i in range(5, 85, 5): result_files.append( open( os.path.join( args.output_dir, "{0}/identifier_results_{1}.json" ).format(method_folder[method], i), "w", ) ) j = 0 for batch_start in range(0, len(test), test_batch_size): batch_elements = test[ batch_start : min(batch_start + test_batch_size, len(test)) ] targets = [evidence_classes[s.classification] for s in batch_elements] targets = torch.tensor(targets, dtype=torch.long, device=device) samples_encoding = [ interned_documents[extract_docid_from_dataset_element(s)] for s in batch_elements ] input_ids = ( torch.stack( [ samples_encoding[i]["input_ids"] for i in range(len(samples_encoding)) ] ) .squeeze(1) .to(device) ) attention_masks = ( torch.stack( [ samples_encoding[i]["attention_mask"] for i in range(len(samples_encoding)) ] ) .squeeze(1) .to(device) ) preds = test_classifier( input_ids=input_ids, attention_mask=attention_masks )[0] for s in batch_elements: doc_name = extract_docid_from_dataset_element(s) inp = documents[doc_name].split() classification = "neg" if targets.item() == 0 else "pos" is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0 if method == "generate_all": file_name = "{0}_{1}_{2}.tex".format( j, classification, is_classification_correct ) GT_global = os.path.join( args.output_dir, "{0}/visual_results_{1}.pdf" ).format(method_folder["ground_truth"], j) GT_ours = os.path.join( args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" ).format( method_folder["transformer_attribution"], j, classification, is_classification_correct, ) CF_ours = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( method_folder["transformer_attribution"], j ) GT_partial = os.path.join( args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" ).format( method_folder["partial_lrp"], j, classification, is_classification_correct, ) CF_partial = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( method_folder["partial_lrp"], j ) GT_gradcam = os.path.join( args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" ).format( method_folder["attn_gradcam"], j, classification, is_classification_correct, ) CF_gradcam = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( method_folder["attn_gradcam"], j ) GT_lrp = os.path.join( args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" ).format( method_folder["lrp"], j, classification, is_classification_correct, ) CF_lrp = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( method_folder["lrp"], j ) GT_lastattn = os.path.join( args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" ).format( method_folder["last_attn"], j, classification, is_classification_correct, ) GT_rollout = os.path.join( args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" ).format( method_folder["rollout"], j, classification, is_classification_correct, ) with open(file_name, "w") as f: f.write( r"""\documentclass[varwidth]{standalone} \usepackage{color} \usepackage{tcolorbox} \usepackage{CJK} \tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt} \begin{document} \begin{CJK*}{UTF8}{gbsn} {\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{ \setlength{\tabcolsep}{2pt} % Default value: 6pt \begin{tabular}{ccc} \includegraphics[width=0.32\linewidth]{""" + GT_global + """}& \includegraphics[width=0.32\linewidth]{""" + GT_ours + """}& \includegraphics[width=0.32\linewidth]{""" + CF_ours + """}\\\\ (a) & (b) & (c)\\\\ \includegraphics[width=0.32\linewidth]{""" + GT_partial + """}& \includegraphics[width=0.32\linewidth]{""" + CF_partial + """}& \includegraphics[width=0.32\linewidth]{""" + GT_gradcam + """}\\\\ (d) & (e) & (f)\\\\ \includegraphics[width=0.32\linewidth]{""" + CF_gradcam + """}& \includegraphics[width=0.32\linewidth]{""" + GT_lrp + """}& \includegraphics[width=0.32\linewidth]{""" + CF_lrp + """}\\\\ (g) & (h) & (i)\\\\ \includegraphics[width=0.32\linewidth]{""" + GT_lastattn + """}& \includegraphics[width=0.32\linewidth]{""" + GT_rollout + """}&\\\\ (j) & (k)&\\\\ \end{tabular} }}} \end{CJK*} \end{document} )""" ) j += 1 break if method == "ground_truth": inp_cropped = get_input_words(inp, tokenizer, input_ids[0]) cam = torch.zeros(len(inp_cropped)) for evidence in extract_evidence_from_dataset_element(s): start_idx = evidence.start_token if start_idx >= len(cam): break end_idx = evidence.end_token cam[start_idx:end_idx] = 1 generate( inp_cropped, cam, ( os.path.join( args.output_dir, "{0}/visual_results_{1}.tex" ).format(method_folder[method], j) ), color="green", ) j = j + 1 break text = tokenizer.convert_ids_to_tokens(input_ids[0]) classification = "neg" if targets.item() == 0 else "pos" is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0 target_idx = targets.item() cam_target = method_expl[method]( input_ids=input_ids, attention_mask=attention_masks, index=target_idx, )[0] cam_target = cam_target.clamp(min=0) generate( text, cam_target, ( os.path.join(args.output_dir, "{0}/{1}_GT_{2}_{3}.tex").format( method_folder[method], j, classification, is_classification_correct, ) ), ) if method in [ "transformer_attribution", "partial_lrp", "attn_gradcam", "lrp", ]: cam_false_class = method_expl[method]( input_ids=input_ids, attention_mask=attention_masks, index=1 - target_idx, )[0] cam_false_class = cam_false_class.clamp(min=0) generate( text, cam_false_class, ( os.path.join(args.output_dir, "{0}/{1}_CF.tex").format( method_folder[method], j ) ), ) cam = cam_target cam = scores_per_word_from_scores_per_token( inp, tokenizer, input_ids[0], cam ) j = j + 1 doc_name = extract_docid_from_dataset_element(s) hard_rationales = [] for res, i in enumerate(range(5, 85, 5)): print("calculating top ", i) _, indices = cam.topk(k=i) for index in indices.tolist(): hard_rationales.append( {"start_token": index, "end_token": index + 1} ) result_dict = { "annotation_id": doc_name, "rationales": [ { "docid": doc_name, "hard_rationale_predictions": hard_rationales, } ], } result_files[res].write(json.dumps(result_dict) + "\n") for i in range(len(result_files)): result_files[i].close() if __name__ == "__main__": main()