add files
Browse filesCo-authored-by: Hila <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -2
- Transformer-Explainability/BERT_explainability.ipynb +581 -0
- Transformer-Explainability/BERT_explainability/modules/BERT/BERT.py +748 -0
- Transformer-Explainability/BERT_explainability/modules/BERT/BERT_cls_lrp.py +240 -0
- Transformer-Explainability/BERT_explainability/modules/BERT/BERT_orig_lrp.py +748 -0
- Transformer-Explainability/BERT_explainability/modules/BERT/BertForSequenceClassification.py +241 -0
- Transformer-Explainability/BERT_explainability/modules/BERT/ExplanationGenerator.py +165 -0
- Transformer-Explainability/BERT_explainability/modules/__init__.py +0 -0
- Transformer-Explainability/BERT_explainability/modules/layers_lrp.py +352 -0
- Transformer-Explainability/BERT_explainability/modules/layers_ours.py +373 -0
- Transformer-Explainability/BERT_params/boolq.json +26 -0
- Transformer-Explainability/BERT_params/boolq_baas.json +26 -0
- Transformer-Explainability/BERT_params/boolq_bert.json +32 -0
- Transformer-Explainability/BERT_params/boolq_soft.json +21 -0
- Transformer-Explainability/BERT_params/cose_bert.json +30 -0
- Transformer-Explainability/BERT_params/cose_multiclass.json +35 -0
- Transformer-Explainability/BERT_params/esnli_bert.json +28 -0
- Transformer-Explainability/BERT_params/evidence_inference.json +26 -0
- Transformer-Explainability/BERT_params/evidence_inference_bert.json +33 -0
- Transformer-Explainability/BERT_params/evidence_inference_soft.json +22 -0
- Transformer-Explainability/BERT_params/fever.json +26 -0
- Transformer-Explainability/BERT_params/fever_baas.json +25 -0
- Transformer-Explainability/BERT_params/fever_bert.json +32 -0
- Transformer-Explainability/BERT_params/fever_soft.json +21 -0
- Transformer-Explainability/BERT_params/movies.json +26 -0
- Transformer-Explainability/BERT_params/movies_baas.json +26 -0
- Transformer-Explainability/BERT_params/movies_bert.json +32 -0
- Transformer-Explainability/BERT_params/movies_soft.json +21 -0
- Transformer-Explainability/BERT_params/multirc.json +26 -0
- Transformer-Explainability/BERT_params/multirc_baas.json +26 -0
- Transformer-Explainability/BERT_params/multirc_bert.json +32 -0
- Transformer-Explainability/BERT_params/multirc_soft.json +21 -0
- Transformer-Explainability/BERT_rationale_benchmark/__init__.py +0 -0
- Transformer-Explainability/BERT_rationale_benchmark/metrics.py +1007 -0
- Transformer-Explainability/BERT_rationale_benchmark/models/model_utils.py +186 -0
- Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/__init__.py +0 -0
- Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/bert_pipeline.py +852 -0
- Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_train.py +235 -0
- Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_utils.py +1045 -0
- Transformer-Explainability/BERT_rationale_benchmark/models/sequence_taggers.py +78 -0
- Transformer-Explainability/BERT_rationale_benchmark/utils.py +251 -0
- Transformer-Explainability/DeiT.PNG +0 -0
- Transformer-Explainability/DeiT_example.ipynb +0 -0
- Transformer-Explainability/LICENSE +21 -0
- Transformer-Explainability/README.md +153 -0
- Transformer-Explainability/Transformer_explainability.ipynb +0 -0
- Transformer-Explainability/baselines/ViT/ViT_LRP.py +535 -0
- Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py +107 -0
- Transformer-Explainability/baselines/ViT/ViT_new.py +329 -0
- Transformer-Explainability/baselines/ViT/ViT_orig_LRP.py +508 -0
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Comparative Explainability
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.34.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: Comparative Explainability
|
3 |
+
emoji: 🏆
|
4 |
colorFrom: red
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.34.0
|
8 |
app_file: app.py
|
Transformer-Explainability/BERT_explainability.ipynb
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"name": "BERT-explainability.ipynb",
|
7 |
+
"provenance": [],
|
8 |
+
"authorship_tag": "ABX9TyOm8dIRrumd5XNcc+fntVA5",
|
9 |
+
"include_colab_link": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"accelerator": "GPU"
|
16 |
+
},
|
17 |
+
"cells": [
|
18 |
+
{
|
19 |
+
"cell_type": "markdown",
|
20 |
+
"metadata": {
|
21 |
+
"id": "view-in-github",
|
22 |
+
"colab_type": "text"
|
23 |
+
},
|
24 |
+
"source": [
|
25 |
+
"<a href=\"https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"metadata": {
|
31 |
+
"colab": {
|
32 |
+
"base_uri": "https://localhost:8080/"
|
33 |
+
},
|
34 |
+
"id": "YCdGaMuy56TA",
|
35 |
+
"outputId": "8f802262-55eb-4366-b772-89c4756224b3"
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"!git clone https://github.com/hila-chefer/Transformer-Explainability.git\n",
|
39 |
+
"\n",
|
40 |
+
"import os\n",
|
41 |
+
"os.chdir(f'./Transformer-Explainability')\n",
|
42 |
+
"\n",
|
43 |
+
"!pip install -r requirements.txt\n",
|
44 |
+
"!pip install captum"
|
45 |
+
],
|
46 |
+
"execution_count": 1,
|
47 |
+
"outputs": [
|
48 |
+
{
|
49 |
+
"output_type": "stream",
|
50 |
+
"name": "stdout",
|
51 |
+
"text": [
|
52 |
+
"fatal: destination path 'Transformer-Explainability' already exists and is not an empty directory.\n",
|
53 |
+
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
54 |
+
"Requirement already satisfied: Pillow>=8.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 1)) (9.4.0)\n",
|
55 |
+
"Requirement already satisfied: einops==0.3.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 2)) (0.3.0)\n",
|
56 |
+
"Requirement already satisfied: h5py==2.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 3)) (2.8.0)\n",
|
57 |
+
"Requirement already satisfied: imageio==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 4)) (2.9.0)\n",
|
58 |
+
"Collecting matplotlib==3.3.2\n",
|
59 |
+
" Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n",
|
60 |
+
"Requirement already satisfied: opencv_python in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 6)) (4.6.0.66)\n",
|
61 |
+
"Requirement already satisfied: scikit_image==0.17.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 7)) (0.17.2)\n",
|
62 |
+
"Requirement already satisfied: scipy==1.5.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 8)) (1.5.2)\n",
|
63 |
+
"Requirement already satisfied: sklearn in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 9)) (0.0.post1)\n",
|
64 |
+
"Requirement already satisfied: torch==1.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 10)) (1.7.0)\n",
|
65 |
+
"Requirement already satisfied: torchvision==0.8.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 11)) (0.8.1)\n",
|
66 |
+
"Requirement already satisfied: tqdm==4.51.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 12)) (4.51.0)\n",
|
67 |
+
"Requirement already satisfied: transformers==3.5.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 13)) (3.5.1)\n",
|
68 |
+
"Requirement already satisfied: utils==1.0.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 14)) (1.0.1)\n",
|
69 |
+
"Requirement already satisfied: Pygments>=2.7.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 15)) (2.14.0)\n",
|
70 |
+
"Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.21.6)\n",
|
71 |
+
"Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.15.0)\n",
|
72 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (1.4.4)\n",
|
73 |
+
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (3.0.9)\n",
|
74 |
+
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2.8.2)\n",
|
75 |
+
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (0.11.0)\n",
|
76 |
+
"Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2022.12.7)\n",
|
77 |
+
"Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (3.0)\n",
|
78 |
+
"Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (2022.10.10)\n",
|
79 |
+
"Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (1.4.1)\n",
|
80 |
+
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.6)\n",
|
81 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (4.4.0)\n",
|
82 |
+
"Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.16.0)\n",
|
83 |
+
"Requirement already satisfied: sacremoses in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.0.53)\n",
|
84 |
+
"Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.19.6)\n",
|
85 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.9.0)\n",
|
86 |
+
"Requirement already satisfied: sentencepiece==0.1.91 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.1.91)\n",
|
87 |
+
"Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (21.3)\n",
|
88 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2022.6.2)\n",
|
89 |
+
"Requirement already satisfied: tokenizers==0.9.3 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.9.3)\n",
|
90 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2.25.1)\n",
|
91 |
+
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (4.0.0)\n",
|
92 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (1.24.3)\n",
|
93 |
+
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (2.10)\n",
|
94 |
+
"Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (1.2.0)\n",
|
95 |
+
"Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (7.1.2)\n",
|
96 |
+
"Installing collected packages: matplotlib\n",
|
97 |
+
" Attempting uninstall: matplotlib\n",
|
98 |
+
" Found existing installation: matplotlib 3.6.3\n",
|
99 |
+
" Uninstalling matplotlib-3.6.3:\n",
|
100 |
+
" Successfully uninstalled matplotlib-3.6.3\n",
|
101 |
+
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
102 |
+
"fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n",
|
103 |
+
"\u001b[0mSuccessfully installed matplotlib-3.3.2\n",
|
104 |
+
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
105 |
+
"Requirement already satisfied: captum in /usr/local/lib/python3.8/dist-packages (0.6.0)\n",
|
106 |
+
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum) (3.3.2)\n",
|
107 |
+
"Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum) (1.7.0)\n",
|
108 |
+
"Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum) (1.21.6)\n",
|
109 |
+
"Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.16.0)\n",
|
110 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (4.4.0)\n",
|
111 |
+
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.6)\n",
|
112 |
+
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (0.11.0)\n",
|
113 |
+
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (9.4.0)\n",
|
114 |
+
"Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2022.12.7)\n",
|
115 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (1.4.4)\n",
|
116 |
+
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (3.0.9)\n",
|
117 |
+
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2.8.2)\n",
|
118 |
+
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->captum) (1.15.0)\n"
|
119 |
+
]
|
120 |
+
}
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"source": [
|
126 |
+
"!pip install captum==0.6.0\n",
|
127 |
+
"!pip install matplotlib==3.3.2"
|
128 |
+
],
|
129 |
+
"metadata": {
|
130 |
+
"id": "zDPnh4lofcNw",
|
131 |
+
"outputId": "3d585bbc-ff3b-4a09-b5bf-57bb4d46e830",
|
132 |
+
"colab": {
|
133 |
+
"base_uri": "https://localhost:8080/"
|
134 |
+
}
|
135 |
+
},
|
136 |
+
"execution_count": 9,
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"output_type": "stream",
|
140 |
+
"name": "stdout",
|
141 |
+
"text": [
|
142 |
+
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
143 |
+
"Requirement already satisfied: captum==0.6.0 in /usr/local/lib/python3.8/dist-packages (0.6.0)\n",
|
144 |
+
"Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.7.0)\n",
|
145 |
+
"Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.21.6)\n",
|
146 |
+
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (3.6.3)\n",
|
147 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (4.4.0)\n",
|
148 |
+
"Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.16.0)\n",
|
149 |
+
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.6)\n",
|
150 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.4.4)\n",
|
151 |
+
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.0.7)\n",
|
152 |
+
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (9.4.0)\n",
|
153 |
+
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (2.8.2)\n",
|
154 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (21.3)\n",
|
155 |
+
"Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (3.0.9)\n",
|
156 |
+
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (4.38.0)\n",
|
157 |
+
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (0.11.0)\n",
|
158 |
+
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7->matplotlib->captum==0.6.0) (1.15.0)\n",
|
159 |
+
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
160 |
+
"Collecting matplotlib==3.3.2\n",
|
161 |
+
" Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n",
|
162 |
+
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (9.4.0)\n",
|
163 |
+
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (0.11.0)\n",
|
164 |
+
"Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.21.6)\n",
|
165 |
+
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (3.0.9)\n",
|
166 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.4.4)\n",
|
167 |
+
"Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2022.12.7)\n",
|
168 |
+
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2.8.2)\n",
|
169 |
+
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib==3.3.2) (1.15.0)\n",
|
170 |
+
"Installing collected packages: matplotlib\n",
|
171 |
+
" Attempting uninstall: matplotlib\n",
|
172 |
+
" Found existing installation: matplotlib 3.6.3\n",
|
173 |
+
" Uninstalling matplotlib-3.6.3:\n",
|
174 |
+
" Successfully uninstalled matplotlib-3.6.3\n",
|
175 |
+
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
176 |
+
"fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n",
|
177 |
+
"\u001b[0mSuccessfully installed matplotlib-3.3.2\n"
|
178 |
+
]
|
179 |
+
}
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"metadata": {
|
185 |
+
"id": "4-XGl_Zw6Aht"
|
186 |
+
},
|
187 |
+
"source": [
|
188 |
+
"from transformers import BertTokenizer\n",
|
189 |
+
"from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n",
|
190 |
+
"from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\n",
|
191 |
+
"from transformers import BertTokenizer\n",
|
192 |
+
"from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n",
|
193 |
+
"from transformers import AutoTokenizer\n",
|
194 |
+
"\n",
|
195 |
+
"from captum.attr import visualization\n",
|
196 |
+
"import torch"
|
197 |
+
],
|
198 |
+
"execution_count": 10,
|
199 |
+
"outputs": []
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"metadata": {
|
204 |
+
"id": "VakYjrkC6C3S"
|
205 |
+
},
|
206 |
+
"source": [
|
207 |
+
"model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\n",
|
208 |
+
"model.eval()\n",
|
209 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")\n",
|
210 |
+
"# initialize the explanations generator\n",
|
211 |
+
"explanations = Generator(model)\n",
|
212 |
+
"\n",
|
213 |
+
"classifications = [\"NEGATIVE\", \"POSITIVE\"]\n"
|
214 |
+
],
|
215 |
+
"execution_count": 11,
|
216 |
+
"outputs": []
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "markdown",
|
220 |
+
"metadata": {
|
221 |
+
"id": "jGRp376FPOvV"
|
222 |
+
},
|
223 |
+
"source": [
|
224 |
+
"#Positive sentiment example"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"metadata": {
|
230 |
+
"id": "uSLZtv546H2z",
|
231 |
+
"colab": {
|
232 |
+
"base_uri": "https://localhost:8080/",
|
233 |
+
"height": 219
|
234 |
+
},
|
235 |
+
"outputId": "26712e90-0b77-40b0-a908-fef13dd88bcd"
|
236 |
+
},
|
237 |
+
"source": [
|
238 |
+
"# encode a sentence\n",
|
239 |
+
"text_batch = [\"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.\"]\n",
|
240 |
+
"encoding = tokenizer(text_batch, return_tensors='pt')\n",
|
241 |
+
"input_ids = encoding['input_ids'].to(\"cuda\")\n",
|
242 |
+
"attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
|
243 |
+
"\n",
|
244 |
+
"# true class is positive - 1\n",
|
245 |
+
"true_class = 1\n",
|
246 |
+
"\n",
|
247 |
+
"# generate an explanation for the input\n",
|
248 |
+
"expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n",
|
249 |
+
"# normalize scores\n",
|
250 |
+
"expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
|
251 |
+
"\n",
|
252 |
+
"# get the model classification\n",
|
253 |
+
"output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
|
254 |
+
"classification = output.argmax(dim=-1).item()\n",
|
255 |
+
"# get class name\n",
|
256 |
+
"class_name = classifications[classification]\n",
|
257 |
+
"# if the classification is negative, higher explanation scores are more negative\n",
|
258 |
+
"# flip for visualization\n",
|
259 |
+
"if class_name == \"NEGATIVE\":\n",
|
260 |
+
" expl *= (-1)\n",
|
261 |
+
"\n",
|
262 |
+
"tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
|
263 |
+
"print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
|
264 |
+
"vis_data_records = [visualization.VisualizationDataRecord(\n",
|
265 |
+
" expl,\n",
|
266 |
+
" output[0][classification],\n",
|
267 |
+
" classification,\n",
|
268 |
+
" true_class,\n",
|
269 |
+
" true_class,\n",
|
270 |
+
" 1, \n",
|
271 |
+
" tokens,\n",
|
272 |
+
" 1)]\n",
|
273 |
+
"visualization.visualize_text(vis_data_records)"
|
274 |
+
],
|
275 |
+
"execution_count": 12,
|
276 |
+
"outputs": [
|
277 |
+
{
|
278 |
+
"output_type": "stream",
|
279 |
+
"name": "stdout",
|
280 |
+
"text": [
|
281 |
+
"[('[CLS]', 0.0), ('this', 0.4267549514770508), ('movie', 0.30920878052711487), ('was', 0.2684089243412018), ('the', 0.33637329936027527), ('best', 0.6280889511108398), ('movie', 0.28546375036239624), ('i', 0.1863601952791214), ('have', 0.10115814208984375), ('ever', 0.1419338583946228), ('seen', 0.1898290067911148), ('!', 0.5944811105728149), ('some', 0.003896803595125675), ('scenes', 0.033401958644390106), ('were', 0.018588582053780556), ('ridiculous', 0.018908796831965446), (',', 0.0), ('but', 0.42920616269111633), ('acting', 0.43855082988739014), ('was', 0.500239372253418), ('great', 1.0), ('.', 0.014817383140325546), ('[SEP]', 0.0868983045220375)]\n"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"output_type": "display_data",
|
286 |
+
"data": {
|
287 |
+
"text/plain": [
|
288 |
+
"<IPython.core.display.HTML object>"
|
289 |
+
],
|
290 |
+
"text/html": [
|
291 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0; line-height:1.75\"><font color=\"black\"> best </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> have </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ever </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> seen </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ! </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> scenes </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ridiculous </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> acting </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> great </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
"metadata": {}
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"output_type": "execute_result",
|
298 |
+
"data": {
|
299 |
+
"text/plain": [
|
300 |
+
"<IPython.core.display.HTML object>"
|
301 |
+
],
|
302 |
+
"text/html": [
|
303 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0; line-height:1.75\"><font color=\"black\"> best </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> have </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ever </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> seen </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ! </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> scenes </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ridiculous </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> acting </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> great </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
"metadata": {},
|
307 |
+
"execution_count": 12
|
308 |
+
}
|
309 |
+
]
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"cell_type": "markdown",
|
313 |
+
"metadata": {
|
314 |
+
"id": "oO_k1BtSPVt3"
|
315 |
+
},
|
316 |
+
"source": [
|
317 |
+
"#Negative sentiment example"
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"metadata": {
|
323 |
+
"colab": {
|
324 |
+
"base_uri": "https://localhost:8080/",
|
325 |
+
"height": 219
|
326 |
+
},
|
327 |
+
"id": "gD4xcvovI1KI",
|
328 |
+
"outputId": "e4a50a94-da4c-460e-b602-052b09cec28f"
|
329 |
+
},
|
330 |
+
"source": [
|
331 |
+
"# encode a sentence\n",
|
332 |
+
"text_batch = [\"I really didn't like this movie. Some of the actors were good, but overall the movie was boring.\"]\n",
|
333 |
+
"encoding = tokenizer(text_batch, return_tensors='pt')\n",
|
334 |
+
"input_ids = encoding['input_ids'].to(\"cuda\")\n",
|
335 |
+
"attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
|
336 |
+
"\n",
|
337 |
+
"# generate an explanation for the input\n",
|
338 |
+
"expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n",
|
339 |
+
"# normalize scores\n",
|
340 |
+
"expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
|
341 |
+
"\n",
|
342 |
+
"# get the model classification\n",
|
343 |
+
"output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
|
344 |
+
"classification = output.argmax(dim=-1).item()\n",
|
345 |
+
"# get class name\n",
|
346 |
+
"class_name = classifications[classification]\n",
|
347 |
+
"# if the classification is negative, higher explanation scores are more negative\n",
|
348 |
+
"# flip for visualization\n",
|
349 |
+
"if class_name == \"NEGATIVE\":\n",
|
350 |
+
" expl *= (-1)\n",
|
351 |
+
"\n",
|
352 |
+
"tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
|
353 |
+
"print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
|
354 |
+
"vis_data_records = [visualization.VisualizationDataRecord(\n",
|
355 |
+
" expl,\n",
|
356 |
+
" output[0][classification],\n",
|
357 |
+
" classification,\n",
|
358 |
+
" 1,\n",
|
359 |
+
" 1,\n",
|
360 |
+
" 1, \n",
|
361 |
+
" tokens,\n",
|
362 |
+
" 1)]\n",
|
363 |
+
"visualization.visualize_text(vis_data_records)"
|
364 |
+
],
|
365 |
+
"execution_count": 13,
|
366 |
+
"outputs": [
|
367 |
+
{
|
368 |
+
"output_type": "stream",
|
369 |
+
"name": "stdout",
|
370 |
+
"text": [
|
371 |
+
"[('[CLS]', -0.0), ('i', -0.19109757244586945), ('really', -0.1888734996318817), ('didn', -0.2894313633441925), (\"'\", -0.006574898026883602), ('t', -0.36788827180862427), ('like', -0.15249046683311462), ('this', -0.18922168016433716), ('movie', -0.0404353104531765), ('.', -0.019592661410570145), ('some', -0.02311306819319725), ('of', -0.0), ('the', -0.02295113168656826), ('actors', -0.09577538073062897), ('were', -0.013370633125305176), ('good', -0.0323222391307354), (',', -0.004366681911051273), ('but', -0.05878860130906105), ('overall', -0.33596664667129517), ('the', -0.21820111572742462), ('movie', -0.05482065677642822), ('was', -0.6248231530189514), ('boring', -1.0), ('.', -0.031107747927308083), ('[SEP]', -0.052539654076099396)]\n"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"output_type": "display_data",
|
376 |
+
"data": {
|
377 |
+
"text/plain": [
|
378 |
+
"<IPython.core.display.HTML object>"
|
379 |
+
],
|
380 |
+
"text/html": [
|
381 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> really </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> didn </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ' </font></mark><mark style=\"background-color: hsl(0, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> t </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0; line-height:1.75\"><font color=\"black\"> like </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> actors </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> good </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> overall </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 76%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> boring </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
"metadata": {}
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"output_type": "execute_result",
|
388 |
+
"data": {
|
389 |
+
"text/plain": [
|
390 |
+
"<IPython.core.display.HTML object>"
|
391 |
+
],
|
392 |
+
"text/html": [
|
393 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> really </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> didn </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ' </font></mark><mark style=\"background-color: hsl(0, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> t </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0; line-height:1.75\"><font color=\"black\"> like </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> actors </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> good </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> overall </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 76%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> boring </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
"metadata": {},
|
397 |
+
"execution_count": 13
|
398 |
+
}
|
399 |
+
]
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "markdown",
|
403 |
+
"source": [
|
404 |
+
"# Choosing class for visualization example"
|
405 |
+
],
|
406 |
+
"metadata": {
|
407 |
+
"id": "UUn2_SMPNG-Y"
|
408 |
+
}
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"source": [
|
413 |
+
"# encode a sentence\n",
|
414 |
+
"text_batch = [\"I hate that I love you.\"]\n",
|
415 |
+
"encoding = tokenizer(text_batch, return_tensors='pt')\n",
|
416 |
+
"input_ids = encoding['input_ids'].to(\"cuda\")\n",
|
417 |
+
"attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
|
418 |
+
"\n",
|
419 |
+
"# true class is positive - 1\n",
|
420 |
+
"true_class = 1\n",
|
421 |
+
"\n",
|
422 |
+
"# generate an explanation for the input\n",
|
423 |
+
"target_class = 0\n",
|
424 |
+
"expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n",
|
425 |
+
"# normalize scores\n",
|
426 |
+
"expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
|
427 |
+
"\n",
|
428 |
+
"# get the model classification\n",
|
429 |
+
"output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
|
430 |
+
"\n",
|
431 |
+
"# get class name\n",
|
432 |
+
"class_name = classifications[target_class]\n",
|
433 |
+
"# if the classification is negative, higher explanation scores are more negative\n",
|
434 |
+
"# flip for visualization\n",
|
435 |
+
"if class_name == \"NEGATIVE\":\n",
|
436 |
+
" expl *= (-1)\n",
|
437 |
+
"\n",
|
438 |
+
"tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
|
439 |
+
"print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
|
440 |
+
"vis_data_records = [visualization.VisualizationDataRecord(\n",
|
441 |
+
" expl,\n",
|
442 |
+
" output[0][classification],\n",
|
443 |
+
" classification,\n",
|
444 |
+
" true_class,\n",
|
445 |
+
" true_class,\n",
|
446 |
+
" 1, \n",
|
447 |
+
" tokens,\n",
|
448 |
+
" 1)]\n",
|
449 |
+
"visualization.visualize_text(vis_data_records)"
|
450 |
+
],
|
451 |
+
"metadata": {
|
452 |
+
"id": "VQVmMFnzhPoV",
|
453 |
+
"outputId": "26a43f8a-340c-4821-b39c-80105a565810",
|
454 |
+
"colab": {
|
455 |
+
"base_uri": "https://localhost:8080/",
|
456 |
+
"height": 219
|
457 |
+
}
|
458 |
+
},
|
459 |
+
"execution_count": 14,
|
460 |
+
"outputs": [
|
461 |
+
{
|
462 |
+
"output_type": "stream",
|
463 |
+
"name": "stdout",
|
464 |
+
"text": [
|
465 |
+
"[('[CLS]', -0.0), ('i', -0.19790242612361908), ('hate', -1.0), ('that', -0.40287283062934875), ('i', -0.12505637109279633), ('love', -0.1307140290737152), ('you', -0.05467141419649124), ('.', -6.108225989009952e-06), ('[SEP]', -0.0)]\n"
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"output_type": "display_data",
|
470 |
+
"data": {
|
471 |
+
"text/plain": [
|
472 |
+
"<IPython.core.display.HTML object>"
|
473 |
+
],
|
474 |
+
"text/html": [
|
475 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(0, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
476 |
+
]
|
477 |
+
},
|
478 |
+
"metadata": {}
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"output_type": "execute_result",
|
482 |
+
"data": {
|
483 |
+
"text/plain": [
|
484 |
+
"<IPython.core.display.HTML object>"
|
485 |
+
],
|
486 |
+
"text/html": [
|
487 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(0, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
"metadata": {},
|
491 |
+
"execution_count": 14
|
492 |
+
}
|
493 |
+
]
|
494 |
+
},
|
495 |
+
{
|
496 |
+
"cell_type": "code",
|
497 |
+
"source": [
|
498 |
+
"# encode a sentence\n",
|
499 |
+
"text_batch = [\"I hate that I love you.\"]\n",
|
500 |
+
"encoding = tokenizer(text_batch, return_tensors='pt')\n",
|
501 |
+
"input_ids = encoding['input_ids'].to(\"cuda\")\n",
|
502 |
+
"attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
|
503 |
+
"\n",
|
504 |
+
"# true class is positive - 1\n",
|
505 |
+
"true_class = 1\n",
|
506 |
+
"\n",
|
507 |
+
"# generate an explanation for the input\n",
|
508 |
+
"target_class = 1\n",
|
509 |
+
"expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n",
|
510 |
+
"# normalize scores\n",
|
511 |
+
"expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
|
512 |
+
"\n",
|
513 |
+
"# get the model classification\n",
|
514 |
+
"output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
|
515 |
+
"\n",
|
516 |
+
"# get class name\n",
|
517 |
+
"class_name = classifications[target_class]\n",
|
518 |
+
"# if the classification is negative, higher explanation scores are more negative\n",
|
519 |
+
"# flip for visualization\n",
|
520 |
+
"if class_name == \"NEGATIVE\":\n",
|
521 |
+
" expl *= (-1)\n",
|
522 |
+
"\n",
|
523 |
+
"tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
|
524 |
+
"print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
|
525 |
+
"vis_data_records = [visualization.VisualizationDataRecord(\n",
|
526 |
+
" expl,\n",
|
527 |
+
" output[0][classification],\n",
|
528 |
+
" classification,\n",
|
529 |
+
" true_class,\n",
|
530 |
+
" true_class,\n",
|
531 |
+
" 1, \n",
|
532 |
+
" tokens,\n",
|
533 |
+
" 1)]\n",
|
534 |
+
"visualization.visualize_text(vis_data_records)"
|
535 |
+
],
|
536 |
+
"metadata": {
|
537 |
+
"id": "WiQAWw0-imCg",
|
538 |
+
"outputId": "a8c66996-dcd0-4132-a8b0-2346d9bf9c7b",
|
539 |
+
"colab": {
|
540 |
+
"base_uri": "https://localhost:8080/",
|
541 |
+
"height": 219
|
542 |
+
}
|
543 |
+
},
|
544 |
+
"execution_count": 15,
|
545 |
+
"outputs": [
|
546 |
+
{
|
547 |
+
"output_type": "stream",
|
548 |
+
"name": "stdout",
|
549 |
+
"text": [
|
550 |
+
"[('[CLS]', 0.0), ('i', 0.2725590765476227), ('hate', 0.17270179092884064), ('that', 0.23211266100406647), ('i', 0.17642731964588165), ('love', 1.0), ('you', 0.2465524971485138), ('.', 0.0), ('[SEP]', 0.00015733683540020138)]\n"
|
551 |
+
]
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"output_type": "display_data",
|
555 |
+
"data": {
|
556 |
+
"text/plain": [
|
557 |
+
"<IPython.core.display.HTML object>"
|
558 |
+
],
|
559 |
+
"text/html": [
|
560 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(120, 75%, 88%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
561 |
+
]
|
562 |
+
},
|
563 |
+
"metadata": {}
|
564 |
+
},
|
565 |
+
{
|
566 |
+
"output_type": "execute_result",
|
567 |
+
"data": {
|
568 |
+
"text/plain": [
|
569 |
+
"<IPython.core.display.HTML object>"
|
570 |
+
],
|
571 |
+
"text/html": [
|
572 |
+
"<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(120, 75%, 88%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
|
573 |
+
]
|
574 |
+
},
|
575 |
+
"metadata": {},
|
576 |
+
"execution_count": 15
|
577 |
+
}
|
578 |
+
]
|
579 |
+
}
|
580 |
+
]
|
581 |
+
}
|
Transformer-Explainability/BERT_explainability/modules/BERT/BERT.py
ADDED
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from BERT_explainability.modules.layers_ours import *
|
8 |
+
from torch import nn
|
9 |
+
from transformers import BertConfig, BertPreTrainedModel, PreTrainedModel
|
10 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
11 |
+
BaseModelOutputWithPooling)
|
12 |
+
|
13 |
+
ACT2FN = {
|
14 |
+
"relu": ReLU,
|
15 |
+
"tanh": Tanh,
|
16 |
+
"gelu": GELU,
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def get_activation(activation_string):
|
21 |
+
if activation_string in ACT2FN:
|
22 |
+
return ACT2FN[activation_string]
|
23 |
+
else:
|
24 |
+
raise KeyError(
|
25 |
+
"function {} not found in ACT2FN mapping {}".format(
|
26 |
+
activation_string, list(ACT2FN.keys())
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
32 |
+
# adding residual consideration
|
33 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
34 |
+
batch_size = all_layer_matrices[0].shape[0]
|
35 |
+
eye = (
|
36 |
+
torch.eye(num_tokens)
|
37 |
+
.expand(batch_size, num_tokens, num_tokens)
|
38 |
+
.to(all_layer_matrices[0].device)
|
39 |
+
)
|
40 |
+
all_layer_matrices = [
|
41 |
+
all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
|
42 |
+
]
|
43 |
+
all_layer_matrices = [
|
44 |
+
all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
45 |
+
for i in range(len(all_layer_matrices))
|
46 |
+
]
|
47 |
+
joint_attention = all_layer_matrices[start_layer]
|
48 |
+
for i in range(start_layer + 1, len(all_layer_matrices)):
|
49 |
+
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
50 |
+
return joint_attention
|
51 |
+
|
52 |
+
|
53 |
+
class BertEmbeddings(nn.Module):
|
54 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
55 |
+
|
56 |
+
def __init__(self, config):
|
57 |
+
super().__init__()
|
58 |
+
self.word_embeddings = nn.Embedding(
|
59 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
60 |
+
)
|
61 |
+
self.position_embeddings = nn.Embedding(
|
62 |
+
config.max_position_embeddings, config.hidden_size
|
63 |
+
)
|
64 |
+
self.token_type_embeddings = nn.Embedding(
|
65 |
+
config.type_vocab_size, config.hidden_size
|
66 |
+
)
|
67 |
+
|
68 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
69 |
+
# any TensorFlow checkpoint file
|
70 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
71 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
72 |
+
|
73 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
74 |
+
self.register_buffer(
|
75 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
76 |
+
)
|
77 |
+
|
78 |
+
self.add1 = Add()
|
79 |
+
self.add2 = Add()
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
|
83 |
+
):
|
84 |
+
if input_ids is not None:
|
85 |
+
input_shape = input_ids.size()
|
86 |
+
else:
|
87 |
+
input_shape = inputs_embeds.size()[:-1]
|
88 |
+
|
89 |
+
seq_length = input_shape[1]
|
90 |
+
|
91 |
+
if position_ids is None:
|
92 |
+
position_ids = self.position_ids[:, :seq_length]
|
93 |
+
|
94 |
+
if token_type_ids is None:
|
95 |
+
token_type_ids = torch.zeros(
|
96 |
+
input_shape, dtype=torch.long, device=self.position_ids.device
|
97 |
+
)
|
98 |
+
|
99 |
+
if inputs_embeds is None:
|
100 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
101 |
+
position_embeddings = self.position_embeddings(position_ids)
|
102 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
103 |
+
|
104 |
+
# embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
105 |
+
embeddings = self.add1([token_type_embeddings, position_embeddings])
|
106 |
+
embeddings = self.add2([embeddings, inputs_embeds])
|
107 |
+
embeddings = self.LayerNorm(embeddings)
|
108 |
+
embeddings = self.dropout(embeddings)
|
109 |
+
return embeddings
|
110 |
+
|
111 |
+
def relprop(self, cam, **kwargs):
|
112 |
+
cam = self.dropout.relprop(cam, **kwargs)
|
113 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
114 |
+
|
115 |
+
# [inputs_embeds, position_embeddings, token_type_embeddings]
|
116 |
+
(cam) = self.add2.relprop(cam, **kwargs)
|
117 |
+
|
118 |
+
return cam
|
119 |
+
|
120 |
+
|
121 |
+
class BertEncoder(nn.Module):
|
122 |
+
def __init__(self, config):
|
123 |
+
super().__init__()
|
124 |
+
self.config = config
|
125 |
+
self.layer = nn.ModuleList(
|
126 |
+
[BertLayer(config) for _ in range(config.num_hidden_layers)]
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
hidden_states,
|
132 |
+
attention_mask=None,
|
133 |
+
head_mask=None,
|
134 |
+
encoder_hidden_states=None,
|
135 |
+
encoder_attention_mask=None,
|
136 |
+
output_attentions=False,
|
137 |
+
output_hidden_states=False,
|
138 |
+
return_dict=False,
|
139 |
+
):
|
140 |
+
all_hidden_states = () if output_hidden_states else None
|
141 |
+
all_attentions = () if output_attentions else None
|
142 |
+
for i, layer_module in enumerate(self.layer):
|
143 |
+
if output_hidden_states:
|
144 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
145 |
+
|
146 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
147 |
+
|
148 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
149 |
+
|
150 |
+
def create_custom_forward(module):
|
151 |
+
def custom_forward(*inputs):
|
152 |
+
return module(*inputs, output_attentions)
|
153 |
+
|
154 |
+
return custom_forward
|
155 |
+
|
156 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
157 |
+
create_custom_forward(layer_module),
|
158 |
+
hidden_states,
|
159 |
+
attention_mask,
|
160 |
+
layer_head_mask,
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
layer_outputs = layer_module(
|
164 |
+
hidden_states,
|
165 |
+
attention_mask,
|
166 |
+
layer_head_mask,
|
167 |
+
output_attentions,
|
168 |
+
)
|
169 |
+
hidden_states = layer_outputs[0]
|
170 |
+
if output_attentions:
|
171 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
172 |
+
|
173 |
+
if output_hidden_states:
|
174 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
175 |
+
|
176 |
+
if not return_dict:
|
177 |
+
return tuple(
|
178 |
+
v
|
179 |
+
for v in [hidden_states, all_hidden_states, all_attentions]
|
180 |
+
if v is not None
|
181 |
+
)
|
182 |
+
return BaseModelOutput(
|
183 |
+
last_hidden_state=hidden_states,
|
184 |
+
hidden_states=all_hidden_states,
|
185 |
+
attentions=all_attentions,
|
186 |
+
)
|
187 |
+
|
188 |
+
def relprop(self, cam, **kwargs):
|
189 |
+
# assuming output_hidden_states is False
|
190 |
+
for layer_module in reversed(self.layer):
|
191 |
+
cam = layer_module.relprop(cam, **kwargs)
|
192 |
+
return cam
|
193 |
+
|
194 |
+
|
195 |
+
# not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
|
196 |
+
class BertPooler(nn.Module):
|
197 |
+
def __init__(self, config):
|
198 |
+
super().__init__()
|
199 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
200 |
+
self.activation = Tanh()
|
201 |
+
self.pool = IndexSelect()
|
202 |
+
|
203 |
+
def forward(self, hidden_states):
|
204 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
205 |
+
# to the first token.
|
206 |
+
self._seq_size = hidden_states.shape[1]
|
207 |
+
|
208 |
+
# first_token_tensor = hidden_states[:, 0]
|
209 |
+
first_token_tensor = self.pool(
|
210 |
+
hidden_states, 1, torch.tensor(0, device=hidden_states.device)
|
211 |
+
)
|
212 |
+
first_token_tensor = first_token_tensor.squeeze(1)
|
213 |
+
pooled_output = self.dense(first_token_tensor)
|
214 |
+
pooled_output = self.activation(pooled_output)
|
215 |
+
return pooled_output
|
216 |
+
|
217 |
+
def relprop(self, cam, **kwargs):
|
218 |
+
cam = self.activation.relprop(cam, **kwargs)
|
219 |
+
# print(cam.sum())
|
220 |
+
cam = self.dense.relprop(cam, **kwargs)
|
221 |
+
# print(cam.sum())
|
222 |
+
cam = cam.unsqueeze(1)
|
223 |
+
cam = self.pool.relprop(cam, **kwargs)
|
224 |
+
# print(cam.sum())
|
225 |
+
|
226 |
+
return cam
|
227 |
+
|
228 |
+
|
229 |
+
class BertAttention(nn.Module):
|
230 |
+
def __init__(self, config):
|
231 |
+
super().__init__()
|
232 |
+
self.self = BertSelfAttention(config)
|
233 |
+
self.output = BertSelfOutput(config)
|
234 |
+
self.pruned_heads = set()
|
235 |
+
self.clone = Clone()
|
236 |
+
|
237 |
+
def prune_heads(self, heads):
|
238 |
+
if len(heads) == 0:
|
239 |
+
return
|
240 |
+
heads, index = find_pruneable_heads_and_indices(
|
241 |
+
heads,
|
242 |
+
self.self.num_attention_heads,
|
243 |
+
self.self.attention_head_size,
|
244 |
+
self.pruned_heads,
|
245 |
+
)
|
246 |
+
|
247 |
+
# Prune linear layers
|
248 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
249 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
250 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
251 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
252 |
+
|
253 |
+
# Update hyper params and store pruned heads
|
254 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
255 |
+
self.self.all_head_size = (
|
256 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
257 |
+
)
|
258 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
259 |
+
|
260 |
+
def forward(
|
261 |
+
self,
|
262 |
+
hidden_states,
|
263 |
+
attention_mask=None,
|
264 |
+
head_mask=None,
|
265 |
+
encoder_hidden_states=None,
|
266 |
+
encoder_attention_mask=None,
|
267 |
+
output_attentions=False,
|
268 |
+
):
|
269 |
+
h1, h2 = self.clone(hidden_states, 2)
|
270 |
+
self_outputs = self.self(
|
271 |
+
h1,
|
272 |
+
attention_mask,
|
273 |
+
head_mask,
|
274 |
+
encoder_hidden_states,
|
275 |
+
encoder_attention_mask,
|
276 |
+
output_attentions,
|
277 |
+
)
|
278 |
+
attention_output = self.output(self_outputs[0], h2)
|
279 |
+
outputs = (attention_output,) + self_outputs[
|
280 |
+
1:
|
281 |
+
] # add attentions if we output them
|
282 |
+
return outputs
|
283 |
+
|
284 |
+
def relprop(self, cam, **kwargs):
|
285 |
+
# assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
|
286 |
+
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
287 |
+
# print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
|
288 |
+
cam1 = self.self.relprop(cam1, **kwargs)
|
289 |
+
# print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
|
290 |
+
|
291 |
+
return self.clone.relprop((cam1, cam2), **kwargs)
|
292 |
+
|
293 |
+
|
294 |
+
class BertSelfAttention(nn.Module):
|
295 |
+
def __init__(self, config):
|
296 |
+
super().__init__()
|
297 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
298 |
+
config, "embedding_size"
|
299 |
+
):
|
300 |
+
raise ValueError(
|
301 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
302 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
303 |
+
)
|
304 |
+
|
305 |
+
self.num_attention_heads = config.num_attention_heads
|
306 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
307 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
308 |
+
|
309 |
+
self.query = Linear(config.hidden_size, self.all_head_size)
|
310 |
+
self.key = Linear(config.hidden_size, self.all_head_size)
|
311 |
+
self.value = Linear(config.hidden_size, self.all_head_size)
|
312 |
+
|
313 |
+
self.dropout = Dropout(config.attention_probs_dropout_prob)
|
314 |
+
|
315 |
+
self.matmul1 = MatMul()
|
316 |
+
self.matmul2 = MatMul()
|
317 |
+
self.softmax = Softmax(dim=-1)
|
318 |
+
self.add = Add()
|
319 |
+
self.mul = Mul()
|
320 |
+
self.head_mask = None
|
321 |
+
self.attention_mask = None
|
322 |
+
self.clone = Clone()
|
323 |
+
|
324 |
+
self.attn_cam = None
|
325 |
+
self.attn = None
|
326 |
+
self.attn_gradients = None
|
327 |
+
|
328 |
+
def get_attn(self):
|
329 |
+
return self.attn
|
330 |
+
|
331 |
+
def save_attn(self, attn):
|
332 |
+
self.attn = attn
|
333 |
+
|
334 |
+
def save_attn_cam(self, cam):
|
335 |
+
self.attn_cam = cam
|
336 |
+
|
337 |
+
def get_attn_cam(self):
|
338 |
+
return self.attn_cam
|
339 |
+
|
340 |
+
def save_attn_gradients(self, attn_gradients):
|
341 |
+
self.attn_gradients = attn_gradients
|
342 |
+
|
343 |
+
def get_attn_gradients(self):
|
344 |
+
return self.attn_gradients
|
345 |
+
|
346 |
+
def transpose_for_scores(self, x):
|
347 |
+
new_x_shape = x.size()[:-1] + (
|
348 |
+
self.num_attention_heads,
|
349 |
+
self.attention_head_size,
|
350 |
+
)
|
351 |
+
x = x.view(*new_x_shape)
|
352 |
+
return x.permute(0, 2, 1, 3)
|
353 |
+
|
354 |
+
def transpose_for_scores_relprop(self, x):
|
355 |
+
return x.permute(0, 2, 1, 3).flatten(2)
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
hidden_states,
|
360 |
+
attention_mask=None,
|
361 |
+
head_mask=None,
|
362 |
+
encoder_hidden_states=None,
|
363 |
+
encoder_attention_mask=None,
|
364 |
+
output_attentions=False,
|
365 |
+
):
|
366 |
+
self.head_mask = head_mask
|
367 |
+
self.attention_mask = attention_mask
|
368 |
+
|
369 |
+
h1, h2, h3 = self.clone(hidden_states, 3)
|
370 |
+
mixed_query_layer = self.query(h1)
|
371 |
+
|
372 |
+
# If this is instantiated as a cross-attention module, the keys
|
373 |
+
# and values come from an encoder; the attention mask needs to be
|
374 |
+
# such that the encoder's padding tokens are not attended to.
|
375 |
+
if encoder_hidden_states is not None:
|
376 |
+
mixed_key_layer = self.key(encoder_hidden_states)
|
377 |
+
mixed_value_layer = self.value(encoder_hidden_states)
|
378 |
+
attention_mask = encoder_attention_mask
|
379 |
+
else:
|
380 |
+
mixed_key_layer = self.key(h2)
|
381 |
+
mixed_value_layer = self.value(h3)
|
382 |
+
|
383 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
384 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
385 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
386 |
+
|
387 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
388 |
+
attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
|
389 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
390 |
+
if attention_mask is not None:
|
391 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
392 |
+
attention_scores = self.add([attention_scores, attention_mask])
|
393 |
+
|
394 |
+
# Normalize the attention scores to probabilities.
|
395 |
+
attention_probs = self.softmax(attention_scores)
|
396 |
+
|
397 |
+
self.save_attn(attention_probs)
|
398 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
399 |
+
|
400 |
+
# This is actually dropping out entire tokens to attend to, which might
|
401 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
402 |
+
attention_probs = self.dropout(attention_probs)
|
403 |
+
|
404 |
+
# Mask heads if we want to
|
405 |
+
if head_mask is not None:
|
406 |
+
attention_probs = attention_probs * head_mask
|
407 |
+
|
408 |
+
context_layer = self.matmul2([attention_probs, value_layer])
|
409 |
+
|
410 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
411 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
412 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
413 |
+
|
414 |
+
outputs = (
|
415 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
416 |
+
)
|
417 |
+
return outputs
|
418 |
+
|
419 |
+
def relprop(self, cam, **kwargs):
|
420 |
+
# Assume output_attentions == False
|
421 |
+
cam = self.transpose_for_scores(cam)
|
422 |
+
|
423 |
+
# [attention_probs, value_layer]
|
424 |
+
(cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
|
425 |
+
cam1 /= 2
|
426 |
+
cam2 /= 2
|
427 |
+
if self.head_mask is not None:
|
428 |
+
# [attention_probs, head_mask]
|
429 |
+
(cam1, _) = self.mul.relprop(cam1, **kwargs)
|
430 |
+
|
431 |
+
self.save_attn_cam(cam1)
|
432 |
+
|
433 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
434 |
+
|
435 |
+
cam1 = self.softmax.relprop(cam1, **kwargs)
|
436 |
+
|
437 |
+
if self.attention_mask is not None:
|
438 |
+
# [attention_scores, attention_mask]
|
439 |
+
(cam1, _) = self.add.relprop(cam1, **kwargs)
|
440 |
+
|
441 |
+
# [query_layer, key_layer.transpose(-1, -2)]
|
442 |
+
(cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
|
443 |
+
cam1_1 /= 2
|
444 |
+
cam1_2 /= 2
|
445 |
+
|
446 |
+
# query
|
447 |
+
cam1_1 = self.transpose_for_scores_relprop(cam1_1)
|
448 |
+
cam1_1 = self.query.relprop(cam1_1, **kwargs)
|
449 |
+
|
450 |
+
# key
|
451 |
+
cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
|
452 |
+
cam1_2 = self.key.relprop(cam1_2, **kwargs)
|
453 |
+
|
454 |
+
# value
|
455 |
+
cam2 = self.transpose_for_scores_relprop(cam2)
|
456 |
+
cam2 = self.value.relprop(cam2, **kwargs)
|
457 |
+
|
458 |
+
cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
|
459 |
+
|
460 |
+
return cam
|
461 |
+
|
462 |
+
|
463 |
+
class BertSelfOutput(nn.Module):
|
464 |
+
def __init__(self, config):
|
465 |
+
super().__init__()
|
466 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
467 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
468 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
469 |
+
self.add = Add()
|
470 |
+
|
471 |
+
def forward(self, hidden_states, input_tensor):
|
472 |
+
hidden_states = self.dense(hidden_states)
|
473 |
+
hidden_states = self.dropout(hidden_states)
|
474 |
+
add = self.add([hidden_states, input_tensor])
|
475 |
+
hidden_states = self.LayerNorm(add)
|
476 |
+
return hidden_states
|
477 |
+
|
478 |
+
def relprop(self, cam, **kwargs):
|
479 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
480 |
+
# [hidden_states, input_tensor]
|
481 |
+
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
482 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
483 |
+
cam1 = self.dense.relprop(cam1, **kwargs)
|
484 |
+
|
485 |
+
return (cam1, cam2)
|
486 |
+
|
487 |
+
|
488 |
+
class BertIntermediate(nn.Module):
|
489 |
+
def __init__(self, config):
|
490 |
+
super().__init__()
|
491 |
+
self.dense = Linear(config.hidden_size, config.intermediate_size)
|
492 |
+
if isinstance(config.hidden_act, str):
|
493 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]()
|
494 |
+
else:
|
495 |
+
self.intermediate_act_fn = config.hidden_act
|
496 |
+
|
497 |
+
def forward(self, hidden_states):
|
498 |
+
hidden_states = self.dense(hidden_states)
|
499 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
500 |
+
return hidden_states
|
501 |
+
|
502 |
+
def relprop(self, cam, **kwargs):
|
503 |
+
cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
|
504 |
+
# print(cam.sum())
|
505 |
+
cam = self.dense.relprop(cam, **kwargs)
|
506 |
+
# print(cam.sum())
|
507 |
+
return cam
|
508 |
+
|
509 |
+
|
510 |
+
class BertOutput(nn.Module):
|
511 |
+
def __init__(self, config):
|
512 |
+
super().__init__()
|
513 |
+
self.dense = Linear(config.intermediate_size, config.hidden_size)
|
514 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
515 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
516 |
+
self.add = Add()
|
517 |
+
|
518 |
+
def forward(self, hidden_states, input_tensor):
|
519 |
+
hidden_states = self.dense(hidden_states)
|
520 |
+
hidden_states = self.dropout(hidden_states)
|
521 |
+
add = self.add([hidden_states, input_tensor])
|
522 |
+
hidden_states = self.LayerNorm(add)
|
523 |
+
return hidden_states
|
524 |
+
|
525 |
+
def relprop(self, cam, **kwargs):
|
526 |
+
# print("in", cam.sum())
|
527 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
528 |
+
# print(cam.sum())
|
529 |
+
# [hidden_states, input_tensor]
|
530 |
+
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
531 |
+
# print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
|
532 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
533 |
+
# print(cam1.sum())
|
534 |
+
cam1 = self.dense.relprop(cam1, **kwargs)
|
535 |
+
# print("dense", cam1.sum())
|
536 |
+
|
537 |
+
# print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
|
538 |
+
return (cam1, cam2)
|
539 |
+
|
540 |
+
|
541 |
+
class BertLayer(nn.Module):
|
542 |
+
def __init__(self, config):
|
543 |
+
super().__init__()
|
544 |
+
self.attention = BertAttention(config)
|
545 |
+
self.intermediate = BertIntermediate(config)
|
546 |
+
self.output = BertOutput(config)
|
547 |
+
self.clone = Clone()
|
548 |
+
|
549 |
+
def forward(
|
550 |
+
self,
|
551 |
+
hidden_states,
|
552 |
+
attention_mask=None,
|
553 |
+
head_mask=None,
|
554 |
+
output_attentions=False,
|
555 |
+
):
|
556 |
+
self_attention_outputs = self.attention(
|
557 |
+
hidden_states,
|
558 |
+
attention_mask,
|
559 |
+
head_mask,
|
560 |
+
output_attentions=output_attentions,
|
561 |
+
)
|
562 |
+
attention_output = self_attention_outputs[0]
|
563 |
+
outputs = self_attention_outputs[
|
564 |
+
1:
|
565 |
+
] # add self attentions if we output attention weights
|
566 |
+
|
567 |
+
ao1, ao2 = self.clone(attention_output, 2)
|
568 |
+
intermediate_output = self.intermediate(ao1)
|
569 |
+
layer_output = self.output(intermediate_output, ao2)
|
570 |
+
|
571 |
+
outputs = (layer_output,) + outputs
|
572 |
+
return outputs
|
573 |
+
|
574 |
+
def relprop(self, cam, **kwargs):
|
575 |
+
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
576 |
+
# print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
|
577 |
+
cam1 = self.intermediate.relprop(cam1, **kwargs)
|
578 |
+
# print("intermediate", cam1.sum())
|
579 |
+
cam = self.clone.relprop((cam1, cam2), **kwargs)
|
580 |
+
# print("clone", cam.sum())
|
581 |
+
cam = self.attention.relprop(cam, **kwargs)
|
582 |
+
# print("attention", cam.sum())
|
583 |
+
return cam
|
584 |
+
|
585 |
+
|
586 |
+
class BertModel(BertPreTrainedModel):
|
587 |
+
def __init__(self, config):
|
588 |
+
super().__init__(config)
|
589 |
+
self.config = config
|
590 |
+
|
591 |
+
self.embeddings = BertEmbeddings(config)
|
592 |
+
self.encoder = BertEncoder(config)
|
593 |
+
self.pooler = BertPooler(config)
|
594 |
+
|
595 |
+
self.init_weights()
|
596 |
+
|
597 |
+
def get_input_embeddings(self):
|
598 |
+
return self.embeddings.word_embeddings
|
599 |
+
|
600 |
+
def set_input_embeddings(self, value):
|
601 |
+
self.embeddings.word_embeddings = value
|
602 |
+
|
603 |
+
def forward(
|
604 |
+
self,
|
605 |
+
input_ids=None,
|
606 |
+
attention_mask=None,
|
607 |
+
token_type_ids=None,
|
608 |
+
position_ids=None,
|
609 |
+
head_mask=None,
|
610 |
+
inputs_embeds=None,
|
611 |
+
encoder_hidden_states=None,
|
612 |
+
encoder_attention_mask=None,
|
613 |
+
output_attentions=None,
|
614 |
+
output_hidden_states=None,
|
615 |
+
return_dict=None,
|
616 |
+
):
|
617 |
+
r"""
|
618 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
619 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
620 |
+
if the model is configured as a decoder.
|
621 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
622 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
623 |
+
is used in the cross-attention if the model is configured as a decoder.
|
624 |
+
Mask values selected in ``[0, 1]``:
|
625 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
626 |
+
"""
|
627 |
+
output_attentions = (
|
628 |
+
output_attentions
|
629 |
+
if output_attentions is not None
|
630 |
+
else self.config.output_attentions
|
631 |
+
)
|
632 |
+
output_hidden_states = (
|
633 |
+
output_hidden_states
|
634 |
+
if output_hidden_states is not None
|
635 |
+
else self.config.output_hidden_states
|
636 |
+
)
|
637 |
+
return_dict = (
|
638 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
639 |
+
)
|
640 |
+
|
641 |
+
if input_ids is not None and inputs_embeds is not None:
|
642 |
+
raise ValueError(
|
643 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
644 |
+
)
|
645 |
+
elif input_ids is not None:
|
646 |
+
input_shape = input_ids.size()
|
647 |
+
elif inputs_embeds is not None:
|
648 |
+
input_shape = inputs_embeds.size()[:-1]
|
649 |
+
else:
|
650 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
651 |
+
|
652 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
653 |
+
|
654 |
+
if attention_mask is None:
|
655 |
+
attention_mask = torch.ones(input_shape, device=device)
|
656 |
+
if token_type_ids is None:
|
657 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
658 |
+
|
659 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
660 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
661 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
662 |
+
attention_mask, input_shape, device
|
663 |
+
)
|
664 |
+
|
665 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
666 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
667 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
668 |
+
(
|
669 |
+
encoder_batch_size,
|
670 |
+
encoder_sequence_length,
|
671 |
+
_,
|
672 |
+
) = encoder_hidden_states.size()
|
673 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
674 |
+
if encoder_attention_mask is None:
|
675 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
676 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
677 |
+
encoder_attention_mask
|
678 |
+
)
|
679 |
+
else:
|
680 |
+
encoder_extended_attention_mask = None
|
681 |
+
|
682 |
+
# Prepare head mask if needed
|
683 |
+
# 1.0 in head_mask indicate we keep the head
|
684 |
+
# attention_probs has shape bsz x n_heads x N x N
|
685 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
686 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
687 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
688 |
+
|
689 |
+
embedding_output = self.embeddings(
|
690 |
+
input_ids=input_ids,
|
691 |
+
position_ids=position_ids,
|
692 |
+
token_type_ids=token_type_ids,
|
693 |
+
inputs_embeds=inputs_embeds,
|
694 |
+
)
|
695 |
+
|
696 |
+
encoder_outputs = self.encoder(
|
697 |
+
embedding_output,
|
698 |
+
attention_mask=extended_attention_mask,
|
699 |
+
head_mask=head_mask,
|
700 |
+
encoder_hidden_states=encoder_hidden_states,
|
701 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
702 |
+
output_attentions=output_attentions,
|
703 |
+
output_hidden_states=output_hidden_states,
|
704 |
+
return_dict=return_dict,
|
705 |
+
)
|
706 |
+
sequence_output = encoder_outputs[0]
|
707 |
+
pooled_output = self.pooler(sequence_output)
|
708 |
+
|
709 |
+
if not return_dict:
|
710 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
711 |
+
|
712 |
+
return BaseModelOutputWithPooling(
|
713 |
+
last_hidden_state=sequence_output,
|
714 |
+
pooler_output=pooled_output,
|
715 |
+
hidden_states=encoder_outputs.hidden_states,
|
716 |
+
attentions=encoder_outputs.attentions,
|
717 |
+
)
|
718 |
+
|
719 |
+
def relprop(self, cam, **kwargs):
|
720 |
+
cam = self.pooler.relprop(cam, **kwargs)
|
721 |
+
# print("111111111111",cam.sum())
|
722 |
+
cam = self.encoder.relprop(cam, **kwargs)
|
723 |
+
# print("222222222222222", cam.sum())
|
724 |
+
# print("conservation: ", cam.sum())
|
725 |
+
return cam
|
726 |
+
|
727 |
+
|
728 |
+
if __name__ == "__main__":
|
729 |
+
|
730 |
+
class Config:
|
731 |
+
def __init__(
|
732 |
+
self, hidden_size, num_attention_heads, attention_probs_dropout_prob
|
733 |
+
):
|
734 |
+
self.hidden_size = hidden_size
|
735 |
+
self.num_attention_heads = num_attention_heads
|
736 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
737 |
+
|
738 |
+
model = BertSelfAttention(Config(1024, 4, 0.1))
|
739 |
+
x = torch.rand(2, 20, 1024)
|
740 |
+
x.requires_grad_()
|
741 |
+
|
742 |
+
model.eval()
|
743 |
+
|
744 |
+
y = model.forward(x)
|
745 |
+
|
746 |
+
relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
|
747 |
+
|
748 |
+
print(relprop[1][0].shape)
|
Transformer-Explainability/BERT_explainability/modules/BERT/BERT_cls_lrp.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from BERT_explainability.modules.BERT.BERT_orig_lrp import BertModel
|
6 |
+
from BERT_explainability.modules.layers_lrp import *
|
7 |
+
from BERT_rationale_benchmark.models.model_utils import PaddedSequence
|
8 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
9 |
+
from transformers import BertPreTrainedModel
|
10 |
+
from transformers.utils import logging
|
11 |
+
|
12 |
+
|
13 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__(config)
|
16 |
+
self.num_labels = config.num_labels
|
17 |
+
|
18 |
+
self.bert = BertModel(config)
|
19 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
20 |
+
self.classifier = Linear(config.hidden_size, config.num_labels)
|
21 |
+
|
22 |
+
self.init_weights()
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
input_ids=None,
|
27 |
+
attention_mask=None,
|
28 |
+
token_type_ids=None,
|
29 |
+
position_ids=None,
|
30 |
+
head_mask=None,
|
31 |
+
inputs_embeds=None,
|
32 |
+
labels=None,
|
33 |
+
output_attentions=None,
|
34 |
+
output_hidden_states=None,
|
35 |
+
return_dict=None,
|
36 |
+
):
|
37 |
+
r"""
|
38 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
39 |
+
Labels for computing the sequence classification/regression loss.
|
40 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
41 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
42 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
43 |
+
"""
|
44 |
+
return_dict = (
|
45 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
46 |
+
)
|
47 |
+
|
48 |
+
outputs = self.bert(
|
49 |
+
input_ids,
|
50 |
+
attention_mask=attention_mask,
|
51 |
+
token_type_ids=token_type_ids,
|
52 |
+
position_ids=position_ids,
|
53 |
+
head_mask=head_mask,
|
54 |
+
inputs_embeds=inputs_embeds,
|
55 |
+
output_attentions=output_attentions,
|
56 |
+
output_hidden_states=output_hidden_states,
|
57 |
+
return_dict=return_dict,
|
58 |
+
)
|
59 |
+
|
60 |
+
pooled_output = outputs[1]
|
61 |
+
|
62 |
+
pooled_output = self.dropout(pooled_output)
|
63 |
+
logits = self.classifier(pooled_output)
|
64 |
+
|
65 |
+
loss = None
|
66 |
+
if labels is not None:
|
67 |
+
if self.num_labels == 1:
|
68 |
+
# We are doing regression
|
69 |
+
loss_fct = MSELoss()
|
70 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
71 |
+
else:
|
72 |
+
loss_fct = CrossEntropyLoss()
|
73 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
74 |
+
|
75 |
+
if not return_dict:
|
76 |
+
output = (logits,) + outputs[2:]
|
77 |
+
return ((loss,) + output) if loss is not None else output
|
78 |
+
|
79 |
+
return SequenceClassifierOutput(
|
80 |
+
loss=loss,
|
81 |
+
logits=logits,
|
82 |
+
hidden_states=outputs.hidden_states,
|
83 |
+
attentions=outputs.attentions,
|
84 |
+
)
|
85 |
+
|
86 |
+
def relprop(self, cam=None, **kwargs):
|
87 |
+
cam = self.classifier.relprop(cam, **kwargs)
|
88 |
+
cam = self.dropout.relprop(cam, **kwargs)
|
89 |
+
cam = self.bert.relprop(cam, **kwargs)
|
90 |
+
return cam
|
91 |
+
|
92 |
+
|
93 |
+
# this is the actual classifier we will be using
|
94 |
+
class BertClassifier(nn.Module):
|
95 |
+
"""Thin wrapper around BertForSequenceClassification"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
bert_dir: str,
|
100 |
+
pad_token_id: int,
|
101 |
+
cls_token_id: int,
|
102 |
+
sep_token_id: int,
|
103 |
+
num_labels: int,
|
104 |
+
max_length: int = 512,
|
105 |
+
use_half_precision=True,
|
106 |
+
):
|
107 |
+
super(BertClassifier, self).__init__()
|
108 |
+
bert = BertForSequenceClassification.from_pretrained(
|
109 |
+
bert_dir, num_labels=num_labels
|
110 |
+
)
|
111 |
+
if use_half_precision:
|
112 |
+
import apex
|
113 |
+
|
114 |
+
bert = bert.half()
|
115 |
+
self.bert = bert
|
116 |
+
self.pad_token_id = pad_token_id
|
117 |
+
self.cls_token_id = cls_token_id
|
118 |
+
self.sep_token_id = sep_token_id
|
119 |
+
self.max_length = max_length
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self,
|
123 |
+
query: List[torch.tensor],
|
124 |
+
docids: List[Any],
|
125 |
+
document_batch: List[torch.tensor],
|
126 |
+
):
|
127 |
+
assert len(query) == len(document_batch)
|
128 |
+
print(query)
|
129 |
+
# note about device management:
|
130 |
+
# since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
|
131 |
+
# we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
|
132 |
+
target_device = next(self.parameters()).device
|
133 |
+
cls_token = torch.tensor([self.cls_token_id]).to(
|
134 |
+
device=document_batch[0].device
|
135 |
+
)
|
136 |
+
sep_token = torch.tensor([self.sep_token_id]).to(
|
137 |
+
device=document_batch[0].device
|
138 |
+
)
|
139 |
+
input_tensors = []
|
140 |
+
position_ids = []
|
141 |
+
for q, d in zip(query, document_batch):
|
142 |
+
if len(q) + len(d) + 2 > self.max_length:
|
143 |
+
d = d[: (self.max_length - len(q) - 2)]
|
144 |
+
input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
|
145 |
+
position_ids.append(
|
146 |
+
torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))
|
147 |
+
)
|
148 |
+
bert_input = PaddedSequence.autopad(
|
149 |
+
input_tensors,
|
150 |
+
batch_first=True,
|
151 |
+
padding_value=self.pad_token_id,
|
152 |
+
device=target_device,
|
153 |
+
)
|
154 |
+
positions = PaddedSequence.autopad(
|
155 |
+
position_ids, batch_first=True, padding_value=0, device=target_device
|
156 |
+
)
|
157 |
+
(classes,) = self.bert(
|
158 |
+
bert_input.data,
|
159 |
+
attention_mask=bert_input.mask(
|
160 |
+
on=0.0, off=float("-inf"), device=target_device
|
161 |
+
),
|
162 |
+
position_ids=positions.data,
|
163 |
+
)
|
164 |
+
assert torch.all(classes == classes) # for nans
|
165 |
+
|
166 |
+
print(input_tensors[0])
|
167 |
+
print(self.relprop()[0])
|
168 |
+
|
169 |
+
return classes
|
170 |
+
|
171 |
+
def relprop(self, cam=None, **kwargs):
|
172 |
+
return self.bert.relprop(cam, **kwargs)
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
import os
|
177 |
+
|
178 |
+
from transformers import BertTokenizer
|
179 |
+
|
180 |
+
class Config:
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
hidden_size,
|
184 |
+
num_attention_heads,
|
185 |
+
attention_probs_dropout_prob,
|
186 |
+
num_labels,
|
187 |
+
hidden_dropout_prob,
|
188 |
+
):
|
189 |
+
self.hidden_size = hidden_size
|
190 |
+
self.num_attention_heads = num_attention_heads
|
191 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
192 |
+
self.num_labels = num_labels
|
193 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
194 |
+
|
195 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
196 |
+
x = tokenizer.encode_plus(
|
197 |
+
"In this movie the acting is great. The movie is perfect! [sep]",
|
198 |
+
add_special_tokens=True,
|
199 |
+
max_length=512,
|
200 |
+
return_token_type_ids=False,
|
201 |
+
return_attention_mask=True,
|
202 |
+
pad_to_max_length=True,
|
203 |
+
return_tensors="pt",
|
204 |
+
truncation=True,
|
205 |
+
)
|
206 |
+
|
207 |
+
print(x["input_ids"])
|
208 |
+
|
209 |
+
model = BertForSequenceClassification.from_pretrained(
|
210 |
+
"bert-base-uncased", num_labels=2
|
211 |
+
)
|
212 |
+
model_save_file = os.path.join(
|
213 |
+
"./BERT_explainability/output_bert/movies/classifier/", "classifier.pt"
|
214 |
+
)
|
215 |
+
model.load_state_dict(torch.load(model_save_file))
|
216 |
+
|
217 |
+
# x = torch.randint(100, (2, 20))
|
218 |
+
# x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
|
219 |
+
# 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
|
220 |
+
# 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
|
221 |
+
# 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
|
222 |
+
# 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
|
223 |
+
# 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
|
224 |
+
# 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
|
225 |
+
# 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
|
226 |
+
# 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
|
227 |
+
# 102, 101, 1012, 102]])
|
228 |
+
# x.requires_grad_()
|
229 |
+
|
230 |
+
model.eval()
|
231 |
+
|
232 |
+
y = model(x["input_ids"], x["attention_mask"])
|
233 |
+
print(y)
|
234 |
+
|
235 |
+
cam, _ = model.relprop()
|
236 |
+
|
237 |
+
# print(cam.shape)
|
238 |
+
|
239 |
+
cam = cam.sum(-1)
|
240 |
+
# print(cam)
|
Transformer-Explainability/BERT_explainability/modules/BERT/BERT_orig_lrp.py
ADDED
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from BERT_explainability.modules.layers_lrp import *
|
8 |
+
from torch import nn
|
9 |
+
from transformers import BertConfig, BertPreTrainedModel, PreTrainedModel
|
10 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
11 |
+
BaseModelOutputWithPooling)
|
12 |
+
|
13 |
+
ACT2FN = {
|
14 |
+
"relu": ReLU,
|
15 |
+
"tanh": Tanh,
|
16 |
+
"gelu": GELU,
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def get_activation(activation_string):
|
21 |
+
if activation_string in ACT2FN:
|
22 |
+
return ACT2FN[activation_string]
|
23 |
+
else:
|
24 |
+
raise KeyError(
|
25 |
+
"function {} not found in ACT2FN mapping {}".format(
|
26 |
+
activation_string, list(ACT2FN.keys())
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
32 |
+
# adding residual consideration
|
33 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
34 |
+
batch_size = all_layer_matrices[0].shape[0]
|
35 |
+
eye = (
|
36 |
+
torch.eye(num_tokens)
|
37 |
+
.expand(batch_size, num_tokens, num_tokens)
|
38 |
+
.to(all_layer_matrices[0].device)
|
39 |
+
)
|
40 |
+
all_layer_matrices = [
|
41 |
+
all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
|
42 |
+
]
|
43 |
+
all_layer_matrices = [
|
44 |
+
all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
45 |
+
for i in range(len(all_layer_matrices))
|
46 |
+
]
|
47 |
+
joint_attention = all_layer_matrices[start_layer]
|
48 |
+
for i in range(start_layer + 1, len(all_layer_matrices)):
|
49 |
+
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
50 |
+
return joint_attention
|
51 |
+
|
52 |
+
|
53 |
+
class BertEmbeddings(nn.Module):
|
54 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
55 |
+
|
56 |
+
def __init__(self, config):
|
57 |
+
super().__init__()
|
58 |
+
self.word_embeddings = nn.Embedding(
|
59 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
60 |
+
)
|
61 |
+
self.position_embeddings = nn.Embedding(
|
62 |
+
config.max_position_embeddings, config.hidden_size
|
63 |
+
)
|
64 |
+
self.token_type_embeddings = nn.Embedding(
|
65 |
+
config.type_vocab_size, config.hidden_size
|
66 |
+
)
|
67 |
+
|
68 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
69 |
+
# any TensorFlow checkpoint file
|
70 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
71 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
72 |
+
|
73 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
74 |
+
self.register_buffer(
|
75 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
76 |
+
)
|
77 |
+
|
78 |
+
self.add1 = Add()
|
79 |
+
self.add2 = Add()
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
|
83 |
+
):
|
84 |
+
if input_ids is not None:
|
85 |
+
input_shape = input_ids.size()
|
86 |
+
else:
|
87 |
+
input_shape = inputs_embeds.size()[:-1]
|
88 |
+
|
89 |
+
seq_length = input_shape[1]
|
90 |
+
|
91 |
+
if position_ids is None:
|
92 |
+
position_ids = self.position_ids[:, :seq_length]
|
93 |
+
|
94 |
+
if token_type_ids is None:
|
95 |
+
token_type_ids = torch.zeros(
|
96 |
+
input_shape, dtype=torch.long, device=self.position_ids.device
|
97 |
+
)
|
98 |
+
|
99 |
+
if inputs_embeds is None:
|
100 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
101 |
+
position_embeddings = self.position_embeddings(position_ids)
|
102 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
103 |
+
|
104 |
+
# embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
105 |
+
embeddings = self.add1([token_type_embeddings, position_embeddings])
|
106 |
+
embeddings = self.add2([embeddings, inputs_embeds])
|
107 |
+
embeddings = self.LayerNorm(embeddings)
|
108 |
+
embeddings = self.dropout(embeddings)
|
109 |
+
return embeddings
|
110 |
+
|
111 |
+
def relprop(self, cam, **kwargs):
|
112 |
+
cam = self.dropout.relprop(cam, **kwargs)
|
113 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
114 |
+
|
115 |
+
# [inputs_embeds, position_embeddings, token_type_embeddings]
|
116 |
+
(cam) = self.add2.relprop(cam, **kwargs)
|
117 |
+
|
118 |
+
return cam
|
119 |
+
|
120 |
+
|
121 |
+
class BertEncoder(nn.Module):
|
122 |
+
def __init__(self, config):
|
123 |
+
super().__init__()
|
124 |
+
self.config = config
|
125 |
+
self.layer = nn.ModuleList(
|
126 |
+
[BertLayer(config) for _ in range(config.num_hidden_layers)]
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
hidden_states,
|
132 |
+
attention_mask=None,
|
133 |
+
head_mask=None,
|
134 |
+
encoder_hidden_states=None,
|
135 |
+
encoder_attention_mask=None,
|
136 |
+
output_attentions=False,
|
137 |
+
output_hidden_states=False,
|
138 |
+
return_dict=False,
|
139 |
+
):
|
140 |
+
all_hidden_states = () if output_hidden_states else None
|
141 |
+
all_attentions = () if output_attentions else None
|
142 |
+
for i, layer_module in enumerate(self.layer):
|
143 |
+
if output_hidden_states:
|
144 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
145 |
+
|
146 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
147 |
+
|
148 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
149 |
+
|
150 |
+
def create_custom_forward(module):
|
151 |
+
def custom_forward(*inputs):
|
152 |
+
return module(*inputs, output_attentions)
|
153 |
+
|
154 |
+
return custom_forward
|
155 |
+
|
156 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
157 |
+
create_custom_forward(layer_module),
|
158 |
+
hidden_states,
|
159 |
+
attention_mask,
|
160 |
+
layer_head_mask,
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
layer_outputs = layer_module(
|
164 |
+
hidden_states,
|
165 |
+
attention_mask,
|
166 |
+
layer_head_mask,
|
167 |
+
output_attentions,
|
168 |
+
)
|
169 |
+
hidden_states = layer_outputs[0]
|
170 |
+
if output_attentions:
|
171 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
172 |
+
|
173 |
+
if output_hidden_states:
|
174 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
175 |
+
|
176 |
+
if not return_dict:
|
177 |
+
return tuple(
|
178 |
+
v
|
179 |
+
for v in [hidden_states, all_hidden_states, all_attentions]
|
180 |
+
if v is not None
|
181 |
+
)
|
182 |
+
return BaseModelOutput(
|
183 |
+
last_hidden_state=hidden_states,
|
184 |
+
hidden_states=all_hidden_states,
|
185 |
+
attentions=all_attentions,
|
186 |
+
)
|
187 |
+
|
188 |
+
def relprop(self, cam, **kwargs):
|
189 |
+
# assuming output_hidden_states is False
|
190 |
+
for layer_module in reversed(self.layer):
|
191 |
+
cam = layer_module.relprop(cam, **kwargs)
|
192 |
+
return cam
|
193 |
+
|
194 |
+
|
195 |
+
# not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
|
196 |
+
class BertPooler(nn.Module):
|
197 |
+
def __init__(self, config):
|
198 |
+
super().__init__()
|
199 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
200 |
+
self.activation = Tanh()
|
201 |
+
self.pool = IndexSelect()
|
202 |
+
|
203 |
+
def forward(self, hidden_states):
|
204 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
205 |
+
# to the first token.
|
206 |
+
self._seq_size = hidden_states.shape[1]
|
207 |
+
|
208 |
+
# first_token_tensor = hidden_states[:, 0]
|
209 |
+
first_token_tensor = self.pool(
|
210 |
+
hidden_states, 1, torch.tensor(0, device=hidden_states.device)
|
211 |
+
)
|
212 |
+
first_token_tensor = first_token_tensor.squeeze(1)
|
213 |
+
pooled_output = self.dense(first_token_tensor)
|
214 |
+
pooled_output = self.activation(pooled_output)
|
215 |
+
return pooled_output
|
216 |
+
|
217 |
+
def relprop(self, cam, **kwargs):
|
218 |
+
cam = self.activation.relprop(cam, **kwargs)
|
219 |
+
# print(cam.sum())
|
220 |
+
cam = self.dense.relprop(cam, **kwargs)
|
221 |
+
# print(cam.sum())
|
222 |
+
cam = cam.unsqueeze(1)
|
223 |
+
cam = self.pool.relprop(cam, **kwargs)
|
224 |
+
# print(cam.sum())
|
225 |
+
|
226 |
+
return cam
|
227 |
+
|
228 |
+
|
229 |
+
class BertAttention(nn.Module):
|
230 |
+
def __init__(self, config):
|
231 |
+
super().__init__()
|
232 |
+
self.self = BertSelfAttention(config)
|
233 |
+
self.output = BertSelfOutput(config)
|
234 |
+
self.pruned_heads = set()
|
235 |
+
self.clone = Clone()
|
236 |
+
|
237 |
+
def prune_heads(self, heads):
|
238 |
+
if len(heads) == 0:
|
239 |
+
return
|
240 |
+
heads, index = find_pruneable_heads_and_indices(
|
241 |
+
heads,
|
242 |
+
self.self.num_attention_heads,
|
243 |
+
self.self.attention_head_size,
|
244 |
+
self.pruned_heads,
|
245 |
+
)
|
246 |
+
|
247 |
+
# Prune linear layers
|
248 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
249 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
250 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
251 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
252 |
+
|
253 |
+
# Update hyper params and store pruned heads
|
254 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
255 |
+
self.self.all_head_size = (
|
256 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
257 |
+
)
|
258 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
259 |
+
|
260 |
+
def forward(
|
261 |
+
self,
|
262 |
+
hidden_states,
|
263 |
+
attention_mask=None,
|
264 |
+
head_mask=None,
|
265 |
+
encoder_hidden_states=None,
|
266 |
+
encoder_attention_mask=None,
|
267 |
+
output_attentions=False,
|
268 |
+
):
|
269 |
+
h1, h2 = self.clone(hidden_states, 2)
|
270 |
+
self_outputs = self.self(
|
271 |
+
h1,
|
272 |
+
attention_mask,
|
273 |
+
head_mask,
|
274 |
+
encoder_hidden_states,
|
275 |
+
encoder_attention_mask,
|
276 |
+
output_attentions,
|
277 |
+
)
|
278 |
+
attention_output = self.output(self_outputs[0], h2)
|
279 |
+
outputs = (attention_output,) + self_outputs[
|
280 |
+
1:
|
281 |
+
] # add attentions if we output them
|
282 |
+
return outputs
|
283 |
+
|
284 |
+
def relprop(self, cam, **kwargs):
|
285 |
+
# assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
|
286 |
+
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
287 |
+
# print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
|
288 |
+
cam1 = self.self.relprop(cam1, **kwargs)
|
289 |
+
# print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
|
290 |
+
|
291 |
+
return self.clone.relprop((cam1, cam2), **kwargs)
|
292 |
+
|
293 |
+
|
294 |
+
class BertSelfAttention(nn.Module):
|
295 |
+
def __init__(self, config):
|
296 |
+
super().__init__()
|
297 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
298 |
+
config, "embedding_size"
|
299 |
+
):
|
300 |
+
raise ValueError(
|
301 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
302 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
303 |
+
)
|
304 |
+
|
305 |
+
self.num_attention_heads = config.num_attention_heads
|
306 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
307 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
308 |
+
|
309 |
+
self.query = Linear(config.hidden_size, self.all_head_size)
|
310 |
+
self.key = Linear(config.hidden_size, self.all_head_size)
|
311 |
+
self.value = Linear(config.hidden_size, self.all_head_size)
|
312 |
+
|
313 |
+
self.dropout = Dropout(config.attention_probs_dropout_prob)
|
314 |
+
|
315 |
+
self.matmul1 = MatMul()
|
316 |
+
self.matmul2 = MatMul()
|
317 |
+
self.softmax = Softmax(dim=-1)
|
318 |
+
self.add = Add()
|
319 |
+
self.mul = Mul()
|
320 |
+
self.head_mask = None
|
321 |
+
self.attention_mask = None
|
322 |
+
self.clone = Clone()
|
323 |
+
|
324 |
+
self.attn_cam = None
|
325 |
+
self.attn = None
|
326 |
+
self.attn_gradients = None
|
327 |
+
|
328 |
+
def get_attn(self):
|
329 |
+
return self.attn
|
330 |
+
|
331 |
+
def save_attn(self, attn):
|
332 |
+
self.attn = attn
|
333 |
+
|
334 |
+
def save_attn_cam(self, cam):
|
335 |
+
self.attn_cam = cam
|
336 |
+
|
337 |
+
def get_attn_cam(self):
|
338 |
+
return self.attn_cam
|
339 |
+
|
340 |
+
def save_attn_gradients(self, attn_gradients):
|
341 |
+
self.attn_gradients = attn_gradients
|
342 |
+
|
343 |
+
def get_attn_gradients(self):
|
344 |
+
return self.attn_gradients
|
345 |
+
|
346 |
+
def transpose_for_scores(self, x):
|
347 |
+
new_x_shape = x.size()[:-1] + (
|
348 |
+
self.num_attention_heads,
|
349 |
+
self.attention_head_size,
|
350 |
+
)
|
351 |
+
x = x.view(*new_x_shape)
|
352 |
+
return x.permute(0, 2, 1, 3)
|
353 |
+
|
354 |
+
def transpose_for_scores_relprop(self, x):
|
355 |
+
return x.permute(0, 2, 1, 3).flatten(2)
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
hidden_states,
|
360 |
+
attention_mask=None,
|
361 |
+
head_mask=None,
|
362 |
+
encoder_hidden_states=None,
|
363 |
+
encoder_attention_mask=None,
|
364 |
+
output_attentions=False,
|
365 |
+
):
|
366 |
+
self.head_mask = head_mask
|
367 |
+
self.attention_mask = attention_mask
|
368 |
+
|
369 |
+
h1, h2, h3 = self.clone(hidden_states, 3)
|
370 |
+
mixed_query_layer = self.query(h1)
|
371 |
+
|
372 |
+
# If this is instantiated as a cross-attention module, the keys
|
373 |
+
# and values come from an encoder; the attention mask needs to be
|
374 |
+
# such that the encoder's padding tokens are not attended to.
|
375 |
+
if encoder_hidden_states is not None:
|
376 |
+
mixed_key_layer = self.key(encoder_hidden_states)
|
377 |
+
mixed_value_layer = self.value(encoder_hidden_states)
|
378 |
+
attention_mask = encoder_attention_mask
|
379 |
+
else:
|
380 |
+
mixed_key_layer = self.key(h2)
|
381 |
+
mixed_value_layer = self.value(h3)
|
382 |
+
|
383 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
384 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
385 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
386 |
+
|
387 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
388 |
+
attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
|
389 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
390 |
+
if attention_mask is not None:
|
391 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
392 |
+
attention_scores = self.add([attention_scores, attention_mask])
|
393 |
+
|
394 |
+
# Normalize the attention scores to probabilities.
|
395 |
+
attention_probs = self.softmax(attention_scores)
|
396 |
+
|
397 |
+
self.save_attn(attention_probs)
|
398 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
399 |
+
|
400 |
+
# This is actually dropping out entire tokens to attend to, which might
|
401 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
402 |
+
attention_probs = self.dropout(attention_probs)
|
403 |
+
|
404 |
+
# Mask heads if we want to
|
405 |
+
if head_mask is not None:
|
406 |
+
attention_probs = attention_probs * head_mask
|
407 |
+
|
408 |
+
context_layer = self.matmul2([attention_probs, value_layer])
|
409 |
+
|
410 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
411 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
412 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
413 |
+
|
414 |
+
outputs = (
|
415 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
416 |
+
)
|
417 |
+
return outputs
|
418 |
+
|
419 |
+
def relprop(self, cam, **kwargs):
|
420 |
+
# Assume output_attentions == False
|
421 |
+
cam = self.transpose_for_scores(cam)
|
422 |
+
|
423 |
+
# [attention_probs, value_layer]
|
424 |
+
(cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
|
425 |
+
cam1 /= 2
|
426 |
+
cam2 /= 2
|
427 |
+
if self.head_mask is not None:
|
428 |
+
# [attention_probs, head_mask]
|
429 |
+
(cam1, _) = self.mul.relprop(cam1, **kwargs)
|
430 |
+
|
431 |
+
self.save_attn_cam(cam1)
|
432 |
+
|
433 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
434 |
+
|
435 |
+
cam1 = self.softmax.relprop(cam1, **kwargs)
|
436 |
+
|
437 |
+
if self.attention_mask is not None:
|
438 |
+
# [attention_scores, attention_mask]
|
439 |
+
(cam1, _) = self.add.relprop(cam1, **kwargs)
|
440 |
+
|
441 |
+
# [query_layer, key_layer.transpose(-1, -2)]
|
442 |
+
(cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
|
443 |
+
cam1_1 /= 2
|
444 |
+
cam1_2 /= 2
|
445 |
+
|
446 |
+
# query
|
447 |
+
cam1_1 = self.transpose_for_scores_relprop(cam1_1)
|
448 |
+
cam1_1 = self.query.relprop(cam1_1, **kwargs)
|
449 |
+
|
450 |
+
# key
|
451 |
+
cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
|
452 |
+
cam1_2 = self.key.relprop(cam1_2, **kwargs)
|
453 |
+
|
454 |
+
# value
|
455 |
+
cam2 = self.transpose_for_scores_relprop(cam2)
|
456 |
+
cam2 = self.value.relprop(cam2, **kwargs)
|
457 |
+
|
458 |
+
cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
|
459 |
+
|
460 |
+
return cam
|
461 |
+
|
462 |
+
|
463 |
+
class BertSelfOutput(nn.Module):
|
464 |
+
def __init__(self, config):
|
465 |
+
super().__init__()
|
466 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
467 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
468 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
469 |
+
self.add = Add()
|
470 |
+
|
471 |
+
def forward(self, hidden_states, input_tensor):
|
472 |
+
hidden_states = self.dense(hidden_states)
|
473 |
+
hidden_states = self.dropout(hidden_states)
|
474 |
+
add = self.add([hidden_states, input_tensor])
|
475 |
+
hidden_states = self.LayerNorm(add)
|
476 |
+
return hidden_states
|
477 |
+
|
478 |
+
def relprop(self, cam, **kwargs):
|
479 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
480 |
+
# [hidden_states, input_tensor]
|
481 |
+
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
482 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
483 |
+
cam1 = self.dense.relprop(cam1, **kwargs)
|
484 |
+
|
485 |
+
return (cam1, cam2)
|
486 |
+
|
487 |
+
|
488 |
+
class BertIntermediate(nn.Module):
|
489 |
+
def __init__(self, config):
|
490 |
+
super().__init__()
|
491 |
+
self.dense = Linear(config.hidden_size, config.intermediate_size)
|
492 |
+
if isinstance(config.hidden_act, str):
|
493 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]()
|
494 |
+
else:
|
495 |
+
self.intermediate_act_fn = config.hidden_act
|
496 |
+
|
497 |
+
def forward(self, hidden_states):
|
498 |
+
hidden_states = self.dense(hidden_states)
|
499 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
500 |
+
return hidden_states
|
501 |
+
|
502 |
+
def relprop(self, cam, **kwargs):
|
503 |
+
cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
|
504 |
+
# print(cam.sum())
|
505 |
+
cam = self.dense.relprop(cam, **kwargs)
|
506 |
+
# print(cam.sum())
|
507 |
+
return cam
|
508 |
+
|
509 |
+
|
510 |
+
class BertOutput(nn.Module):
|
511 |
+
def __init__(self, config):
|
512 |
+
super().__init__()
|
513 |
+
self.dense = Linear(config.intermediate_size, config.hidden_size)
|
514 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
515 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
516 |
+
self.add = Add()
|
517 |
+
|
518 |
+
def forward(self, hidden_states, input_tensor):
|
519 |
+
hidden_states = self.dense(hidden_states)
|
520 |
+
hidden_states = self.dropout(hidden_states)
|
521 |
+
add = self.add([hidden_states, input_tensor])
|
522 |
+
hidden_states = self.LayerNorm(add)
|
523 |
+
return hidden_states
|
524 |
+
|
525 |
+
def relprop(self, cam, **kwargs):
|
526 |
+
# print("in", cam.sum())
|
527 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
528 |
+
# print(cam.sum())
|
529 |
+
# [hidden_states, input_tensor]
|
530 |
+
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
531 |
+
# print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
|
532 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
533 |
+
# print(cam1.sum())
|
534 |
+
cam1 = self.dense.relprop(cam1, **kwargs)
|
535 |
+
# print("dense", cam1.sum())
|
536 |
+
|
537 |
+
# print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
|
538 |
+
return (cam1, cam2)
|
539 |
+
|
540 |
+
|
541 |
+
class BertLayer(nn.Module):
|
542 |
+
def __init__(self, config):
|
543 |
+
super().__init__()
|
544 |
+
self.attention = BertAttention(config)
|
545 |
+
self.intermediate = BertIntermediate(config)
|
546 |
+
self.output = BertOutput(config)
|
547 |
+
self.clone = Clone()
|
548 |
+
|
549 |
+
def forward(
|
550 |
+
self,
|
551 |
+
hidden_states,
|
552 |
+
attention_mask=None,
|
553 |
+
head_mask=None,
|
554 |
+
output_attentions=False,
|
555 |
+
):
|
556 |
+
self_attention_outputs = self.attention(
|
557 |
+
hidden_states,
|
558 |
+
attention_mask,
|
559 |
+
head_mask,
|
560 |
+
output_attentions=output_attentions,
|
561 |
+
)
|
562 |
+
attention_output = self_attention_outputs[0]
|
563 |
+
outputs = self_attention_outputs[
|
564 |
+
1:
|
565 |
+
] # add self attentions if we output attention weights
|
566 |
+
|
567 |
+
ao1, ao2 = self.clone(attention_output, 2)
|
568 |
+
intermediate_output = self.intermediate(ao1)
|
569 |
+
layer_output = self.output(intermediate_output, ao2)
|
570 |
+
|
571 |
+
outputs = (layer_output,) + outputs
|
572 |
+
return outputs
|
573 |
+
|
574 |
+
def relprop(self, cam, **kwargs):
|
575 |
+
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
576 |
+
# print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
|
577 |
+
cam1 = self.intermediate.relprop(cam1, **kwargs)
|
578 |
+
# print("intermediate", cam1.sum())
|
579 |
+
cam = self.clone.relprop((cam1, cam2), **kwargs)
|
580 |
+
# print("clone", cam.sum())
|
581 |
+
cam = self.attention.relprop(cam, **kwargs)
|
582 |
+
# print("attention", cam.sum())
|
583 |
+
return cam
|
584 |
+
|
585 |
+
|
586 |
+
class BertModel(BertPreTrainedModel):
|
587 |
+
def __init__(self, config):
|
588 |
+
super().__init__(config)
|
589 |
+
self.config = config
|
590 |
+
|
591 |
+
self.embeddings = BertEmbeddings(config)
|
592 |
+
self.encoder = BertEncoder(config)
|
593 |
+
self.pooler = BertPooler(config)
|
594 |
+
|
595 |
+
self.init_weights()
|
596 |
+
|
597 |
+
def get_input_embeddings(self):
|
598 |
+
return self.embeddings.word_embeddings
|
599 |
+
|
600 |
+
def set_input_embeddings(self, value):
|
601 |
+
self.embeddings.word_embeddings = value
|
602 |
+
|
603 |
+
def forward(
|
604 |
+
self,
|
605 |
+
input_ids=None,
|
606 |
+
attention_mask=None,
|
607 |
+
token_type_ids=None,
|
608 |
+
position_ids=None,
|
609 |
+
head_mask=None,
|
610 |
+
inputs_embeds=None,
|
611 |
+
encoder_hidden_states=None,
|
612 |
+
encoder_attention_mask=None,
|
613 |
+
output_attentions=None,
|
614 |
+
output_hidden_states=None,
|
615 |
+
return_dict=None,
|
616 |
+
):
|
617 |
+
r"""
|
618 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
619 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
620 |
+
if the model is configured as a decoder.
|
621 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
622 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
623 |
+
is used in the cross-attention if the model is configured as a decoder.
|
624 |
+
Mask values selected in ``[0, 1]``:
|
625 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
626 |
+
"""
|
627 |
+
output_attentions = (
|
628 |
+
output_attentions
|
629 |
+
if output_attentions is not None
|
630 |
+
else self.config.output_attentions
|
631 |
+
)
|
632 |
+
output_hidden_states = (
|
633 |
+
output_hidden_states
|
634 |
+
if output_hidden_states is not None
|
635 |
+
else self.config.output_hidden_states
|
636 |
+
)
|
637 |
+
return_dict = (
|
638 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
639 |
+
)
|
640 |
+
|
641 |
+
if input_ids is not None and inputs_embeds is not None:
|
642 |
+
raise ValueError(
|
643 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
644 |
+
)
|
645 |
+
elif input_ids is not None:
|
646 |
+
input_shape = input_ids.size()
|
647 |
+
elif inputs_embeds is not None:
|
648 |
+
input_shape = inputs_embeds.size()[:-1]
|
649 |
+
else:
|
650 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
651 |
+
|
652 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
653 |
+
|
654 |
+
if attention_mask is None:
|
655 |
+
attention_mask = torch.ones(input_shape, device=device)
|
656 |
+
if token_type_ids is None:
|
657 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
658 |
+
|
659 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
660 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
661 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
662 |
+
attention_mask, input_shape, device
|
663 |
+
)
|
664 |
+
|
665 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
666 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
667 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
668 |
+
(
|
669 |
+
encoder_batch_size,
|
670 |
+
encoder_sequence_length,
|
671 |
+
_,
|
672 |
+
) = encoder_hidden_states.size()
|
673 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
674 |
+
if encoder_attention_mask is None:
|
675 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
676 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
677 |
+
encoder_attention_mask
|
678 |
+
)
|
679 |
+
else:
|
680 |
+
encoder_extended_attention_mask = None
|
681 |
+
|
682 |
+
# Prepare head mask if needed
|
683 |
+
# 1.0 in head_mask indicate we keep the head
|
684 |
+
# attention_probs has shape bsz x n_heads x N x N
|
685 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
686 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
687 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
688 |
+
|
689 |
+
embedding_output = self.embeddings(
|
690 |
+
input_ids=input_ids,
|
691 |
+
position_ids=position_ids,
|
692 |
+
token_type_ids=token_type_ids,
|
693 |
+
inputs_embeds=inputs_embeds,
|
694 |
+
)
|
695 |
+
|
696 |
+
encoder_outputs = self.encoder(
|
697 |
+
embedding_output,
|
698 |
+
attention_mask=extended_attention_mask,
|
699 |
+
head_mask=head_mask,
|
700 |
+
encoder_hidden_states=encoder_hidden_states,
|
701 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
702 |
+
output_attentions=output_attentions,
|
703 |
+
output_hidden_states=output_hidden_states,
|
704 |
+
return_dict=return_dict,
|
705 |
+
)
|
706 |
+
sequence_output = encoder_outputs[0]
|
707 |
+
pooled_output = self.pooler(sequence_output)
|
708 |
+
|
709 |
+
if not return_dict:
|
710 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
711 |
+
|
712 |
+
return BaseModelOutputWithPooling(
|
713 |
+
last_hidden_state=sequence_output,
|
714 |
+
pooler_output=pooled_output,
|
715 |
+
hidden_states=encoder_outputs.hidden_states,
|
716 |
+
attentions=encoder_outputs.attentions,
|
717 |
+
)
|
718 |
+
|
719 |
+
def relprop(self, cam, **kwargs):
|
720 |
+
cam = self.pooler.relprop(cam, **kwargs)
|
721 |
+
# print("111111111111",cam.sum())
|
722 |
+
cam = self.encoder.relprop(cam, **kwargs)
|
723 |
+
# print("222222222222222", cam.sum())
|
724 |
+
# print("conservation: ", cam.sum())
|
725 |
+
return cam
|
726 |
+
|
727 |
+
|
728 |
+
if __name__ == "__main__":
|
729 |
+
|
730 |
+
class Config:
|
731 |
+
def __init__(
|
732 |
+
self, hidden_size, num_attention_heads, attention_probs_dropout_prob
|
733 |
+
):
|
734 |
+
self.hidden_size = hidden_size
|
735 |
+
self.num_attention_heads = num_attention_heads
|
736 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
737 |
+
|
738 |
+
model = BertSelfAttention(Config(1024, 4, 0.1))
|
739 |
+
x = torch.rand(2, 20, 1024)
|
740 |
+
x.requires_grad_()
|
741 |
+
|
742 |
+
model.eval()
|
743 |
+
|
744 |
+
y = model.forward(x)
|
745 |
+
|
746 |
+
relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
|
747 |
+
|
748 |
+
print(relprop[1][0].shape)
|
Transformer-Explainability/BERT_explainability/modules/BERT/BertForSequenceClassification.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from BERT_explainability.modules.BERT.BERT import BertModel
|
6 |
+
from BERT_explainability.modules.layers_ours import *
|
7 |
+
from BERT_rationale_benchmark.models.model_utils import PaddedSequence
|
8 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
9 |
+
from transformers import BertPreTrainedModel
|
10 |
+
from transformers.utils import logging
|
11 |
+
|
12 |
+
|
13 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__(config)
|
16 |
+
self.num_labels = config.num_labels
|
17 |
+
|
18 |
+
self.bert = BertModel(config)
|
19 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
20 |
+
self.classifier = Linear(config.hidden_size, config.num_labels)
|
21 |
+
|
22 |
+
self.init_weights()
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
input_ids=None,
|
27 |
+
attention_mask=None,
|
28 |
+
token_type_ids=None,
|
29 |
+
position_ids=None,
|
30 |
+
head_mask=None,
|
31 |
+
inputs_embeds=None,
|
32 |
+
labels=None,
|
33 |
+
output_attentions=None,
|
34 |
+
output_hidden_states=None,
|
35 |
+
return_dict=None,
|
36 |
+
):
|
37 |
+
r"""
|
38 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
39 |
+
Labels for computing the sequence classification/regression loss.
|
40 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
41 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
42 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
43 |
+
"""
|
44 |
+
return_dict = (
|
45 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
46 |
+
)
|
47 |
+
|
48 |
+
outputs = self.bert(
|
49 |
+
input_ids,
|
50 |
+
attention_mask=attention_mask,
|
51 |
+
token_type_ids=token_type_ids,
|
52 |
+
position_ids=position_ids,
|
53 |
+
head_mask=head_mask,
|
54 |
+
inputs_embeds=inputs_embeds,
|
55 |
+
output_attentions=output_attentions,
|
56 |
+
output_hidden_states=output_hidden_states,
|
57 |
+
return_dict=return_dict,
|
58 |
+
)
|
59 |
+
|
60 |
+
pooled_output = outputs[1]
|
61 |
+
|
62 |
+
pooled_output = self.dropout(pooled_output)
|
63 |
+
logits = self.classifier(pooled_output)
|
64 |
+
|
65 |
+
loss = None
|
66 |
+
if labels is not None:
|
67 |
+
if self.num_labels == 1:
|
68 |
+
# We are doing regression
|
69 |
+
loss_fct = MSELoss()
|
70 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
71 |
+
else:
|
72 |
+
loss_fct = CrossEntropyLoss()
|
73 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
74 |
+
|
75 |
+
if not return_dict:
|
76 |
+
output = (logits,) + outputs[2:]
|
77 |
+
return ((loss,) + output) if loss is not None else output
|
78 |
+
|
79 |
+
return SequenceClassifierOutput(
|
80 |
+
loss=loss,
|
81 |
+
logits=logits,
|
82 |
+
hidden_states=outputs.hidden_states,
|
83 |
+
attentions=outputs.attentions,
|
84 |
+
)
|
85 |
+
|
86 |
+
def relprop(self, cam=None, **kwargs):
|
87 |
+
cam = self.classifier.relprop(cam, **kwargs)
|
88 |
+
cam = self.dropout.relprop(cam, **kwargs)
|
89 |
+
cam = self.bert.relprop(cam, **kwargs)
|
90 |
+
# print("conservation: ", cam.sum())
|
91 |
+
return cam
|
92 |
+
|
93 |
+
|
94 |
+
# this is the actual classifier we will be using
|
95 |
+
class BertClassifier(nn.Module):
|
96 |
+
"""Thin wrapper around BertForSequenceClassification"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
bert_dir: str,
|
101 |
+
pad_token_id: int,
|
102 |
+
cls_token_id: int,
|
103 |
+
sep_token_id: int,
|
104 |
+
num_labels: int,
|
105 |
+
max_length: int = 512,
|
106 |
+
use_half_precision=True,
|
107 |
+
):
|
108 |
+
super(BertClassifier, self).__init__()
|
109 |
+
bert = BertForSequenceClassification.from_pretrained(
|
110 |
+
bert_dir, num_labels=num_labels
|
111 |
+
)
|
112 |
+
if use_half_precision:
|
113 |
+
import apex
|
114 |
+
|
115 |
+
bert = bert.half()
|
116 |
+
self.bert = bert
|
117 |
+
self.pad_token_id = pad_token_id
|
118 |
+
self.cls_token_id = cls_token_id
|
119 |
+
self.sep_token_id = sep_token_id
|
120 |
+
self.max_length = max_length
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self,
|
124 |
+
query: List[torch.tensor],
|
125 |
+
docids: List[Any],
|
126 |
+
document_batch: List[torch.tensor],
|
127 |
+
):
|
128 |
+
assert len(query) == len(document_batch)
|
129 |
+
print(query)
|
130 |
+
# note about device management:
|
131 |
+
# since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
|
132 |
+
# we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
|
133 |
+
target_device = next(self.parameters()).device
|
134 |
+
cls_token = torch.tensor([self.cls_token_id]).to(
|
135 |
+
device=document_batch[0].device
|
136 |
+
)
|
137 |
+
sep_token = torch.tensor([self.sep_token_id]).to(
|
138 |
+
device=document_batch[0].device
|
139 |
+
)
|
140 |
+
input_tensors = []
|
141 |
+
position_ids = []
|
142 |
+
for q, d in zip(query, document_batch):
|
143 |
+
if len(q) + len(d) + 2 > self.max_length:
|
144 |
+
d = d[: (self.max_length - len(q) - 2)]
|
145 |
+
input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
|
146 |
+
position_ids.append(
|
147 |
+
torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))
|
148 |
+
)
|
149 |
+
bert_input = PaddedSequence.autopad(
|
150 |
+
input_tensors,
|
151 |
+
batch_first=True,
|
152 |
+
padding_value=self.pad_token_id,
|
153 |
+
device=target_device,
|
154 |
+
)
|
155 |
+
positions = PaddedSequence.autopad(
|
156 |
+
position_ids, batch_first=True, padding_value=0, device=target_device
|
157 |
+
)
|
158 |
+
(classes,) = self.bert(
|
159 |
+
bert_input.data,
|
160 |
+
attention_mask=bert_input.mask(
|
161 |
+
on=0.0, off=float("-inf"), device=target_device
|
162 |
+
),
|
163 |
+
position_ids=positions.data,
|
164 |
+
)
|
165 |
+
assert torch.all(classes == classes) # for nans
|
166 |
+
|
167 |
+
print(input_tensors[0])
|
168 |
+
print(self.relprop()[0])
|
169 |
+
|
170 |
+
return classes
|
171 |
+
|
172 |
+
def relprop(self, cam=None, **kwargs):
|
173 |
+
return self.bert.relprop(cam, **kwargs)
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
import os
|
178 |
+
|
179 |
+
from transformers import BertTokenizer
|
180 |
+
|
181 |
+
class Config:
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
hidden_size,
|
185 |
+
num_attention_heads,
|
186 |
+
attention_probs_dropout_prob,
|
187 |
+
num_labels,
|
188 |
+
hidden_dropout_prob,
|
189 |
+
):
|
190 |
+
self.hidden_size = hidden_size
|
191 |
+
self.num_attention_heads = num_attention_heads
|
192 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
193 |
+
self.num_labels = num_labels
|
194 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
195 |
+
|
196 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
197 |
+
x = tokenizer.encode_plus(
|
198 |
+
"In this movie the acting is great. The movie is perfect! [sep]",
|
199 |
+
add_special_tokens=True,
|
200 |
+
max_length=512,
|
201 |
+
return_token_type_ids=False,
|
202 |
+
return_attention_mask=True,
|
203 |
+
pad_to_max_length=True,
|
204 |
+
return_tensors="pt",
|
205 |
+
truncation=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
print(x["input_ids"])
|
209 |
+
|
210 |
+
model = BertForSequenceClassification.from_pretrained(
|
211 |
+
"bert-base-uncased", num_labels=2
|
212 |
+
)
|
213 |
+
model_save_file = os.path.join(
|
214 |
+
"./BERT_explainability/output_bert/movies/classifier/", "classifier.pt"
|
215 |
+
)
|
216 |
+
model.load_state_dict(torch.load(model_save_file))
|
217 |
+
|
218 |
+
# x = torch.randint(100, (2, 20))
|
219 |
+
# x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
|
220 |
+
# 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
|
221 |
+
# 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
|
222 |
+
# 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
|
223 |
+
# 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
|
224 |
+
# 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
|
225 |
+
# 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
|
226 |
+
# 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
|
227 |
+
# 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
|
228 |
+
# 102, 101, 1012, 102]])
|
229 |
+
# x.requires_grad_()
|
230 |
+
|
231 |
+
model.eval()
|
232 |
+
|
233 |
+
y = model(x["input_ids"], x["attention_mask"])
|
234 |
+
print(y)
|
235 |
+
|
236 |
+
cam, _ = model.relprop()
|
237 |
+
|
238 |
+
# print(cam.shape)
|
239 |
+
|
240 |
+
cam = cam.sum(-1)
|
241 |
+
# print(cam)
|
Transformer-Explainability/BERT_explainability/modules/BERT/ExplanationGenerator.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
# compute rollout between attention layers
|
9 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
10 |
+
# adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
|
11 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
12 |
+
batch_size = all_layer_matrices[0].shape[0]
|
13 |
+
eye = (
|
14 |
+
torch.eye(num_tokens)
|
15 |
+
.expand(batch_size, num_tokens, num_tokens)
|
16 |
+
.to(all_layer_matrices[0].device)
|
17 |
+
)
|
18 |
+
all_layer_matrices = [
|
19 |
+
all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
|
20 |
+
]
|
21 |
+
matrices_aug = [
|
22 |
+
all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
23 |
+
for i in range(len(all_layer_matrices))
|
24 |
+
]
|
25 |
+
joint_attention = matrices_aug[start_layer]
|
26 |
+
for i in range(start_layer + 1, len(matrices_aug)):
|
27 |
+
joint_attention = matrices_aug[i].bmm(joint_attention)
|
28 |
+
return joint_attention
|
29 |
+
|
30 |
+
|
31 |
+
class Generator:
|
32 |
+
def __init__(self, model):
|
33 |
+
self.model = model
|
34 |
+
self.model.eval()
|
35 |
+
|
36 |
+
def forward(self, input_ids, attention_mask):
|
37 |
+
return self.model(input_ids, attention_mask)
|
38 |
+
|
39 |
+
def generate_LRP(self, input_ids, attention_mask, index=None, start_layer=11):
|
40 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
41 |
+
kwargs = {"alpha": 1}
|
42 |
+
|
43 |
+
if index == None:
|
44 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
45 |
+
|
46 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
47 |
+
one_hot[0, index] = 1
|
48 |
+
one_hot_vector = one_hot
|
49 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
50 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
51 |
+
|
52 |
+
self.model.zero_grad()
|
53 |
+
one_hot.backward(retain_graph=True)
|
54 |
+
|
55 |
+
self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
|
56 |
+
|
57 |
+
cams = []
|
58 |
+
blocks = self.model.bert.encoder.layer
|
59 |
+
for blk in blocks:
|
60 |
+
grad = blk.attention.self.get_attn_gradients()
|
61 |
+
cam = blk.attention.self.get_attn_cam()
|
62 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
63 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
64 |
+
cam = grad * cam
|
65 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
66 |
+
cams.append(cam.unsqueeze(0))
|
67 |
+
rollout = compute_rollout_attention(cams, start_layer=start_layer)
|
68 |
+
rollout[:, 0, 0] = rollout[:, 0].min()
|
69 |
+
return rollout[:, 0]
|
70 |
+
|
71 |
+
def generate_LRP_last_layer(self, input_ids, attention_mask, index=None):
|
72 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
73 |
+
kwargs = {"alpha": 1}
|
74 |
+
if index == None:
|
75 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
76 |
+
|
77 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
78 |
+
one_hot[0, index] = 1
|
79 |
+
one_hot_vector = one_hot
|
80 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
81 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
82 |
+
|
83 |
+
self.model.zero_grad()
|
84 |
+
one_hot.backward(retain_graph=True)
|
85 |
+
|
86 |
+
self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
|
87 |
+
|
88 |
+
cam = self.model.bert.encoder.layer[-1].attention.self.get_attn_cam()[0]
|
89 |
+
cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
|
90 |
+
cam[:, 0, 0] = 0
|
91 |
+
return cam[:, 0]
|
92 |
+
|
93 |
+
def generate_full_lrp(self, input_ids, attention_mask, index=None):
|
94 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
95 |
+
kwargs = {"alpha": 1}
|
96 |
+
|
97 |
+
if index == None:
|
98 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
99 |
+
|
100 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
101 |
+
one_hot[0, index] = 1
|
102 |
+
one_hot_vector = one_hot
|
103 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
104 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
105 |
+
|
106 |
+
self.model.zero_grad()
|
107 |
+
one_hot.backward(retain_graph=True)
|
108 |
+
|
109 |
+
cam = self.model.relprop(
|
110 |
+
torch.tensor(one_hot_vector).to(input_ids.device), **kwargs
|
111 |
+
)
|
112 |
+
cam = cam.sum(dim=2)
|
113 |
+
cam[:, 0] = 0
|
114 |
+
return cam
|
115 |
+
|
116 |
+
def generate_attn_last_layer(self, input_ids, attention_mask, index=None):
|
117 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
118 |
+
cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()[0]
|
119 |
+
cam = cam.mean(dim=0).unsqueeze(0)
|
120 |
+
cam[:, 0, 0] = 0
|
121 |
+
return cam[:, 0]
|
122 |
+
|
123 |
+
def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None):
|
124 |
+
self.model.zero_grad()
|
125 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
126 |
+
blocks = self.model.bert.encoder.layer
|
127 |
+
all_layer_attentions = []
|
128 |
+
for blk in blocks:
|
129 |
+
attn_heads = blk.attention.self.get_attn()
|
130 |
+
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
|
131 |
+
all_layer_attentions.append(avg_heads)
|
132 |
+
rollout = compute_rollout_attention(
|
133 |
+
all_layer_attentions, start_layer=start_layer
|
134 |
+
)
|
135 |
+
rollout[:, 0, 0] = 0
|
136 |
+
return rollout[:, 0]
|
137 |
+
|
138 |
+
def generate_attn_gradcam(self, input_ids, attention_mask, index=None):
|
139 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
140 |
+
kwargs = {"alpha": 1}
|
141 |
+
|
142 |
+
if index == None:
|
143 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
144 |
+
|
145 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
146 |
+
one_hot[0, index] = 1
|
147 |
+
one_hot_vector = one_hot
|
148 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
149 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
150 |
+
|
151 |
+
self.model.zero_grad()
|
152 |
+
one_hot.backward(retain_graph=True)
|
153 |
+
|
154 |
+
self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
|
155 |
+
|
156 |
+
cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()
|
157 |
+
grad = self.model.bert.encoder.layer[-1].attention.self.get_attn_gradients()
|
158 |
+
|
159 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
160 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
161 |
+
grad = grad.mean(dim=[1, 2], keepdim=True)
|
162 |
+
cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
|
163 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min())
|
164 |
+
cam[:, 0, 0] = 0
|
165 |
+
return cam[:, 0]
|
Transformer-Explainability/BERT_explainability/modules/__init__.py
ADDED
File without changes
|
Transformer-Explainability/BERT_explainability/modules/layers_lrp.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"forward_hook",
|
7 |
+
"Clone",
|
8 |
+
"Add",
|
9 |
+
"Cat",
|
10 |
+
"ReLU",
|
11 |
+
"GELU",
|
12 |
+
"Dropout",
|
13 |
+
"BatchNorm2d",
|
14 |
+
"Linear",
|
15 |
+
"MaxPool2d",
|
16 |
+
"AdaptiveAvgPool2d",
|
17 |
+
"AvgPool2d",
|
18 |
+
"Conv2d",
|
19 |
+
"Sequential",
|
20 |
+
"safe_divide",
|
21 |
+
"einsum",
|
22 |
+
"Softmax",
|
23 |
+
"IndexSelect",
|
24 |
+
"LayerNorm",
|
25 |
+
"AddEye",
|
26 |
+
"Tanh",
|
27 |
+
"MatMul",
|
28 |
+
"Mul",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
def safe_divide(a, b):
|
33 |
+
den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
|
34 |
+
den = den + den.eq(0).type(den.type()) * 1e-9
|
35 |
+
return a / den * b.ne(0).type(b.type())
|
36 |
+
|
37 |
+
|
38 |
+
def forward_hook(self, input, output):
|
39 |
+
if type(input[0]) in (list, tuple):
|
40 |
+
self.X = []
|
41 |
+
for i in input[0]:
|
42 |
+
x = i.detach()
|
43 |
+
x.requires_grad = True
|
44 |
+
self.X.append(x)
|
45 |
+
else:
|
46 |
+
self.X = input[0].detach()
|
47 |
+
self.X.requires_grad = True
|
48 |
+
|
49 |
+
self.Y = output
|
50 |
+
|
51 |
+
|
52 |
+
def backward_hook(self, grad_input, grad_output):
|
53 |
+
self.grad_input = grad_input
|
54 |
+
self.grad_output = grad_output
|
55 |
+
|
56 |
+
|
57 |
+
class RelProp(nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(RelProp, self).__init__()
|
60 |
+
# if not self.training:
|
61 |
+
self.register_forward_hook(forward_hook)
|
62 |
+
|
63 |
+
def gradprop(self, Z, X, S):
|
64 |
+
C = torch.autograd.grad(Z, X, S, retain_graph=True)
|
65 |
+
return C
|
66 |
+
|
67 |
+
def relprop(self, R, alpha):
|
68 |
+
return R
|
69 |
+
|
70 |
+
|
71 |
+
class RelPropSimple(RelProp):
|
72 |
+
def relprop(self, R, alpha):
|
73 |
+
Z = self.forward(self.X)
|
74 |
+
S = safe_divide(R, Z)
|
75 |
+
C = self.gradprop(Z, self.X, S)
|
76 |
+
|
77 |
+
if torch.is_tensor(self.X) == False:
|
78 |
+
outputs = []
|
79 |
+
outputs.append(self.X[0] * C[0])
|
80 |
+
outputs.append(self.X[1] * C[1])
|
81 |
+
else:
|
82 |
+
outputs = self.X * (C[0])
|
83 |
+
return outputs
|
84 |
+
|
85 |
+
|
86 |
+
class AddEye(RelPropSimple):
|
87 |
+
# input of shape B, C, seq_len, seq_len
|
88 |
+
def forward(self, input):
|
89 |
+
return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
|
90 |
+
|
91 |
+
|
92 |
+
class ReLU(nn.ReLU, RelProp):
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class Tanh(nn.Tanh, RelProp):
|
97 |
+
pass
|
98 |
+
|
99 |
+
|
100 |
+
class GELU(nn.GELU, RelProp):
|
101 |
+
pass
|
102 |
+
|
103 |
+
|
104 |
+
class Softmax(nn.Softmax, RelProp):
|
105 |
+
pass
|
106 |
+
|
107 |
+
|
108 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
109 |
+
pass
|
110 |
+
|
111 |
+
|
112 |
+
class Dropout(nn.Dropout, RelProp):
|
113 |
+
pass
|
114 |
+
|
115 |
+
|
116 |
+
class MaxPool2d(nn.MaxPool2d, RelPropSimple):
|
117 |
+
pass
|
118 |
+
|
119 |
+
|
120 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
121 |
+
pass
|
122 |
+
|
123 |
+
|
124 |
+
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
|
125 |
+
pass
|
126 |
+
|
127 |
+
|
128 |
+
class MatMul(RelPropSimple):
|
129 |
+
def forward(self, inputs):
|
130 |
+
return torch.matmul(*inputs)
|
131 |
+
|
132 |
+
|
133 |
+
class Mul(RelPropSimple):
|
134 |
+
def forward(self, inputs):
|
135 |
+
return torch.mul(*inputs)
|
136 |
+
|
137 |
+
|
138 |
+
class AvgPool2d(nn.AvgPool2d, RelPropSimple):
|
139 |
+
pass
|
140 |
+
|
141 |
+
|
142 |
+
class Add(RelPropSimple):
|
143 |
+
def forward(self, inputs):
|
144 |
+
return torch.add(*inputs)
|
145 |
+
|
146 |
+
|
147 |
+
class einsum(RelPropSimple):
|
148 |
+
def __init__(self, equation):
|
149 |
+
super().__init__()
|
150 |
+
self.equation = equation
|
151 |
+
|
152 |
+
def forward(self, *operands):
|
153 |
+
return torch.einsum(self.equation, *operands)
|
154 |
+
|
155 |
+
|
156 |
+
class IndexSelect(RelProp):
|
157 |
+
def forward(self, inputs, dim, indices):
|
158 |
+
self.__setattr__("dim", dim)
|
159 |
+
self.__setattr__("indices", indices)
|
160 |
+
|
161 |
+
return torch.index_select(inputs, dim, indices)
|
162 |
+
|
163 |
+
def relprop(self, R, alpha):
|
164 |
+
Z = self.forward(self.X, self.dim, self.indices)
|
165 |
+
S = safe_divide(R, Z)
|
166 |
+
C = self.gradprop(Z, self.X, S)
|
167 |
+
|
168 |
+
if torch.is_tensor(self.X) == False:
|
169 |
+
outputs = []
|
170 |
+
outputs.append(self.X[0] * C[0])
|
171 |
+
outputs.append(self.X[1] * C[1])
|
172 |
+
else:
|
173 |
+
outputs = self.X * (C[0])
|
174 |
+
return outputs
|
175 |
+
|
176 |
+
|
177 |
+
class Clone(RelProp):
|
178 |
+
def forward(self, input, num):
|
179 |
+
self.__setattr__("num", num)
|
180 |
+
outputs = []
|
181 |
+
for _ in range(num):
|
182 |
+
outputs.append(input)
|
183 |
+
|
184 |
+
return outputs
|
185 |
+
|
186 |
+
def relprop(self, R, alpha):
|
187 |
+
Z = []
|
188 |
+
for _ in range(self.num):
|
189 |
+
Z.append(self.X)
|
190 |
+
S = [safe_divide(r, z) for r, z in zip(R, Z)]
|
191 |
+
C = self.gradprop(Z, self.X, S)[0]
|
192 |
+
|
193 |
+
R = self.X * C
|
194 |
+
|
195 |
+
return R
|
196 |
+
|
197 |
+
|
198 |
+
class Cat(RelProp):
|
199 |
+
def forward(self, inputs, dim):
|
200 |
+
self.__setattr__("dim", dim)
|
201 |
+
return torch.cat(inputs, dim)
|
202 |
+
|
203 |
+
def relprop(self, R, alpha):
|
204 |
+
Z = self.forward(self.X, self.dim)
|
205 |
+
S = safe_divide(R, Z)
|
206 |
+
C = self.gradprop(Z, self.X, S)
|
207 |
+
|
208 |
+
outputs = []
|
209 |
+
for x, c in zip(self.X, C):
|
210 |
+
outputs.append(x * c)
|
211 |
+
|
212 |
+
return outputs
|
213 |
+
|
214 |
+
|
215 |
+
class Sequential(nn.Sequential):
|
216 |
+
def relprop(self, R, alpha):
|
217 |
+
for m in reversed(self._modules.values()):
|
218 |
+
R = m.relprop(R, alpha)
|
219 |
+
return R
|
220 |
+
|
221 |
+
|
222 |
+
class BatchNorm2d(nn.BatchNorm2d, RelProp):
|
223 |
+
def relprop(self, R, alpha):
|
224 |
+
X = self.X
|
225 |
+
beta = 1 - alpha
|
226 |
+
weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
|
227 |
+
(
|
228 |
+
self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2)
|
229 |
+
+ self.eps
|
230 |
+
).pow(0.5)
|
231 |
+
)
|
232 |
+
Z = X * weight + 1e-9
|
233 |
+
S = R / Z
|
234 |
+
Ca = S * weight
|
235 |
+
R = self.X * (Ca)
|
236 |
+
return R
|
237 |
+
|
238 |
+
|
239 |
+
class Linear(nn.Linear, RelProp):
|
240 |
+
def relprop(self, R, alpha):
|
241 |
+
beta = alpha - 1
|
242 |
+
pw = torch.clamp(self.weight, min=0)
|
243 |
+
nw = torch.clamp(self.weight, max=0)
|
244 |
+
px = torch.clamp(self.X, min=0)
|
245 |
+
nx = torch.clamp(self.X, max=0)
|
246 |
+
|
247 |
+
def f(w1, w2, x1, x2):
|
248 |
+
Z1 = F.linear(x1, w1)
|
249 |
+
Z2 = F.linear(x2, w2)
|
250 |
+
S1 = safe_divide(R, Z1)
|
251 |
+
S2 = safe_divide(R, Z2)
|
252 |
+
C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
|
253 |
+
C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
|
254 |
+
|
255 |
+
return C1 + C2
|
256 |
+
|
257 |
+
activator_relevances = f(pw, nw, px, nx)
|
258 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
259 |
+
|
260 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
261 |
+
|
262 |
+
return R
|
263 |
+
|
264 |
+
|
265 |
+
class Conv2d(nn.Conv2d, RelProp):
|
266 |
+
def gradprop2(self, DY, weight):
|
267 |
+
Z = self.forward(self.X)
|
268 |
+
|
269 |
+
output_padding = self.X.size()[2] - (
|
270 |
+
(Z.size()[2] - 1) * self.stride[0]
|
271 |
+
- 2 * self.padding[0]
|
272 |
+
+ self.kernel_size[0]
|
273 |
+
)
|
274 |
+
|
275 |
+
return F.conv_transpose2d(
|
276 |
+
DY,
|
277 |
+
weight,
|
278 |
+
stride=self.stride,
|
279 |
+
padding=self.padding,
|
280 |
+
output_padding=output_padding,
|
281 |
+
)
|
282 |
+
|
283 |
+
def relprop(self, R, alpha):
|
284 |
+
if self.X.shape[1] == 3:
|
285 |
+
pw = torch.clamp(self.weight, min=0)
|
286 |
+
nw = torch.clamp(self.weight, max=0)
|
287 |
+
X = self.X
|
288 |
+
L = (
|
289 |
+
self.X * 0
|
290 |
+
+ torch.min(
|
291 |
+
torch.min(
|
292 |
+
torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
|
293 |
+
)[0],
|
294 |
+
dim=3,
|
295 |
+
keepdim=True,
|
296 |
+
)[0]
|
297 |
+
)
|
298 |
+
H = (
|
299 |
+
self.X * 0
|
300 |
+
+ torch.max(
|
301 |
+
torch.max(
|
302 |
+
torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
|
303 |
+
)[0],
|
304 |
+
dim=3,
|
305 |
+
keepdim=True,
|
306 |
+
)[0]
|
307 |
+
)
|
308 |
+
Za = (
|
309 |
+
torch.conv2d(
|
310 |
+
X, self.weight, bias=None, stride=self.stride, padding=self.padding
|
311 |
+
)
|
312 |
+
- torch.conv2d(
|
313 |
+
L, pw, bias=None, stride=self.stride, padding=self.padding
|
314 |
+
)
|
315 |
+
- torch.conv2d(
|
316 |
+
H, nw, bias=None, stride=self.stride, padding=self.padding
|
317 |
+
)
|
318 |
+
+ 1e-9
|
319 |
+
)
|
320 |
+
|
321 |
+
S = R / Za
|
322 |
+
C = (
|
323 |
+
X * self.gradprop2(S, self.weight)
|
324 |
+
- L * self.gradprop2(S, pw)
|
325 |
+
- H * self.gradprop2(S, nw)
|
326 |
+
)
|
327 |
+
R = C
|
328 |
+
else:
|
329 |
+
beta = alpha - 1
|
330 |
+
pw = torch.clamp(self.weight, min=0)
|
331 |
+
nw = torch.clamp(self.weight, max=0)
|
332 |
+
px = torch.clamp(self.X, min=0)
|
333 |
+
nx = torch.clamp(self.X, max=0)
|
334 |
+
|
335 |
+
def f(w1, w2, x1, x2):
|
336 |
+
Z1 = F.conv2d(
|
337 |
+
x1, w1, bias=None, stride=self.stride, padding=self.padding
|
338 |
+
)
|
339 |
+
Z2 = F.conv2d(
|
340 |
+
x2, w2, bias=None, stride=self.stride, padding=self.padding
|
341 |
+
)
|
342 |
+
S1 = safe_divide(R, Z1)
|
343 |
+
S2 = safe_divide(R, Z2)
|
344 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
345 |
+
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
346 |
+
return C1 + C2
|
347 |
+
|
348 |
+
activator_relevances = f(pw, nw, px, nx)
|
349 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
350 |
+
|
351 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
352 |
+
return R
|
Transformer-Explainability/BERT_explainability/modules/layers_ours.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"forward_hook",
|
7 |
+
"Clone",
|
8 |
+
"Add",
|
9 |
+
"Cat",
|
10 |
+
"ReLU",
|
11 |
+
"GELU",
|
12 |
+
"Dropout",
|
13 |
+
"BatchNorm2d",
|
14 |
+
"Linear",
|
15 |
+
"MaxPool2d",
|
16 |
+
"AdaptiveAvgPool2d",
|
17 |
+
"AvgPool2d",
|
18 |
+
"Conv2d",
|
19 |
+
"Sequential",
|
20 |
+
"safe_divide",
|
21 |
+
"einsum",
|
22 |
+
"Softmax",
|
23 |
+
"IndexSelect",
|
24 |
+
"LayerNorm",
|
25 |
+
"AddEye",
|
26 |
+
"Tanh",
|
27 |
+
"MatMul",
|
28 |
+
"Mul",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
def safe_divide(a, b):
|
33 |
+
den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
|
34 |
+
den = den + den.eq(0).type(den.type()) * 1e-9
|
35 |
+
return a / den * b.ne(0).type(b.type())
|
36 |
+
|
37 |
+
|
38 |
+
def forward_hook(self, input, output):
|
39 |
+
if type(input[0]) in (list, tuple):
|
40 |
+
self.X = []
|
41 |
+
for i in input[0]:
|
42 |
+
x = i.detach()
|
43 |
+
x.requires_grad = True
|
44 |
+
self.X.append(x)
|
45 |
+
else:
|
46 |
+
self.X = input[0].detach()
|
47 |
+
self.X.requires_grad = True
|
48 |
+
|
49 |
+
self.Y = output
|
50 |
+
|
51 |
+
|
52 |
+
def backward_hook(self, grad_input, grad_output):
|
53 |
+
self.grad_input = grad_input
|
54 |
+
self.grad_output = grad_output
|
55 |
+
|
56 |
+
|
57 |
+
class RelProp(nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(RelProp, self).__init__()
|
60 |
+
# if not self.training:
|
61 |
+
self.register_forward_hook(forward_hook)
|
62 |
+
|
63 |
+
def gradprop(self, Z, X, S):
|
64 |
+
C = torch.autograd.grad(Z, X, S, retain_graph=True)
|
65 |
+
return C
|
66 |
+
|
67 |
+
def relprop(self, R, alpha):
|
68 |
+
return R
|
69 |
+
|
70 |
+
|
71 |
+
class RelPropSimple(RelProp):
|
72 |
+
def relprop(self, R, alpha):
|
73 |
+
Z = self.forward(self.X)
|
74 |
+
S = safe_divide(R, Z)
|
75 |
+
C = self.gradprop(Z, self.X, S)
|
76 |
+
|
77 |
+
if torch.is_tensor(self.X) == False:
|
78 |
+
outputs = []
|
79 |
+
outputs.append(self.X[0] * C[0])
|
80 |
+
outputs.append(self.X[1] * C[1])
|
81 |
+
else:
|
82 |
+
outputs = self.X * (C[0])
|
83 |
+
return outputs
|
84 |
+
|
85 |
+
|
86 |
+
class AddEye(RelPropSimple):
|
87 |
+
# input of shape B, C, seq_len, seq_len
|
88 |
+
def forward(self, input):
|
89 |
+
return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
|
90 |
+
|
91 |
+
|
92 |
+
class ReLU(nn.ReLU, RelProp):
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class GELU(nn.GELU, RelProp):
|
97 |
+
pass
|
98 |
+
|
99 |
+
|
100 |
+
class Softmax(nn.Softmax, RelProp):
|
101 |
+
pass
|
102 |
+
|
103 |
+
|
104 |
+
class Mul(RelPropSimple):
|
105 |
+
def forward(self, inputs):
|
106 |
+
return torch.mul(*inputs)
|
107 |
+
|
108 |
+
|
109 |
+
class Tanh(nn.Tanh, RelProp):
|
110 |
+
pass
|
111 |
+
|
112 |
+
|
113 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
114 |
+
pass
|
115 |
+
|
116 |
+
|
117 |
+
class Dropout(nn.Dropout, RelProp):
|
118 |
+
pass
|
119 |
+
|
120 |
+
|
121 |
+
class MatMul(RelPropSimple):
|
122 |
+
def forward(self, inputs):
|
123 |
+
return torch.matmul(*inputs)
|
124 |
+
|
125 |
+
|
126 |
+
class MaxPool2d(nn.MaxPool2d, RelPropSimple):
|
127 |
+
pass
|
128 |
+
|
129 |
+
|
130 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
131 |
+
pass
|
132 |
+
|
133 |
+
|
134 |
+
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
|
135 |
+
pass
|
136 |
+
|
137 |
+
|
138 |
+
class AvgPool2d(nn.AvgPool2d, RelPropSimple):
|
139 |
+
pass
|
140 |
+
|
141 |
+
|
142 |
+
class Add(RelPropSimple):
|
143 |
+
def forward(self, inputs):
|
144 |
+
return torch.add(*inputs)
|
145 |
+
|
146 |
+
def relprop(self, R, alpha):
|
147 |
+
Z = self.forward(self.X)
|
148 |
+
S = safe_divide(R, Z)
|
149 |
+
C = self.gradprop(Z, self.X, S)
|
150 |
+
|
151 |
+
a = self.X[0] * C[0]
|
152 |
+
b = self.X[1] * C[1]
|
153 |
+
|
154 |
+
a_sum = a.sum()
|
155 |
+
b_sum = b.sum()
|
156 |
+
|
157 |
+
a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
158 |
+
b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
159 |
+
|
160 |
+
a = a * safe_divide(a_fact, a.sum())
|
161 |
+
b = b * safe_divide(b_fact, b.sum())
|
162 |
+
|
163 |
+
outputs = [a, b]
|
164 |
+
|
165 |
+
return outputs
|
166 |
+
|
167 |
+
|
168 |
+
class einsum(RelPropSimple):
|
169 |
+
def __init__(self, equation):
|
170 |
+
super().__init__()
|
171 |
+
self.equation = equation
|
172 |
+
|
173 |
+
def forward(self, *operands):
|
174 |
+
return torch.einsum(self.equation, *operands)
|
175 |
+
|
176 |
+
|
177 |
+
class IndexSelect(RelProp):
|
178 |
+
def forward(self, inputs, dim, indices):
|
179 |
+
self.__setattr__("dim", dim)
|
180 |
+
self.__setattr__("indices", indices)
|
181 |
+
|
182 |
+
return torch.index_select(inputs, dim, indices)
|
183 |
+
|
184 |
+
def relprop(self, R, alpha):
|
185 |
+
Z = self.forward(self.X, self.dim, self.indices)
|
186 |
+
S = safe_divide(R, Z)
|
187 |
+
C = self.gradprop(Z, self.X, S)
|
188 |
+
|
189 |
+
if torch.is_tensor(self.X) == False:
|
190 |
+
outputs = []
|
191 |
+
outputs.append(self.X[0] * C[0])
|
192 |
+
outputs.append(self.X[1] * C[1])
|
193 |
+
else:
|
194 |
+
outputs = self.X * (C[0])
|
195 |
+
return outputs
|
196 |
+
|
197 |
+
|
198 |
+
class Clone(RelProp):
|
199 |
+
def forward(self, input, num):
|
200 |
+
self.__setattr__("num", num)
|
201 |
+
outputs = []
|
202 |
+
for _ in range(num):
|
203 |
+
outputs.append(input)
|
204 |
+
|
205 |
+
return outputs
|
206 |
+
|
207 |
+
def relprop(self, R, alpha):
|
208 |
+
Z = []
|
209 |
+
for _ in range(self.num):
|
210 |
+
Z.append(self.X)
|
211 |
+
S = [safe_divide(r, z) for r, z in zip(R, Z)]
|
212 |
+
C = self.gradprop(Z, self.X, S)[0]
|
213 |
+
|
214 |
+
R = self.X * C
|
215 |
+
|
216 |
+
return R
|
217 |
+
|
218 |
+
|
219 |
+
class Cat(RelProp):
|
220 |
+
def forward(self, inputs, dim):
|
221 |
+
self.__setattr__("dim", dim)
|
222 |
+
return torch.cat(inputs, dim)
|
223 |
+
|
224 |
+
def relprop(self, R, alpha):
|
225 |
+
Z = self.forward(self.X, self.dim)
|
226 |
+
S = safe_divide(R, Z)
|
227 |
+
C = self.gradprop(Z, self.X, S)
|
228 |
+
|
229 |
+
outputs = []
|
230 |
+
for x, c in zip(self.X, C):
|
231 |
+
outputs.append(x * c)
|
232 |
+
|
233 |
+
return outputs
|
234 |
+
|
235 |
+
|
236 |
+
class Sequential(nn.Sequential):
|
237 |
+
def relprop(self, R, alpha):
|
238 |
+
for m in reversed(self._modules.values()):
|
239 |
+
R = m.relprop(R, alpha)
|
240 |
+
return R
|
241 |
+
|
242 |
+
|
243 |
+
class BatchNorm2d(nn.BatchNorm2d, RelProp):
|
244 |
+
def relprop(self, R, alpha):
|
245 |
+
X = self.X
|
246 |
+
beta = 1 - alpha
|
247 |
+
weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
|
248 |
+
(
|
249 |
+
self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2)
|
250 |
+
+ self.eps
|
251 |
+
).pow(0.5)
|
252 |
+
)
|
253 |
+
Z = X * weight + 1e-9
|
254 |
+
S = R / Z
|
255 |
+
Ca = S * weight
|
256 |
+
R = self.X * (Ca)
|
257 |
+
return R
|
258 |
+
|
259 |
+
|
260 |
+
class Linear(nn.Linear, RelProp):
|
261 |
+
def relprop(self, R, alpha):
|
262 |
+
beta = alpha - 1
|
263 |
+
pw = torch.clamp(self.weight, min=0)
|
264 |
+
nw = torch.clamp(self.weight, max=0)
|
265 |
+
px = torch.clamp(self.X, min=0)
|
266 |
+
nx = torch.clamp(self.X, max=0)
|
267 |
+
|
268 |
+
def f(w1, w2, x1, x2):
|
269 |
+
Z1 = F.linear(x1, w1)
|
270 |
+
Z2 = F.linear(x2, w2)
|
271 |
+
S1 = safe_divide(R, Z1 + Z2)
|
272 |
+
S2 = safe_divide(R, Z1 + Z2)
|
273 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
274 |
+
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
275 |
+
|
276 |
+
return C1 + C2
|
277 |
+
|
278 |
+
activator_relevances = f(pw, nw, px, nx)
|
279 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
280 |
+
|
281 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
282 |
+
|
283 |
+
return R
|
284 |
+
|
285 |
+
|
286 |
+
class Conv2d(nn.Conv2d, RelProp):
|
287 |
+
def gradprop2(self, DY, weight):
|
288 |
+
Z = self.forward(self.X)
|
289 |
+
|
290 |
+
output_padding = self.X.size()[2] - (
|
291 |
+
(Z.size()[2] - 1) * self.stride[0]
|
292 |
+
- 2 * self.padding[0]
|
293 |
+
+ self.kernel_size[0]
|
294 |
+
)
|
295 |
+
|
296 |
+
return F.conv_transpose2d(
|
297 |
+
DY,
|
298 |
+
weight,
|
299 |
+
stride=self.stride,
|
300 |
+
padding=self.padding,
|
301 |
+
output_padding=output_padding,
|
302 |
+
)
|
303 |
+
|
304 |
+
def relprop(self, R, alpha):
|
305 |
+
if self.X.shape[1] == 3:
|
306 |
+
pw = torch.clamp(self.weight, min=0)
|
307 |
+
nw = torch.clamp(self.weight, max=0)
|
308 |
+
X = self.X
|
309 |
+
L = (
|
310 |
+
self.X * 0
|
311 |
+
+ torch.min(
|
312 |
+
torch.min(
|
313 |
+
torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
|
314 |
+
)[0],
|
315 |
+
dim=3,
|
316 |
+
keepdim=True,
|
317 |
+
)[0]
|
318 |
+
)
|
319 |
+
H = (
|
320 |
+
self.X * 0
|
321 |
+
+ torch.max(
|
322 |
+
torch.max(
|
323 |
+
torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
|
324 |
+
)[0],
|
325 |
+
dim=3,
|
326 |
+
keepdim=True,
|
327 |
+
)[0]
|
328 |
+
)
|
329 |
+
Za = (
|
330 |
+
torch.conv2d(
|
331 |
+
X, self.weight, bias=None, stride=self.stride, padding=self.padding
|
332 |
+
)
|
333 |
+
- torch.conv2d(
|
334 |
+
L, pw, bias=None, stride=self.stride, padding=self.padding
|
335 |
+
)
|
336 |
+
- torch.conv2d(
|
337 |
+
H, nw, bias=None, stride=self.stride, padding=self.padding
|
338 |
+
)
|
339 |
+
+ 1e-9
|
340 |
+
)
|
341 |
+
|
342 |
+
S = R / Za
|
343 |
+
C = (
|
344 |
+
X * self.gradprop2(S, self.weight)
|
345 |
+
- L * self.gradprop2(S, pw)
|
346 |
+
- H * self.gradprop2(S, nw)
|
347 |
+
)
|
348 |
+
R = C
|
349 |
+
else:
|
350 |
+
beta = alpha - 1
|
351 |
+
pw = torch.clamp(self.weight, min=0)
|
352 |
+
nw = torch.clamp(self.weight, max=0)
|
353 |
+
px = torch.clamp(self.X, min=0)
|
354 |
+
nx = torch.clamp(self.X, max=0)
|
355 |
+
|
356 |
+
def f(w1, w2, x1, x2):
|
357 |
+
Z1 = F.conv2d(
|
358 |
+
x1, w1, bias=None, stride=self.stride, padding=self.padding
|
359 |
+
)
|
360 |
+
Z2 = F.conv2d(
|
361 |
+
x2, w2, bias=None, stride=self.stride, padding=self.padding
|
362 |
+
)
|
363 |
+
S1 = safe_divide(R, Z1)
|
364 |
+
S2 = safe_divide(R, Z2)
|
365 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
366 |
+
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
367 |
+
return C1 + C2
|
368 |
+
|
369 |
+
activator_relevances = f(pw, nw, px, nx)
|
370 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
371 |
+
|
372 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
373 |
+
return R
|
Transformer-Explainability/BERT_params/boolq.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.05
|
5 |
+
},
|
6 |
+
"evidence_identifier": {
|
7 |
+
"mlp_size": 128,
|
8 |
+
"dropout": 0.2,
|
9 |
+
"batch_size": 768,
|
10 |
+
"epochs": 50,
|
11 |
+
"patience": 10,
|
12 |
+
"lr": 1e-3,
|
13 |
+
"sampling_method": "random",
|
14 |
+
"sampling_ratio": 1.0
|
15 |
+
},
|
16 |
+
"evidence_classifier": {
|
17 |
+
"classes": [ "False", "True" ],
|
18 |
+
"mlp_size": 128,
|
19 |
+
"dropout": 0.2,
|
20 |
+
"batch_size": 768,
|
21 |
+
"epochs": 50,
|
22 |
+
"patience": 10,
|
23 |
+
"lr": 1e-3,
|
24 |
+
"sampling_method": "everything"
|
25 |
+
}
|
26 |
+
}
|
Transformer-Explainability/BERT_params/boolq_baas.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"start_server": 0,
|
3 |
+
"bert_dir": "model_components/uncased_L-12_H-768_A-12/",
|
4 |
+
"max_length": 512,
|
5 |
+
"pooling_strategy": "CLS_TOKEN",
|
6 |
+
"evidence_identifier": {
|
7 |
+
"batch_size": 64,
|
8 |
+
"epochs": 3,
|
9 |
+
"patience": 10,
|
10 |
+
"lr": 1e-3,
|
11 |
+
"max_grad_norm": 1.0,
|
12 |
+
"sampling_method": "random",
|
13 |
+
"sampling_ratio": 1.0
|
14 |
+
},
|
15 |
+
"evidence_classifier": {
|
16 |
+
"classes": [ "False", "True" ],
|
17 |
+
"batch_size": 64,
|
18 |
+
"epochs": 3,
|
19 |
+
"patience": 10,
|
20 |
+
"lr": 1e-3,
|
21 |
+
"max_grad_norm": 1.0,
|
22 |
+
"sampling_method": "everything"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
|
Transformer-Explainability/BERT_params/boolq_bert.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 1,
|
6 |
+
"use_evidence_token_identifier": 0,
|
7 |
+
"evidence_identifier": {
|
8 |
+
"batch_size": 10,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 50,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1,
|
14 |
+
"sampling_method": "random",
|
15 |
+
"sampling_ratio": 1,
|
16 |
+
"use_half_precision": 0
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [
|
20 |
+
"False",
|
21 |
+
"True"
|
22 |
+
],
|
23 |
+
"batch_size": 10,
|
24 |
+
"warmup_steps": 50,
|
25 |
+
"epochs": 10,
|
26 |
+
"patience": 10,
|
27 |
+
"lr": 1e-05,
|
28 |
+
"max_grad_norm": 1,
|
29 |
+
"sampling_method": "everything",
|
30 |
+
"use_half_precision": 0
|
31 |
+
}
|
32 |
+
}
|
Transformer-Explainability/BERT_params/boolq_soft.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.2
|
5 |
+
},
|
6 |
+
"classifier": {
|
7 |
+
"classes": [ "False", "True" ],
|
8 |
+
"has_query": 1,
|
9 |
+
"hidden_size": 32,
|
10 |
+
"mlp_size": 128,
|
11 |
+
"dropout": 0.2,
|
12 |
+
"batch_size": 16,
|
13 |
+
"epochs": 50,
|
14 |
+
"attention_epochs": 50,
|
15 |
+
"patience": 10,
|
16 |
+
"lr": 1e-3,
|
17 |
+
"dropout": 0.2,
|
18 |
+
"k_fraction": 0.07,
|
19 |
+
"threshold": 0.1
|
20 |
+
}
|
21 |
+
}
|
Transformer-Explainability/BERT_params/cose_bert.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 0,
|
6 |
+
"use_evidence_token_identifier": 1,
|
7 |
+
"evidence_token_identifier": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 10,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 0.5,
|
14 |
+
"sampling_method": "everything",
|
15 |
+
"use_half_precision": 0,
|
16 |
+
"cose_data_hack": 1
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [ "false", "true"],
|
20 |
+
"batch_size": 32,
|
21 |
+
"warmup_steps": 10,
|
22 |
+
"epochs": 10,
|
23 |
+
"patience": 10,
|
24 |
+
"lr": 1e-05,
|
25 |
+
"max_grad_norm": 0.5,
|
26 |
+
"sampling_method": "everything",
|
27 |
+
"use_half_precision": 0,
|
28 |
+
"cose_data_hack": 1
|
29 |
+
}
|
30 |
+
}
|
Transformer-Explainability/BERT_params/cose_multiclass.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 1,
|
6 |
+
"use_evidence_token_identifier": 0,
|
7 |
+
"evidence_identifier": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 50,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1,
|
14 |
+
"sampling_method": "random",
|
15 |
+
"sampling_ratio": 1,
|
16 |
+
"use_half_precision": 0
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [
|
20 |
+
"A",
|
21 |
+
"B",
|
22 |
+
"C",
|
23 |
+
"D",
|
24 |
+
"E"
|
25 |
+
],
|
26 |
+
"batch_size": 10,
|
27 |
+
"warmup_steps": 50,
|
28 |
+
"epochs": 10,
|
29 |
+
"patience": 10,
|
30 |
+
"lr": 1e-05,
|
31 |
+
"max_grad_norm": 1,
|
32 |
+
"sampling_method": "everything",
|
33 |
+
"use_half_precision": 0
|
34 |
+
}
|
35 |
+
}
|
Transformer-Explainability/BERT_params/esnli_bert.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 0,
|
6 |
+
"use_evidence_token_identifier": 1,
|
7 |
+
"evidence_token_identifier": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 10,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1,
|
14 |
+
"sampling_method": "everything",
|
15 |
+
"use_half_precision": 0
|
16 |
+
},
|
17 |
+
"evidence_classifier": {
|
18 |
+
"classes": [ "contradiction", "neutral", "entailment" ],
|
19 |
+
"batch_size": 32,
|
20 |
+
"warmup_steps": 10,
|
21 |
+
"epochs": 10,
|
22 |
+
"patience": 10,
|
23 |
+
"lr": 1e-05,
|
24 |
+
"max_grad_norm": 1,
|
25 |
+
"sampling_method": "everything",
|
26 |
+
"use_half_precision": 0
|
27 |
+
}
|
28 |
+
}
|
Transformer-Explainability/BERT_params/evidence_inference.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/PubMed-w2v.bin",
|
4 |
+
"dropout": 0.05
|
5 |
+
},
|
6 |
+
"evidence_identifier": {
|
7 |
+
"mlp_size": 128,
|
8 |
+
"dropout": 0.05,
|
9 |
+
"batch_size": 768,
|
10 |
+
"epochs": 50,
|
11 |
+
"patience": 10,
|
12 |
+
"lr": 1e-3,
|
13 |
+
"sampling_method": "random",
|
14 |
+
"sampling_ratio": 1.0
|
15 |
+
},
|
16 |
+
"evidence_classifier": {
|
17 |
+
"classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
|
18 |
+
"mlp_size": 128,
|
19 |
+
"dropout": 0.05,
|
20 |
+
"batch_size": 768,
|
21 |
+
"epochs": 50,
|
22 |
+
"patience": 10,
|
23 |
+
"lr": 1e-3,
|
24 |
+
"sampling_method": "everything"
|
25 |
+
}
|
26 |
+
}
|
Transformer-Explainability/BERT_params/evidence_inference_bert.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "allenai/scibert_scivocab_uncased",
|
4 |
+
"bert_dir": "allenai/scibert_scivocab_uncased",
|
5 |
+
"use_evidence_sentence_identifier": 1,
|
6 |
+
"use_evidence_token_identifier": 0,
|
7 |
+
"evidence_identifier": {
|
8 |
+
"batch_size": 10,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 10,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1,
|
14 |
+
"sampling_method": "random",
|
15 |
+
"use_half_precision": 0,
|
16 |
+
"sampling_ratio": 1
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [
|
20 |
+
"significantly decreased",
|
21 |
+
"no significant difference",
|
22 |
+
"significantly increased"
|
23 |
+
],
|
24 |
+
"batch_size": 10,
|
25 |
+
"warmup_steps": 10,
|
26 |
+
"epochs": 10,
|
27 |
+
"patience": 10,
|
28 |
+
"lr": 1e-05,
|
29 |
+
"max_grad_norm": 1,
|
30 |
+
"sampling_method": "everything",
|
31 |
+
"use_half_precision": 0
|
32 |
+
}
|
33 |
+
}
|
Transformer-Explainability/BERT_params/evidence_inference_soft.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/PubMed-w2v.bin",
|
4 |
+
"dropout": 0.2
|
5 |
+
},
|
6 |
+
"classifier": {
|
7 |
+
"classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
|
8 |
+
"use_token_selection": 1,
|
9 |
+
"has_query": 1,
|
10 |
+
"hidden_size": 32,
|
11 |
+
"mlp_size": 128,
|
12 |
+
"dropout": 0.2,
|
13 |
+
"batch_size": 16,
|
14 |
+
"epochs": 50,
|
15 |
+
"attention_epochs": 0,
|
16 |
+
"patience": 10,
|
17 |
+
"lr": 1e-3,
|
18 |
+
"dropout": 0.2,
|
19 |
+
"k_fraction": 0.013,
|
20 |
+
"threshold": 0.1
|
21 |
+
}
|
22 |
+
}
|
Transformer-Explainability/BERT_params/fever.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.05
|
5 |
+
},
|
6 |
+
"evidence_identifier": {
|
7 |
+
"mlp_size": 128,
|
8 |
+
"dropout": 0.05,
|
9 |
+
"batch_size": 768,
|
10 |
+
"epochs": 50,
|
11 |
+
"patience": 10,
|
12 |
+
"lr": 1e-3,
|
13 |
+
"sampling_method": "random",
|
14 |
+
"sampling_ratio": 1.0
|
15 |
+
},
|
16 |
+
"evidence_classifier": {
|
17 |
+
"classes": [ "SUPPORTS", "REFUTES" ],
|
18 |
+
"mlp_size": 128,
|
19 |
+
"dropout": 0.05,
|
20 |
+
"batch_size": 768,
|
21 |
+
"epochs": 50,
|
22 |
+
"patience": 10,
|
23 |
+
"lr": 1e-5,
|
24 |
+
"sampling_method": "everything"
|
25 |
+
}
|
26 |
+
}
|
Transformer-Explainability/BERT_params/fever_baas.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"start_server": 0,
|
3 |
+
"bert_dir": "model_components/uncased_L-12_H-768_A-12/",
|
4 |
+
"max_length": 512,
|
5 |
+
"pooling_strategy": "CLS_TOKEN",
|
6 |
+
"evidence_identifier": {
|
7 |
+
"batch_size": 64,
|
8 |
+
"epochs": 3,
|
9 |
+
"patience": 10,
|
10 |
+
"lr": 1e-3,
|
11 |
+
"max_grad_norm": 1.0,
|
12 |
+
"sampling_method": "random",
|
13 |
+
"sampling_ratio": 1.0
|
14 |
+
},
|
15 |
+
"evidence_classifier": {
|
16 |
+
"classes": [ "SUPPORTS", "REFUTES" ],
|
17 |
+
"batch_size": 64,
|
18 |
+
"epochs": 3,
|
19 |
+
"patience": 10,
|
20 |
+
"lr": 1e-3,
|
21 |
+
"max_grad_norm": 1.0,
|
22 |
+
"sampling_method": "everything"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
Transformer-Explainability/BERT_params/fever_bert.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 1,
|
6 |
+
"use_evidence_token_identifier": 0,
|
7 |
+
"evidence_identifier": {
|
8 |
+
"batch_size": 16,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 10,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1.0,
|
14 |
+
"sampling_method": "random",
|
15 |
+
"sampling_ratio": 1.0,
|
16 |
+
"use_half_precision": 0
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [
|
20 |
+
"SUPPORTS",
|
21 |
+
"REFUTES"
|
22 |
+
],
|
23 |
+
"batch_size": 10,
|
24 |
+
"warmup_steps": 10,
|
25 |
+
"epochs": 10,
|
26 |
+
"patience": 10,
|
27 |
+
"lr": 1e-05,
|
28 |
+
"max_grad_norm": 1.0,
|
29 |
+
"sampling_method": "everything",
|
30 |
+
"use_half_precision": 0
|
31 |
+
}
|
32 |
+
}
|
Transformer-Explainability/BERT_params/fever_soft.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.2
|
5 |
+
},
|
6 |
+
"classifier": {
|
7 |
+
"classes": [ "SUPPORTS", "REFUTES" ],
|
8 |
+
"has_query": 1,
|
9 |
+
"hidden_size": 32,
|
10 |
+
"mlp_size": 128,
|
11 |
+
"dropout": 0.2,
|
12 |
+
"batch_size": 128,
|
13 |
+
"epochs": 50,
|
14 |
+
"attention_epochs": 50,
|
15 |
+
"patience": 10,
|
16 |
+
"lr": 1e-3,
|
17 |
+
"dropout": 0.2,
|
18 |
+
"k_fraction": 0.07,
|
19 |
+
"threshold": 0.1
|
20 |
+
}
|
21 |
+
}
|
Transformer-Explainability/BERT_params/movies.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.05
|
5 |
+
},
|
6 |
+
"evidence_identifier": {
|
7 |
+
"mlp_size": 128,
|
8 |
+
"dropout": 0.05,
|
9 |
+
"batch_size": 768,
|
10 |
+
"epochs": 50,
|
11 |
+
"patience": 10,
|
12 |
+
"lr": 1e-4,
|
13 |
+
"sampling_method": "random",
|
14 |
+
"sampling_ratio": 1.0
|
15 |
+
},
|
16 |
+
"evidence_classifier": {
|
17 |
+
"classes": [ "NEG", "POS" ],
|
18 |
+
"mlp_size": 128,
|
19 |
+
"dropout": 0.05,
|
20 |
+
"batch_size": 768,
|
21 |
+
"epochs": 50,
|
22 |
+
"patience": 10,
|
23 |
+
"lr": 1e-3,
|
24 |
+
"sampling_method": "everything"
|
25 |
+
}
|
26 |
+
}
|
Transformer-Explainability/BERT_params/movies_baas.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"start_server": 0,
|
3 |
+
"bert_dir": "model_components/uncased_L-12_H-768_A-12/",
|
4 |
+
"max_length": 512,
|
5 |
+
"pooling_strategy": "CLS_TOKEN",
|
6 |
+
"evidence_identifier": {
|
7 |
+
"batch_size": 64,
|
8 |
+
"epochs": 3,
|
9 |
+
"patience": 10,
|
10 |
+
"lr": 1e-3,
|
11 |
+
"max_grad_norm": 1.0,
|
12 |
+
"sampling_method": "random",
|
13 |
+
"sampling_ratio": 1.0
|
14 |
+
},
|
15 |
+
"evidence_classifier": {
|
16 |
+
"classes": [ "NEG", "POS" ],
|
17 |
+
"batch_size": 64,
|
18 |
+
"epochs": 3,
|
19 |
+
"patience": 10,
|
20 |
+
"lr": 1e-3,
|
21 |
+
"max_grad_norm": 1.0,
|
22 |
+
"sampling_method": "everything"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
|
Transformer-Explainability/BERT_params/movies_bert.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 1,
|
6 |
+
"use_evidence_token_identifier": 0,
|
7 |
+
"evidence_identifier": {
|
8 |
+
"batch_size": 16,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 50,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1,
|
14 |
+
"sampling_method": "random",
|
15 |
+
"sampling_ratio": 1,
|
16 |
+
"use_half_precision": 0
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [
|
20 |
+
"NEG",
|
21 |
+
"POS"
|
22 |
+
],
|
23 |
+
"batch_size": 10,
|
24 |
+
"warmup_steps": 50,
|
25 |
+
"epochs": 10,
|
26 |
+
"patience": 10,
|
27 |
+
"lr": 1e-05,
|
28 |
+
"max_grad_norm": 1,
|
29 |
+
"sampling_method": "everything",
|
30 |
+
"use_half_precision": 0
|
31 |
+
}
|
32 |
+
}
|
Transformer-Explainability/BERT_params/movies_soft.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.2
|
5 |
+
},
|
6 |
+
"classifier": {
|
7 |
+
"classes": [ "NEG", "POS" ],
|
8 |
+
"has_query": 0,
|
9 |
+
"hidden_size": 32,
|
10 |
+
"mlp_size": 128,
|
11 |
+
"dropout": 0.2,
|
12 |
+
"batch_size": 16,
|
13 |
+
"epochs": 50,
|
14 |
+
"attention_epochs": 50,
|
15 |
+
"patience": 10,
|
16 |
+
"lr": 1e-3,
|
17 |
+
"dropout": 0.2,
|
18 |
+
"k_fraction": 0.07,
|
19 |
+
"threshold": 0.1
|
20 |
+
}
|
21 |
+
}
|
Transformer-Explainability/BERT_params/multirc.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.05
|
5 |
+
},
|
6 |
+
"evidence_identifier": {
|
7 |
+
"mlp_size": 128,
|
8 |
+
"dropout": 0.05,
|
9 |
+
"batch_size": 768,
|
10 |
+
"epochs": 50,
|
11 |
+
"patience": 10,
|
12 |
+
"lr": 1e-3,
|
13 |
+
"sampling_method": "random",
|
14 |
+
"sampling_ratio": 1.0
|
15 |
+
},
|
16 |
+
"evidence_classifier": {
|
17 |
+
"classes": [ "False", "True" ],
|
18 |
+
"mlp_size": 128,
|
19 |
+
"dropout": 0.05,
|
20 |
+
"batch_size": 768,
|
21 |
+
"epochs": 50,
|
22 |
+
"patience": 10,
|
23 |
+
"lr": 1e-3,
|
24 |
+
"sampling_method": "everything"
|
25 |
+
}
|
26 |
+
}
|
Transformer-Explainability/BERT_params/multirc_baas.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"start_server": 0,
|
3 |
+
"bert_dir": "model_components/uncased_L-12_H-768_A-12/",
|
4 |
+
"max_length": 512,
|
5 |
+
"pooling_strategy": "CLS_TOKEN",
|
6 |
+
"evidence_identifier": {
|
7 |
+
"batch_size": 64,
|
8 |
+
"epochs": 3,
|
9 |
+
"patience": 10,
|
10 |
+
"lr": 1e-3,
|
11 |
+
"max_grad_norm": 1.0,
|
12 |
+
"sampling_method": "random",
|
13 |
+
"sampling_ratio": 1.0
|
14 |
+
},
|
15 |
+
"evidence_classifier": {
|
16 |
+
"classes": [ "False", "True" ],
|
17 |
+
"batch_size": 64,
|
18 |
+
"epochs": 3,
|
19 |
+
"patience": 10,
|
20 |
+
"lr": 1e-3,
|
21 |
+
"max_grad_norm": 1.0,
|
22 |
+
"sampling_method": "everything"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
|
Transformer-Explainability/BERT_params/multirc_bert.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 512,
|
3 |
+
"bert_vocab": "bert-base-uncased",
|
4 |
+
"bert_dir": "bert-base-uncased",
|
5 |
+
"use_evidence_sentence_identifier": 1,
|
6 |
+
"use_evidence_token_identifier": 0,
|
7 |
+
"evidence_identifier": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"epochs": 10,
|
10 |
+
"patience": 10,
|
11 |
+
"warmup_steps": 50,
|
12 |
+
"lr": 1e-05,
|
13 |
+
"max_grad_norm": 1,
|
14 |
+
"sampling_method": "random",
|
15 |
+
"sampling_ratio": 1,
|
16 |
+
"use_half_precision": 0
|
17 |
+
},
|
18 |
+
"evidence_classifier": {
|
19 |
+
"classes": [
|
20 |
+
"False",
|
21 |
+
"True"
|
22 |
+
],
|
23 |
+
"batch_size": 32,
|
24 |
+
"warmup_steps": 50,
|
25 |
+
"epochs": 10,
|
26 |
+
"patience": 10,
|
27 |
+
"lr": 1e-05,
|
28 |
+
"max_grad_norm": 1,
|
29 |
+
"sampling_method": "everything",
|
30 |
+
"use_half_precision": 0
|
31 |
+
}
|
32 |
+
}
|
Transformer-Explainability/BERT_params/multirc_soft.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embeddings": {
|
3 |
+
"embedding_file": "model_components/glove.6B.200d.txt",
|
4 |
+
"dropout": 0.2
|
5 |
+
},
|
6 |
+
"classifier": {
|
7 |
+
"classes": [ "False", "True" ],
|
8 |
+
"has_query": 1,
|
9 |
+
"hidden_size": 32,
|
10 |
+
"mlp_size": 128,
|
11 |
+
"dropout": 0.2,
|
12 |
+
"batch_size": 16,
|
13 |
+
"epochs": 50,
|
14 |
+
"attention_epochs": 50,
|
15 |
+
"patience": 10,
|
16 |
+
"lr": 1e-3,
|
17 |
+
"dropout": 0.2,
|
18 |
+
"k_fraction": 0.07,
|
19 |
+
"threshold": 0.1
|
20 |
+
}
|
21 |
+
}
|
Transformer-Explainability/BERT_rationale_benchmark/__init__.py
ADDED
File without changes
|
Transformer-Explainability/BERT_rationale_benchmark/metrics.py
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import pprint
|
6 |
+
from collections import Counter, defaultdict, namedtuple
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from itertools import chain
|
9 |
+
from typing import Any, Callable, Dict, List, Set, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from BERT_rationale_benchmark.utils import (Annotation, Evidence,
|
14 |
+
annotations_from_jsonl,
|
15 |
+
load_documents,
|
16 |
+
load_flattened_documents,
|
17 |
+
load_jsonl)
|
18 |
+
from scipy.stats import entropy
|
19 |
+
from sklearn.metrics import (accuracy_score, auc, average_precision_score,
|
20 |
+
classification_report, precision_recall_curve,
|
21 |
+
roc_auc_score)
|
22 |
+
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
# start_token is inclusive, end_token is exclusive
|
29 |
+
@dataclass(eq=True, frozen=True)
|
30 |
+
class Rationale:
|
31 |
+
ann_id: str
|
32 |
+
docid: str
|
33 |
+
start_token: int
|
34 |
+
end_token: int
|
35 |
+
|
36 |
+
def to_token_level(self) -> List["Rationale"]:
|
37 |
+
ret = []
|
38 |
+
for t in range(self.start_token, self.end_token):
|
39 |
+
ret.append(Rationale(self.ann_id, self.docid, t, t + 1))
|
40 |
+
return ret
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_annotation(cls, ann: Annotation) -> List["Rationale"]:
|
44 |
+
ret = []
|
45 |
+
for ev_group in ann.evidences:
|
46 |
+
for ev in ev_group:
|
47 |
+
ret.append(
|
48 |
+
Rationale(ann.annotation_id, ev.docid, ev.start_token, ev.end_token)
|
49 |
+
)
|
50 |
+
return ret
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def from_instance(cls, inst: dict) -> List["Rationale"]:
|
54 |
+
ret = []
|
55 |
+
for rat in inst["rationales"]:
|
56 |
+
for pred in rat.get("hard_rationale_predictions", []):
|
57 |
+
ret.append(
|
58 |
+
Rationale(
|
59 |
+
inst["annotation_id"],
|
60 |
+
rat["docid"],
|
61 |
+
pred["start_token"],
|
62 |
+
pred["end_token"],
|
63 |
+
)
|
64 |
+
)
|
65 |
+
return ret
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass(eq=True, frozen=True)
|
69 |
+
class PositionScoredDocument:
|
70 |
+
ann_id: str
|
71 |
+
docid: str
|
72 |
+
scores: Tuple[float]
|
73 |
+
truths: Tuple[bool]
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def from_results(
|
77 |
+
cls,
|
78 |
+
instances: List[dict],
|
79 |
+
annotations: List[Annotation],
|
80 |
+
docs: Dict[str, List[Any]],
|
81 |
+
use_tokens: bool = True,
|
82 |
+
) -> List["PositionScoredDocument"]:
|
83 |
+
"""Creates a paired list of annotation ids/docids/predictions/truth values"""
|
84 |
+
key_to_annotation = dict()
|
85 |
+
for ann in annotations:
|
86 |
+
for ev in chain.from_iterable(ann.evidences):
|
87 |
+
key = (ann.annotation_id, ev.docid)
|
88 |
+
if key not in key_to_annotation:
|
89 |
+
key_to_annotation[key] = [False for _ in docs[ev.docid]]
|
90 |
+
if use_tokens:
|
91 |
+
start, end = ev.start_token, ev.end_token
|
92 |
+
else:
|
93 |
+
start, end = ev.start_sentence, ev.end_sentence
|
94 |
+
for t in range(start, end):
|
95 |
+
key_to_annotation[key][t] = True
|
96 |
+
ret = []
|
97 |
+
if use_tokens:
|
98 |
+
field = "soft_rationale_predictions"
|
99 |
+
else:
|
100 |
+
field = "soft_sentence_predictions"
|
101 |
+
for inst in instances:
|
102 |
+
for rat in inst["rationales"]:
|
103 |
+
docid = rat["docid"]
|
104 |
+
scores = rat[field]
|
105 |
+
key = (inst["annotation_id"], docid)
|
106 |
+
assert len(scores) == len(docs[docid])
|
107 |
+
if key in key_to_annotation:
|
108 |
+
assert len(scores) == len(key_to_annotation[key])
|
109 |
+
else:
|
110 |
+
# In case model makes a prediction on docuemnt(s) for which ground truth evidence is not present
|
111 |
+
key_to_annotation[key] = [False for _ in docs[docid]]
|
112 |
+
ret.append(
|
113 |
+
PositionScoredDocument(
|
114 |
+
inst["annotation_id"],
|
115 |
+
docid,
|
116 |
+
tuple(scores),
|
117 |
+
tuple(key_to_annotation[key]),
|
118 |
+
)
|
119 |
+
)
|
120 |
+
return ret
|
121 |
+
|
122 |
+
|
123 |
+
def _f1(_p, _r):
|
124 |
+
if _p == 0 or _r == 0:
|
125 |
+
return 0
|
126 |
+
return 2 * _p * _r / (_p + _r)
|
127 |
+
|
128 |
+
|
129 |
+
def _keyed_rationale_from_list(
|
130 |
+
rats: List[Rationale],
|
131 |
+
) -> Dict[Tuple[str, str], Rationale]:
|
132 |
+
ret = defaultdict(set)
|
133 |
+
for r in rats:
|
134 |
+
ret[(r.ann_id, r.docid)].add(r)
|
135 |
+
return ret
|
136 |
+
|
137 |
+
|
138 |
+
def partial_match_score(
|
139 |
+
truth: List[Rationale], pred: List[Rationale], thresholds: List[float]
|
140 |
+
) -> List[Dict[str, Any]]:
|
141 |
+
"""Computes a partial match F1
|
142 |
+
|
143 |
+
Computes an instance-level (annotation) micro- and macro-averaged F1 score.
|
144 |
+
True Positives are computed by using intersection-over-union and
|
145 |
+
thresholding the resulting intersection-over-union fraction.
|
146 |
+
|
147 |
+
Micro-average results are computed by ignoring instance level distinctions
|
148 |
+
in the TP calculation (and recall, and precision, and finally the F1 of
|
149 |
+
those numbers). Macro-average results are computed first by measuring
|
150 |
+
instance (annotation + document) precisions and recalls, averaging those,
|
151 |
+
and finally computing an F1 of the resulting average.
|
152 |
+
"""
|
153 |
+
|
154 |
+
ann_to_rat = _keyed_rationale_from_list(truth)
|
155 |
+
pred_to_rat = _keyed_rationale_from_list(pred)
|
156 |
+
|
157 |
+
num_classifications = {k: len(v) for k, v in pred_to_rat.items()}
|
158 |
+
num_truth = {k: len(v) for k, v in ann_to_rat.items()}
|
159 |
+
ious = defaultdict(dict)
|
160 |
+
for k in set(ann_to_rat.keys()) | set(pred_to_rat.keys()):
|
161 |
+
for p in pred_to_rat.get(k, []):
|
162 |
+
best_iou = 0.0
|
163 |
+
for t in ann_to_rat.get(k, []):
|
164 |
+
num = len(
|
165 |
+
set(range(p.start_token, p.end_token))
|
166 |
+
& set(range(t.start_token, t.end_token))
|
167 |
+
)
|
168 |
+
denom = len(
|
169 |
+
set(range(p.start_token, p.end_token))
|
170 |
+
| set(range(t.start_token, t.end_token))
|
171 |
+
)
|
172 |
+
iou = 0 if denom == 0 else num / denom
|
173 |
+
if iou > best_iou:
|
174 |
+
best_iou = iou
|
175 |
+
ious[k][p] = best_iou
|
176 |
+
scores = []
|
177 |
+
for threshold in thresholds:
|
178 |
+
threshold_tps = dict()
|
179 |
+
for k, vs in ious.items():
|
180 |
+
threshold_tps[k] = sum(int(x >= threshold) for x in vs.values())
|
181 |
+
micro_r = (
|
182 |
+
sum(threshold_tps.values()) / sum(num_truth.values())
|
183 |
+
if sum(num_truth.values()) > 0
|
184 |
+
else 0
|
185 |
+
)
|
186 |
+
micro_p = (
|
187 |
+
sum(threshold_tps.values()) / sum(num_classifications.values())
|
188 |
+
if sum(num_classifications.values()) > 0
|
189 |
+
else 0
|
190 |
+
)
|
191 |
+
micro_f1 = _f1(micro_r, micro_p)
|
192 |
+
macro_rs = list(
|
193 |
+
threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_truth.items()
|
194 |
+
)
|
195 |
+
macro_ps = list(
|
196 |
+
threshold_tps.get(k, 0.0) / n if n > 0 else 0
|
197 |
+
for k, n in num_classifications.items()
|
198 |
+
)
|
199 |
+
macro_r = sum(macro_rs) / len(macro_rs) if len(macro_rs) > 0 else 0
|
200 |
+
macro_p = sum(macro_ps) / len(macro_ps) if len(macro_ps) > 0 else 0
|
201 |
+
macro_f1 = _f1(macro_r, macro_p)
|
202 |
+
scores.append(
|
203 |
+
{
|
204 |
+
"threshold": threshold,
|
205 |
+
"micro": {"p": micro_p, "r": micro_r, "f1": micro_f1},
|
206 |
+
"macro": {"p": macro_p, "r": macro_r, "f1": macro_f1},
|
207 |
+
}
|
208 |
+
)
|
209 |
+
return scores
|
210 |
+
|
211 |
+
|
212 |
+
def score_hard_rationale_predictions(
|
213 |
+
truth: List[Rationale], pred: List[Rationale]
|
214 |
+
) -> Dict[str, Dict[str, float]]:
|
215 |
+
"""Computes instance (annotation)-level micro/macro averaged F1s"""
|
216 |
+
scores = dict()
|
217 |
+
truth = set(truth)
|
218 |
+
pred = set(pred)
|
219 |
+
micro_prec = len(truth & pred) / len(pred)
|
220 |
+
micro_rec = len(truth & pred) / len(truth)
|
221 |
+
micro_f1 = _f1(micro_prec, micro_rec)
|
222 |
+
scores["instance_micro"] = {
|
223 |
+
"p": micro_prec,
|
224 |
+
"r": micro_rec,
|
225 |
+
"f1": micro_f1,
|
226 |
+
}
|
227 |
+
|
228 |
+
ann_to_rat = _keyed_rationale_from_list(truth)
|
229 |
+
pred_to_rat = _keyed_rationale_from_list(pred)
|
230 |
+
instances_to_scores = dict()
|
231 |
+
for k in set(ann_to_rat.keys()) | (pred_to_rat.keys()):
|
232 |
+
if len(pred_to_rat.get(k, set())) > 0:
|
233 |
+
instance_prec = len(
|
234 |
+
ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())
|
235 |
+
) / len(pred_to_rat[k])
|
236 |
+
else:
|
237 |
+
instance_prec = 0
|
238 |
+
if len(ann_to_rat.get(k, set())) > 0:
|
239 |
+
instance_rec = len(
|
240 |
+
ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())
|
241 |
+
) / len(ann_to_rat[k])
|
242 |
+
else:
|
243 |
+
instance_rec = 0
|
244 |
+
instance_f1 = _f1(instance_prec, instance_rec)
|
245 |
+
instances_to_scores[k] = {
|
246 |
+
"p": instance_prec,
|
247 |
+
"r": instance_rec,
|
248 |
+
"f1": instance_f1,
|
249 |
+
}
|
250 |
+
# these are calculated as sklearn would
|
251 |
+
macro_prec = sum(instance["p"] for instance in instances_to_scores.values()) / len(
|
252 |
+
instances_to_scores
|
253 |
+
)
|
254 |
+
macro_rec = sum(instance["r"] for instance in instances_to_scores.values()) / len(
|
255 |
+
instances_to_scores
|
256 |
+
)
|
257 |
+
macro_f1 = sum(instance["f1"] for instance in instances_to_scores.values()) / len(
|
258 |
+
instances_to_scores
|
259 |
+
)
|
260 |
+
|
261 |
+
f1_scores = [instance["f1"] for instance in instances_to_scores.values()]
|
262 |
+
print(macro_f1, np.argsort(f1_scores)[::-1])
|
263 |
+
|
264 |
+
scores["instance_macro"] = {
|
265 |
+
"p": macro_prec,
|
266 |
+
"r": macro_rec,
|
267 |
+
"f1": macro_f1,
|
268 |
+
}
|
269 |
+
return scores
|
270 |
+
|
271 |
+
|
272 |
+
def _auprc(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]]) -> float:
|
273 |
+
if len(preds) == 0:
|
274 |
+
return 0.0
|
275 |
+
assert len(truth.keys() and preds.keys()) == len(truth.keys())
|
276 |
+
aucs = []
|
277 |
+
for k, true in truth.items():
|
278 |
+
pred = preds[k]
|
279 |
+
true = [int(t) for t in true]
|
280 |
+
precision, recall, _ = precision_recall_curve(true, pred)
|
281 |
+
aucs.append(auc(recall, precision))
|
282 |
+
return np.average(aucs)
|
283 |
+
|
284 |
+
|
285 |
+
def _score_aggregator(
|
286 |
+
truth: Dict[Any, List[bool]],
|
287 |
+
preds: Dict[Any, List[float]],
|
288 |
+
score_function: Callable[[List[float], List[float]], float],
|
289 |
+
discard_single_class_answers: bool,
|
290 |
+
) -> float:
|
291 |
+
if len(preds) == 0:
|
292 |
+
return 0.0
|
293 |
+
assert len(truth.keys() and preds.keys()) == len(truth.keys())
|
294 |
+
scores = []
|
295 |
+
for k, true in truth.items():
|
296 |
+
pred = preds[k]
|
297 |
+
if (all(true) or all(not x for x in true)) and discard_single_class_answers:
|
298 |
+
continue
|
299 |
+
true = [int(t) for t in true]
|
300 |
+
scores.append(score_function(true, pred))
|
301 |
+
return np.average(scores)
|
302 |
+
|
303 |
+
|
304 |
+
def score_soft_tokens(paired_scores: List[PositionScoredDocument]) -> Dict[str, float]:
|
305 |
+
truth = {(ps.ann_id, ps.docid): ps.truths for ps in paired_scores}
|
306 |
+
pred = {(ps.ann_id, ps.docid): ps.scores for ps in paired_scores}
|
307 |
+
auprc_score = _auprc(truth, pred)
|
308 |
+
ap = _score_aggregator(truth, pred, average_precision_score, True)
|
309 |
+
roc_auc = _score_aggregator(truth, pred, roc_auc_score, True)
|
310 |
+
|
311 |
+
return {
|
312 |
+
"auprc": auprc_score,
|
313 |
+
"average_precision": ap,
|
314 |
+
"roc_auc_score": roc_auc,
|
315 |
+
}
|
316 |
+
|
317 |
+
|
318 |
+
def _instances_aopc(
|
319 |
+
instances: List[dict], thresholds: List[float], key: str
|
320 |
+
) -> Tuple[float, List[float]]:
|
321 |
+
dataset_scores = []
|
322 |
+
for inst in instances:
|
323 |
+
kls = inst["classification"]
|
324 |
+
beta_0 = inst["classification_scores"][kls]
|
325 |
+
instance_scores = []
|
326 |
+
for score in filter(
|
327 |
+
lambda x: x["threshold"] in thresholds,
|
328 |
+
sorted(inst["thresholded_scores"], key=lambda x: x["threshold"]),
|
329 |
+
):
|
330 |
+
beta_k = score[key][kls]
|
331 |
+
delta = beta_0 - beta_k
|
332 |
+
instance_scores.append(delta)
|
333 |
+
assert len(instance_scores) == len(thresholds)
|
334 |
+
dataset_scores.append(instance_scores)
|
335 |
+
dataset_scores = np.array(dataset_scores)
|
336 |
+
# a careful reading of Samek, et al. "Evaluating the Visualization of What a Deep Neural Network Has Learned"
|
337 |
+
# and some algebra will show the reader that we can average in any of several ways and get the same result:
|
338 |
+
# over a flattened array, within an instance and then between instances, or over instances (by position) an
|
339 |
+
# then across them.
|
340 |
+
final_score = np.average(dataset_scores)
|
341 |
+
position_scores = np.average(dataset_scores, axis=0).tolist()
|
342 |
+
|
343 |
+
return final_score, position_scores
|
344 |
+
|
345 |
+
|
346 |
+
def compute_aopc_scores(instances: List[dict], aopc_thresholds: List[float]):
|
347 |
+
if aopc_thresholds is None:
|
348 |
+
aopc_thresholds = sorted(
|
349 |
+
set(
|
350 |
+
chain.from_iterable(
|
351 |
+
[x["threshold"] for x in y["thresholded_scores"]] for y in instances
|
352 |
+
)
|
353 |
+
)
|
354 |
+
)
|
355 |
+
aopc_comprehensiveness_score, aopc_comprehensiveness_points = _instances_aopc(
|
356 |
+
instances, aopc_thresholds, "comprehensiveness_classification_scores"
|
357 |
+
)
|
358 |
+
aopc_sufficiency_score, aopc_sufficiency_points = _instances_aopc(
|
359 |
+
instances, aopc_thresholds, "sufficiency_classification_scores"
|
360 |
+
)
|
361 |
+
return (
|
362 |
+
aopc_thresholds,
|
363 |
+
aopc_comprehensiveness_score,
|
364 |
+
aopc_comprehensiveness_points,
|
365 |
+
aopc_sufficiency_score,
|
366 |
+
aopc_sufficiency_points,
|
367 |
+
)
|
368 |
+
|
369 |
+
|
370 |
+
def score_classifications(
|
371 |
+
instances: List[dict],
|
372 |
+
annotations: List[Annotation],
|
373 |
+
docs: Dict[str, List[str]],
|
374 |
+
aopc_thresholds: List[float],
|
375 |
+
) -> Dict[str, float]:
|
376 |
+
def compute_kl(cls_scores_, faith_scores_):
|
377 |
+
keys = list(cls_scores_.keys())
|
378 |
+
cls_scores_ = [cls_scores_[k] for k in keys]
|
379 |
+
faith_scores_ = [faith_scores_[k] for k in keys]
|
380 |
+
return entropy(faith_scores_, cls_scores_)
|
381 |
+
|
382 |
+
labels = list(set(x.classification for x in annotations))
|
383 |
+
label_to_int = {l: i for i, l in enumerate(labels)}
|
384 |
+
key_to_instances = {inst["annotation_id"]: inst for inst in instances}
|
385 |
+
truth = []
|
386 |
+
predicted = []
|
387 |
+
for ann in annotations:
|
388 |
+
truth.append(label_to_int[ann.classification])
|
389 |
+
inst = key_to_instances[ann.annotation_id]
|
390 |
+
predicted.append(label_to_int[inst["classification"]])
|
391 |
+
classification_scores = classification_report(
|
392 |
+
truth, predicted, output_dict=True, target_names=labels, digits=3
|
393 |
+
)
|
394 |
+
accuracy = accuracy_score(truth, predicted)
|
395 |
+
if "comprehensiveness_classification_scores" in instances[0]:
|
396 |
+
comprehensiveness_scores = [
|
397 |
+
x["classification_scores"][x["classification"]]
|
398 |
+
- x["comprehensiveness_classification_scores"][x["classification"]]
|
399 |
+
for x in instances
|
400 |
+
]
|
401 |
+
comprehensiveness_score = np.average(comprehensiveness_scores)
|
402 |
+
else:
|
403 |
+
comprehensiveness_score = None
|
404 |
+
comprehensiveness_scores = None
|
405 |
+
|
406 |
+
if "sufficiency_classification_scores" in instances[0]:
|
407 |
+
sufficiency_scores = [
|
408 |
+
x["classification_scores"][x["classification"]]
|
409 |
+
- x["sufficiency_classification_scores"][x["classification"]]
|
410 |
+
for x in instances
|
411 |
+
]
|
412 |
+
sufficiency_score = np.average(sufficiency_scores)
|
413 |
+
else:
|
414 |
+
sufficiency_score = None
|
415 |
+
sufficiency_scores = None
|
416 |
+
|
417 |
+
if "comprehensiveness_classification_scores" in instances[0]:
|
418 |
+
comprehensiveness_entropies = [
|
419 |
+
entropy(list(x["classification_scores"].values()))
|
420 |
+
- entropy(list(x["comprehensiveness_classification_scores"].values()))
|
421 |
+
for x in instances
|
422 |
+
]
|
423 |
+
comprehensiveness_entropy = np.average(comprehensiveness_entropies)
|
424 |
+
comprehensiveness_kl = np.average(
|
425 |
+
list(
|
426 |
+
compute_kl(
|
427 |
+
x["classification_scores"],
|
428 |
+
x["comprehensiveness_classification_scores"],
|
429 |
+
)
|
430 |
+
for x in instances
|
431 |
+
)
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
comprehensiveness_entropies = None
|
435 |
+
comprehensiveness_kl = None
|
436 |
+
comprehensiveness_entropy = None
|
437 |
+
|
438 |
+
if "sufficiency_classification_scores" in instances[0]:
|
439 |
+
sufficiency_entropies = [
|
440 |
+
entropy(list(x["classification_scores"].values()))
|
441 |
+
- entropy(list(x["sufficiency_classification_scores"].values()))
|
442 |
+
for x in instances
|
443 |
+
]
|
444 |
+
sufficiency_entropy = np.average(sufficiency_entropies)
|
445 |
+
sufficiency_kl = np.average(
|
446 |
+
list(
|
447 |
+
compute_kl(
|
448 |
+
x["classification_scores"], x["sufficiency_classification_scores"]
|
449 |
+
)
|
450 |
+
for x in instances
|
451 |
+
)
|
452 |
+
)
|
453 |
+
else:
|
454 |
+
sufficiency_entropies = None
|
455 |
+
sufficiency_kl = None
|
456 |
+
sufficiency_entropy = None
|
457 |
+
|
458 |
+
if "thresholded_scores" in instances[0]:
|
459 |
+
(
|
460 |
+
aopc_thresholds,
|
461 |
+
aopc_comprehensiveness_score,
|
462 |
+
aopc_comprehensiveness_points,
|
463 |
+
aopc_sufficiency_score,
|
464 |
+
aopc_sufficiency_points,
|
465 |
+
) = compute_aopc_scores(instances, aopc_thresholds)
|
466 |
+
else:
|
467 |
+
(
|
468 |
+
aopc_thresholds,
|
469 |
+
aopc_comprehensiveness_score,
|
470 |
+
aopc_comprehensiveness_points,
|
471 |
+
aopc_sufficiency_score,
|
472 |
+
aopc_sufficiency_points,
|
473 |
+
) = (None, None, None, None, None)
|
474 |
+
if "tokens_to_flip" in instances[0]:
|
475 |
+
token_percentages = []
|
476 |
+
for ann in annotations:
|
477 |
+
# in practice, this is of size 1 for everything except e-snli
|
478 |
+
docids = set(ev.docid for ev in chain.from_iterable(ann.evidences))
|
479 |
+
inst = key_to_instances[ann.annotation_id]
|
480 |
+
tokens = inst["tokens_to_flip"]
|
481 |
+
doc_lengths = sum(len(docs[d]) for d in docids)
|
482 |
+
token_percentages.append(tokens / doc_lengths)
|
483 |
+
token_percentages = np.average(token_percentages)
|
484 |
+
else:
|
485 |
+
token_percentages = None
|
486 |
+
|
487 |
+
return {
|
488 |
+
"accuracy": accuracy,
|
489 |
+
"prf": classification_scores,
|
490 |
+
"comprehensiveness": comprehensiveness_score,
|
491 |
+
"sufficiency": sufficiency_score,
|
492 |
+
"comprehensiveness_entropy": comprehensiveness_entropy,
|
493 |
+
"comprehensiveness_kl": comprehensiveness_kl,
|
494 |
+
"sufficiency_entropy": sufficiency_entropy,
|
495 |
+
"sufficiency_kl": sufficiency_kl,
|
496 |
+
"aopc_thresholds": aopc_thresholds,
|
497 |
+
"comprehensiveness_aopc": aopc_comprehensiveness_score,
|
498 |
+
"comprehensiveness_aopc_points": aopc_comprehensiveness_points,
|
499 |
+
"sufficiency_aopc": aopc_sufficiency_score,
|
500 |
+
"sufficiency_aopc_points": aopc_sufficiency_points,
|
501 |
+
}
|
502 |
+
|
503 |
+
|
504 |
+
def verify_instance(instance: dict, docs: Dict[str, list], thresholds: Set[float]):
|
505 |
+
error = False
|
506 |
+
docids = []
|
507 |
+
# verify the internal structure of these instances is correct:
|
508 |
+
# * hard predictions are present
|
509 |
+
# * start and end tokens are valid
|
510 |
+
# * soft rationale predictions, if present, must have the same document length
|
511 |
+
|
512 |
+
for rat in instance["rationales"]:
|
513 |
+
docid = rat["docid"]
|
514 |
+
if docid not in docid:
|
515 |
+
error = True
|
516 |
+
logging.info(
|
517 |
+
f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} could not be found as a preprocessed document! Gave up on additional processing.'
|
518 |
+
)
|
519 |
+
continue
|
520 |
+
doc_length = len(docs[docid])
|
521 |
+
for h1 in rat.get("hard_rationale_predictions", []):
|
522 |
+
# verify that each token is valid
|
523 |
+
# verify that no annotations overlap
|
524 |
+
for h2 in rat.get("hard_rationale_predictions", []):
|
525 |
+
if h1 == h2:
|
526 |
+
continue
|
527 |
+
if (
|
528 |
+
len(
|
529 |
+
set(range(h1["start_token"], h1["end_token"]))
|
530 |
+
& set(range(h2["start_token"], h2["end_token"]))
|
531 |
+
)
|
532 |
+
> 0
|
533 |
+
):
|
534 |
+
logging.info(
|
535 |
+
f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} {h1} and {h2} overlap!'
|
536 |
+
)
|
537 |
+
error = True
|
538 |
+
if h1["start_token"] > doc_length:
|
539 |
+
logging.info(
|
540 |
+
f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}'
|
541 |
+
)
|
542 |
+
error = True
|
543 |
+
if h1["end_token"] > doc_length:
|
544 |
+
logging.info(
|
545 |
+
f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}'
|
546 |
+
)
|
547 |
+
error = True
|
548 |
+
# length check for soft rationale
|
549 |
+
# note that either flattened_documents or sentence-broken documents must be passed in depending on result
|
550 |
+
soft_rationale_predictions = rat.get("soft_rationale_predictions", [])
|
551 |
+
if (
|
552 |
+
len(soft_rationale_predictions) > 0
|
553 |
+
and len(soft_rationale_predictions) != doc_length
|
554 |
+
):
|
555 |
+
logging.info(
|
556 |
+
f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} expected classifications for {doc_length} tokens but have them for {len(soft_rationale_predictions)} tokens instead!'
|
557 |
+
)
|
558 |
+
error = True
|
559 |
+
|
560 |
+
# count that one appears per-document
|
561 |
+
docids = Counter(docids)
|
562 |
+
for docid, count in docids.items():
|
563 |
+
if count > 1:
|
564 |
+
error = True
|
565 |
+
logging.info(
|
566 |
+
'Error! For instance annotation={instance["annotation_id"]}, docid={docid} appear {count} times, may only appear once!'
|
567 |
+
)
|
568 |
+
|
569 |
+
classification = instance.get("classification", "")
|
570 |
+
if not isinstance(classification, str):
|
571 |
+
logging.info(
|
572 |
+
f'Error! For instance annotation={instance["annotation_id"]}, classification field {classification} is not a string!'
|
573 |
+
)
|
574 |
+
error = True
|
575 |
+
classification_scores = instance.get("classification_scores", dict())
|
576 |
+
if not isinstance(classification_scores, dict):
|
577 |
+
logging.info(
|
578 |
+
f'Error! For instance annotation={instance["annotation_id"]}, classification_scores field {classification_scores} is not a dict!'
|
579 |
+
)
|
580 |
+
error = True
|
581 |
+
comprehensiveness_classification_scores = instance.get(
|
582 |
+
"comprehensiveness_classification_scores", dict()
|
583 |
+
)
|
584 |
+
if not isinstance(comprehensiveness_classification_scores, dict):
|
585 |
+
logging.info(
|
586 |
+
f'Error! For instance annotation={instance["annotation_id"]}, comprehensiveness_classification_scores field {comprehensiveness_classification_scores} is not a dict!'
|
587 |
+
)
|
588 |
+
error = True
|
589 |
+
sufficiency_classification_scores = instance.get(
|
590 |
+
"sufficiency_classification_scores", dict()
|
591 |
+
)
|
592 |
+
if not isinstance(sufficiency_classification_scores, dict):
|
593 |
+
logging.info(
|
594 |
+
f'Error! For instance annotation={instance["annotation_id"]}, sufficiency_classification_scores field {sufficiency_classification_scores} is not a dict!'
|
595 |
+
)
|
596 |
+
error = True
|
597 |
+
if ("classification" in instance) != ("classification_scores" in instance):
|
598 |
+
logging.info(
|
599 |
+
f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide classification scores!'
|
600 |
+
)
|
601 |
+
error = True
|
602 |
+
if ("comprehensiveness_classification_scores" in instance) and not (
|
603 |
+
"classification" in instance
|
604 |
+
):
|
605 |
+
logging.info(
|
606 |
+
f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide a comprehensiveness_classification_score'
|
607 |
+
)
|
608 |
+
error = True
|
609 |
+
if ("sufficiency_classification_scores" in instance) and not (
|
610 |
+
"classification_scores" in instance
|
611 |
+
):
|
612 |
+
logging.info(
|
613 |
+
f'Error! For instance annotation={instance["annotation_id"]}, when providing a sufficiency_classification_score, you must also provide a classification score!'
|
614 |
+
)
|
615 |
+
error = True
|
616 |
+
if "thresholded_scores" in instance:
|
617 |
+
instance_thresholds = set(
|
618 |
+
x["threshold"] for x in instance["thresholded_scores"]
|
619 |
+
)
|
620 |
+
if instance_thresholds != thresholds:
|
621 |
+
error = True
|
622 |
+
logging.info(
|
623 |
+
'Error: {instance["thresholded_scores"]} has thresholds that differ from previous thresholds: {thresholds}'
|
624 |
+
)
|
625 |
+
if (
|
626 |
+
"comprehensiveness_classification_scores" not in instance
|
627 |
+
or "sufficiency_classification_scores" not in instance
|
628 |
+
or "classification" not in instance
|
629 |
+
or "classification_scores" not in instance
|
630 |
+
):
|
631 |
+
error = True
|
632 |
+
logging.info(
|
633 |
+
"Error: {instance} must have comprehensiveness_classification_scores, sufficiency_classification_scores, classification, and classification_scores defined when including thresholded scores"
|
634 |
+
)
|
635 |
+
if not all(
|
636 |
+
"sufficiency_classification_scores" in x
|
637 |
+
for x in instance["thresholded_scores"]
|
638 |
+
):
|
639 |
+
error = True
|
640 |
+
logging.info(
|
641 |
+
"Error: {instance} must have sufficiency_classification_scores for every threshold"
|
642 |
+
)
|
643 |
+
if not all(
|
644 |
+
"comprehensiveness_classification_scores" in x
|
645 |
+
for x in instance["thresholded_scores"]
|
646 |
+
):
|
647 |
+
error = True
|
648 |
+
logging.info(
|
649 |
+
"Error: {instance} must have comprehensiveness_classification_scores for every threshold"
|
650 |
+
)
|
651 |
+
return error
|
652 |
+
|
653 |
+
|
654 |
+
def verify_instances(instances: List[dict], docs: Dict[str, list]):
|
655 |
+
annotation_ids = list(x["annotation_id"] for x in instances)
|
656 |
+
key_counter = Counter(annotation_ids)
|
657 |
+
multi_occurrence_annotation_ids = list(
|
658 |
+
filter(lambda kv: kv[1] > 1, key_counter.items())
|
659 |
+
)
|
660 |
+
error = False
|
661 |
+
if len(multi_occurrence_annotation_ids) > 0:
|
662 |
+
error = True
|
663 |
+
logging.info(
|
664 |
+
f"Error in instances: {len(multi_occurrence_annotation_ids)} appear multiple times in the annotations file: {multi_occurrence_annotation_ids}"
|
665 |
+
)
|
666 |
+
failed_validation = set()
|
667 |
+
instances_with_classification = list()
|
668 |
+
instances_with_soft_rationale_predictions = list()
|
669 |
+
instances_with_soft_sentence_predictions = list()
|
670 |
+
instances_with_comprehensiveness_classifications = list()
|
671 |
+
instances_with_sufficiency_classifications = list()
|
672 |
+
instances_with_thresholded_scores = list()
|
673 |
+
if "thresholded_scores" in instances[0]:
|
674 |
+
thresholds = set(x["threshold"] for x in instances[0]["thresholded_scores"])
|
675 |
+
else:
|
676 |
+
thresholds = None
|
677 |
+
for instance in instances:
|
678 |
+
instance_error = verify_instance(instance, docs, thresholds)
|
679 |
+
if instance_error:
|
680 |
+
error = True
|
681 |
+
failed_validation.add(instance["annotation_id"])
|
682 |
+
if instance.get("classification", None) != None:
|
683 |
+
instances_with_classification.append(instance)
|
684 |
+
if instance.get("comprehensiveness_classification_scores", None) != None:
|
685 |
+
instances_with_comprehensiveness_classifications.append(instance)
|
686 |
+
if instance.get("sufficiency_classification_scores", None) != None:
|
687 |
+
instances_with_sufficiency_classifications.append(instance)
|
688 |
+
has_soft_rationales = []
|
689 |
+
has_soft_sentences = []
|
690 |
+
for rat in instance["rationales"]:
|
691 |
+
if rat.get("soft_rationale_predictions", None) != None:
|
692 |
+
has_soft_rationales.append(rat)
|
693 |
+
if rat.get("soft_sentence_predictions", None) != None:
|
694 |
+
has_soft_sentences.append(rat)
|
695 |
+
if len(has_soft_rationales) > 0:
|
696 |
+
instances_with_soft_rationale_predictions.append(instance)
|
697 |
+
if len(has_soft_rationales) != len(instance["rationales"]):
|
698 |
+
error = True
|
699 |
+
logging.info(
|
700 |
+
f'Error: instance {instance["annotation"]} has soft rationales for some but not all reported documents!'
|
701 |
+
)
|
702 |
+
if len(has_soft_sentences) > 0:
|
703 |
+
instances_with_soft_sentence_predictions.append(instance)
|
704 |
+
if len(has_soft_sentences) != len(instance["rationales"]):
|
705 |
+
error = True
|
706 |
+
logging.info(
|
707 |
+
f'Error: instance {instance["annotation"]} has soft sentences for some but not all reported documents!'
|
708 |
+
)
|
709 |
+
if "thresholded_scores" in instance:
|
710 |
+
instances_with_thresholded_scores.append(instance)
|
711 |
+
logging.info(
|
712 |
+
f"Error in instances: {len(failed_validation)} instances fail validation: {failed_validation}"
|
713 |
+
)
|
714 |
+
if len(instances_with_classification) != 0 and len(
|
715 |
+
instances_with_classification
|
716 |
+
) != len(instances):
|
717 |
+
logging.info(
|
718 |
+
f"Either all {len(instances)} must have a classification or none may, instead {len(instances_with_classification)} do!"
|
719 |
+
)
|
720 |
+
error = True
|
721 |
+
if len(instances_with_soft_sentence_predictions) != 0 and len(
|
722 |
+
instances_with_soft_sentence_predictions
|
723 |
+
) != len(instances):
|
724 |
+
logging.info(
|
725 |
+
f"Either all {len(instances)} must have a sentence prediction or none may, instead {len(instances_with_soft_sentence_predictions)} do!"
|
726 |
+
)
|
727 |
+
error = True
|
728 |
+
if len(instances_with_soft_rationale_predictions) != 0 and len(
|
729 |
+
instances_with_soft_rationale_predictions
|
730 |
+
) != len(instances):
|
731 |
+
logging.info(
|
732 |
+
f"Either all {len(instances)} must have a soft rationale prediction or none may, instead {len(instances_with_soft_rationale_predictions)} do!"
|
733 |
+
)
|
734 |
+
error = True
|
735 |
+
if len(instances_with_comprehensiveness_classifications) != 0 and len(
|
736 |
+
instances_with_comprehensiveness_classifications
|
737 |
+
) != len(instances):
|
738 |
+
error = True
|
739 |
+
logging.info(
|
740 |
+
f"Either all {len(instances)} must have a comprehensiveness classification or none may, instead {len(instances_with_comprehensiveness_classifications)} do!"
|
741 |
+
)
|
742 |
+
if len(instances_with_sufficiency_classifications) != 0 and len(
|
743 |
+
instances_with_sufficiency_classifications
|
744 |
+
) != len(instances):
|
745 |
+
error = True
|
746 |
+
logging.info(
|
747 |
+
f"Either all {len(instances)} must have a sufficiency classification or none may, instead {len(instances_with_sufficiency_classifications)} do!"
|
748 |
+
)
|
749 |
+
if len(instances_with_thresholded_scores) != 0 and len(
|
750 |
+
instances_with_thresholded_scores
|
751 |
+
) != len(instances):
|
752 |
+
error = True
|
753 |
+
logging.info(
|
754 |
+
f"Either all {len(instances)} must have thresholded scores or none may, instead {len(instances_with_thresholded_scores)} do!"
|
755 |
+
)
|
756 |
+
if error:
|
757 |
+
raise ValueError(
|
758 |
+
"Some instances are invalid, please fix your formatting and try again"
|
759 |
+
)
|
760 |
+
|
761 |
+
|
762 |
+
def _has_hard_predictions(results: List[dict]) -> bool:
|
763 |
+
# assumes that we have run "verification" over the inputs
|
764 |
+
return (
|
765 |
+
"rationales" in results[0]
|
766 |
+
and len(results[0]["rationales"]) > 0
|
767 |
+
and "hard_rationale_predictions" in results[0]["rationales"][0]
|
768 |
+
and results[0]["rationales"][0]["hard_rationale_predictions"] is not None
|
769 |
+
and len(results[0]["rationales"][0]["hard_rationale_predictions"]) > 0
|
770 |
+
)
|
771 |
+
|
772 |
+
|
773 |
+
def _has_soft_predictions(results: List[dict]) -> bool:
|
774 |
+
# assumes that we have run "verification" over the inputs
|
775 |
+
return (
|
776 |
+
"rationales" in results[0]
|
777 |
+
and len(results[0]["rationales"]) > 0
|
778 |
+
and "soft_rationale_predictions" in results[0]["rationales"][0]
|
779 |
+
and results[0]["rationales"][0]["soft_rationale_predictions"] is not None
|
780 |
+
)
|
781 |
+
|
782 |
+
|
783 |
+
def _has_soft_sentence_predictions(results: List[dict]) -> bool:
|
784 |
+
# assumes that we have run "verification" over the inputs
|
785 |
+
return (
|
786 |
+
"rationales" in results[0]
|
787 |
+
and len(results[0]["rationales"]) > 0
|
788 |
+
and "soft_sentence_predictions" in results[0]["rationales"][0]
|
789 |
+
and results[0]["rationales"][0]["soft_sentence_predictions"] is not None
|
790 |
+
)
|
791 |
+
|
792 |
+
|
793 |
+
def _has_classifications(results: List[dict]) -> bool:
|
794 |
+
# assumes that we have run "verification" over the inputs
|
795 |
+
return "classification" in results[0] and results[0]["classification"] is not None
|
796 |
+
|
797 |
+
|
798 |
+
def main():
|
799 |
+
parser = argparse.ArgumentParser(
|
800 |
+
description="""Computes rationale and final class classification scores""",
|
801 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
802 |
+
)
|
803 |
+
parser.add_argument(
|
804 |
+
"--data_dir",
|
805 |
+
dest="data_dir",
|
806 |
+
required=True,
|
807 |
+
help="Which directory contains a {train,val,test}.jsonl file?",
|
808 |
+
)
|
809 |
+
parser.add_argument(
|
810 |
+
"--split",
|
811 |
+
dest="split",
|
812 |
+
required=True,
|
813 |
+
help="Which of {train,val,test} are we scoring on?",
|
814 |
+
)
|
815 |
+
parser.add_argument(
|
816 |
+
"--strict",
|
817 |
+
dest="strict",
|
818 |
+
required=False,
|
819 |
+
action="store_true",
|
820 |
+
default=False,
|
821 |
+
help="Do we perform strict scoring?",
|
822 |
+
)
|
823 |
+
parser.add_argument(
|
824 |
+
"--results",
|
825 |
+
dest="results",
|
826 |
+
required=True,
|
827 |
+
help="""Results File
|
828 |
+
Contents are expected to be jsonl of:
|
829 |
+
{
|
830 |
+
"annotation_id": str, required
|
831 |
+
# these classifications *must not* overlap
|
832 |
+
"rationales": List[
|
833 |
+
{
|
834 |
+
"docid": str, required
|
835 |
+
"hard_rationale_predictions": List[{
|
836 |
+
"start_token": int, inclusive, required
|
837 |
+
"end_token": int, exclusive, required
|
838 |
+
}], optional,
|
839 |
+
# token level classifications, a value must be provided per-token
|
840 |
+
# in an ideal world, these correspond to the hard-decoding above.
|
841 |
+
"soft_rationale_predictions": List[float], optional.
|
842 |
+
# sentence level classifications, a value must be provided for every
|
843 |
+
# sentence in each document, or not at all
|
844 |
+
"soft_sentence_predictions": List[float], optional.
|
845 |
+
}
|
846 |
+
],
|
847 |
+
# the classification the model made for the overall classification task
|
848 |
+
"classification": str, optional
|
849 |
+
# A probability distribution output by the model. We require this to be normalized.
|
850 |
+
"classification_scores": Dict[str, float], optional
|
851 |
+
# The next two fields are measures for how faithful your model is (the
|
852 |
+
# rationales it predicts are in some sense causal of the prediction), and
|
853 |
+
# how sufficient they are. We approximate a measure for comprehensiveness by
|
854 |
+
# asking that you remove the top k%% of tokens from your documents,
|
855 |
+
# running your models again, and reporting the score distribution in the
|
856 |
+
# "comprehensiveness_classification_scores" field.
|
857 |
+
# We approximate a measure of sufficiency by asking exactly the converse
|
858 |
+
# - that you provide model distributions on the removed k%% tokens.
|
859 |
+
# 'k' is determined by human rationales, and is documented in our paper.
|
860 |
+
# You should determine which of these tokens to remove based on some kind
|
861 |
+
# of information about your model: gradient based, attention based, other
|
862 |
+
# interpretability measures, etc.
|
863 |
+
# scores per class having removed k%% of the data, where k is determined by human comprehensive rationales
|
864 |
+
"comprehensiveness_classification_scores": Dict[str, float], optional
|
865 |
+
# scores per class having access to only k%% of the data, where k is determined by human comprehensive rationales
|
866 |
+
"sufficiency_classification_scores": Dict[str, float], optional
|
867 |
+
# the number of tokens required to flip the prediction - see "Is Attention Interpretable" by Serrano and Smith.
|
868 |
+
"tokens_to_flip": int, optional
|
869 |
+
"thresholded_scores": List[{
|
870 |
+
"threshold": float, required,
|
871 |
+
"comprehensiveness_classification_scores": like "classification_scores"
|
872 |
+
"sufficiency_classification_scores": like "classification_scores"
|
873 |
+
}], optional. if present, then "classification" and "classification_scores" must be present
|
874 |
+
}
|
875 |
+
When providing one of the optional fields, it must be provided for *every* instance.
|
876 |
+
The classification, classification_score, and comprehensiveness_classification_scores
|
877 |
+
must together be present for every instance or absent for every instance.
|
878 |
+
""",
|
879 |
+
)
|
880 |
+
parser.add_argument(
|
881 |
+
"--iou_thresholds",
|
882 |
+
dest="iou_thresholds",
|
883 |
+
required=False,
|
884 |
+
nargs="+",
|
885 |
+
type=float,
|
886 |
+
default=[0.5],
|
887 |
+
help="""Thresholds for IOU scoring.
|
888 |
+
|
889 |
+
These are used for "soft" or partial match scoring of rationale spans.
|
890 |
+
A span is considered a match if the size of the intersection of the prediction
|
891 |
+
and the annotation, divided by the union of the two spans, is larger than
|
892 |
+
the IOU threshold. This score can be computed for arbitrary thresholds.
|
893 |
+
""",
|
894 |
+
)
|
895 |
+
parser.add_argument(
|
896 |
+
"--score_file",
|
897 |
+
dest="score_file",
|
898 |
+
required=False,
|
899 |
+
default=None,
|
900 |
+
help="Where to write results?",
|
901 |
+
)
|
902 |
+
parser.add_argument(
|
903 |
+
"--aopc_thresholds",
|
904 |
+
nargs="+",
|
905 |
+
required=False,
|
906 |
+
type=float,
|
907 |
+
default=[0.01, 0.05, 0.1, 0.2, 0.5],
|
908 |
+
help="Thresholds for AOPC Thresholds",
|
909 |
+
)
|
910 |
+
args = parser.parse_args()
|
911 |
+
results = load_jsonl(args.results)
|
912 |
+
docids = set(
|
913 |
+
chain.from_iterable(
|
914 |
+
[rat["docid"] for rat in res["rationales"]] for res in results
|
915 |
+
)
|
916 |
+
)
|
917 |
+
docs = load_flattened_documents(args.data_dir, docids)
|
918 |
+
verify_instances(results, docs)
|
919 |
+
# load truth
|
920 |
+
annotations = annotations_from_jsonl(
|
921 |
+
os.path.join(args.data_dir, args.split + ".jsonl")
|
922 |
+
)
|
923 |
+
docids |= set(
|
924 |
+
chain.from_iterable(
|
925 |
+
(ev.docid for ev in chain.from_iterable(ann.evidences))
|
926 |
+
for ann in annotations
|
927 |
+
)
|
928 |
+
)
|
929 |
+
|
930 |
+
has_final_predictions = _has_classifications(results)
|
931 |
+
scores = dict()
|
932 |
+
if args.strict:
|
933 |
+
if not args.iou_thresholds:
|
934 |
+
raise ValueError(
|
935 |
+
"--iou_thresholds must be provided when running strict scoring"
|
936 |
+
)
|
937 |
+
if not has_final_predictions:
|
938 |
+
raise ValueError(
|
939 |
+
"We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!"
|
940 |
+
)
|
941 |
+
# TODO think about offering a sentence level version of these scores.
|
942 |
+
if _has_hard_predictions(results):
|
943 |
+
truth = list(
|
944 |
+
chain.from_iterable(Rationale.from_annotation(ann) for ann in annotations)
|
945 |
+
)
|
946 |
+
pred = list(
|
947 |
+
chain.from_iterable(Rationale.from_instance(inst) for inst in results)
|
948 |
+
)
|
949 |
+
if args.iou_thresholds is not None:
|
950 |
+
iou_scores = partial_match_score(truth, pred, args.iou_thresholds)
|
951 |
+
scores["iou_scores"] = iou_scores
|
952 |
+
# NER style scoring
|
953 |
+
rationale_level_prf = score_hard_rationale_predictions(truth, pred)
|
954 |
+
scores["rationale_prf"] = rationale_level_prf
|
955 |
+
token_level_truth = list(
|
956 |
+
chain.from_iterable(rat.to_token_level() for rat in truth)
|
957 |
+
)
|
958 |
+
token_level_pred = list(
|
959 |
+
chain.from_iterable(rat.to_token_level() for rat in pred)
|
960 |
+
)
|
961 |
+
token_level_prf = score_hard_rationale_predictions(
|
962 |
+
token_level_truth, token_level_pred
|
963 |
+
)
|
964 |
+
scores["token_prf"] = token_level_prf
|
965 |
+
else:
|
966 |
+
logging.info("No hard predictions detected, skipping rationale scoring")
|
967 |
+
|
968 |
+
if _has_soft_predictions(results):
|
969 |
+
flattened_documents = load_flattened_documents(args.data_dir, docids)
|
970 |
+
paired_scoring = PositionScoredDocument.from_results(
|
971 |
+
results, annotations, flattened_documents, use_tokens=True
|
972 |
+
)
|
973 |
+
token_scores = score_soft_tokens(paired_scoring)
|
974 |
+
scores["token_soft_metrics"] = token_scores
|
975 |
+
else:
|
976 |
+
logging.info("No soft predictions detected, skipping rationale scoring")
|
977 |
+
|
978 |
+
if _has_soft_sentence_predictions(results):
|
979 |
+
documents = load_documents(args.data_dir, docids)
|
980 |
+
paired_scoring = PositionScoredDocument.from_results(
|
981 |
+
results, annotations, documents, use_tokens=False
|
982 |
+
)
|
983 |
+
sentence_scores = score_soft_tokens(paired_scoring)
|
984 |
+
scores["sentence_soft_metrics"] = sentence_scores
|
985 |
+
else:
|
986 |
+
logging.info(
|
987 |
+
"No sentence level predictions detected, skipping sentence-level diagnostic"
|
988 |
+
)
|
989 |
+
|
990 |
+
if has_final_predictions:
|
991 |
+
flattened_documents = load_flattened_documents(args.data_dir, docids)
|
992 |
+
class_results = score_classifications(
|
993 |
+
results, annotations, flattened_documents, args.aopc_thresholds
|
994 |
+
)
|
995 |
+
scores["classification_scores"] = class_results
|
996 |
+
else:
|
997 |
+
logging.info("No classification scores detected, skipping classification")
|
998 |
+
|
999 |
+
pprint.pprint(scores)
|
1000 |
+
|
1001 |
+
if args.score_file:
|
1002 |
+
with open(args.score_file, "w") as of:
|
1003 |
+
json.dump(scores, of, indent=4, sort_keys=True)
|
1004 |
+
|
1005 |
+
|
1006 |
+
if __name__ == "__main__":
|
1007 |
+
main()
|
Transformer-Explainability/BERT_rationale_benchmark/models/model_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Set
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from gensim.models import KeyedVectors
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.utils.rnn import (PackedSequence, pack_padded_sequence,
|
9 |
+
pad_packed_sequence, pad_sequence)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass(eq=True, frozen=True)
|
13 |
+
class PaddedSequence:
|
14 |
+
"""A utility class for padding variable length sequences mean for RNN input
|
15 |
+
This class is in the style of PackedSequence from the PyTorch RNN Utils,
|
16 |
+
but is somewhat more manual in approach. It provides the ability to generate masks
|
17 |
+
for outputs of the same input dimensions.
|
18 |
+
The constructor should never be called directly and should only be called via
|
19 |
+
the autopad classmethod.
|
20 |
+
|
21 |
+
We'd love to delete this, but we pad_sequence, pack_padded_sequence, and
|
22 |
+
pad_packed_sequence all require shuffling around tuples of information, and some
|
23 |
+
convenience methods using these are nice to have.
|
24 |
+
"""
|
25 |
+
|
26 |
+
data: torch.Tensor
|
27 |
+
batch_sizes: torch.Tensor
|
28 |
+
batch_first: bool = False
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def autopad(
|
32 |
+
cls, data, batch_first: bool = False, padding_value=0, device=None
|
33 |
+
) -> "PaddedSequence":
|
34 |
+
# handle tensors of size 0 (single item)
|
35 |
+
data_ = []
|
36 |
+
for d in data:
|
37 |
+
if len(d.size()) == 0:
|
38 |
+
d = d.unsqueeze(0)
|
39 |
+
data_.append(d)
|
40 |
+
padded = pad_sequence(
|
41 |
+
data_, batch_first=batch_first, padding_value=padding_value
|
42 |
+
)
|
43 |
+
if batch_first:
|
44 |
+
batch_lengths = torch.LongTensor([len(x) for x in data_])
|
45 |
+
if any([x == 0 for x in batch_lengths]):
|
46 |
+
raise ValueError(
|
47 |
+
"Found a 0 length batch element, this can't possibly be right: {}".format(
|
48 |
+
batch_lengths
|
49 |
+
)
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
# TODO actually test this codepath
|
53 |
+
batch_lengths = torch.LongTensor([len(x) for x in data])
|
54 |
+
return PaddedSequence(padded, batch_lengths, batch_first).to(device=device)
|
55 |
+
|
56 |
+
def pack_other(self, data: torch.Tensor):
|
57 |
+
return pack_padded_sequence(
|
58 |
+
data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False
|
59 |
+
)
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_packed_sequence(
|
63 |
+
cls, ps: PackedSequence, batch_first: bool, padding_value=0
|
64 |
+
) -> "PaddedSequence":
|
65 |
+
padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value)
|
66 |
+
return PaddedSequence(padded, batch_sizes, batch_first)
|
67 |
+
|
68 |
+
def cuda(self) -> "PaddedSequence":
|
69 |
+
return PaddedSequence(
|
70 |
+
self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first
|
71 |
+
)
|
72 |
+
|
73 |
+
def to(
|
74 |
+
self, dtype=None, device=None, copy=False, non_blocking=False
|
75 |
+
) -> "PaddedSequence":
|
76 |
+
# TODO make to() support all of the torch.Tensor to() variants
|
77 |
+
return PaddedSequence(
|
78 |
+
self.data.to(
|
79 |
+
dtype=dtype, device=device, copy=copy, non_blocking=non_blocking
|
80 |
+
),
|
81 |
+
self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking),
|
82 |
+
batch_first=self.batch_first,
|
83 |
+
)
|
84 |
+
|
85 |
+
def mask(
|
86 |
+
self, on=int(0), off=int(0), device="cpu", size=None, dtype=None
|
87 |
+
) -> torch.Tensor:
|
88 |
+
if size is None:
|
89 |
+
size = self.data.size()
|
90 |
+
out_tensor = torch.zeros(*size, dtype=dtype)
|
91 |
+
# TODO this can be done more efficiently
|
92 |
+
out_tensor.fill_(off)
|
93 |
+
# note to self: these are probably less efficient than explicilty populating the off values instead of the on values.
|
94 |
+
if self.batch_first:
|
95 |
+
for i, bl in enumerate(self.batch_sizes):
|
96 |
+
out_tensor[i, :bl] = on
|
97 |
+
else:
|
98 |
+
for i, bl in enumerate(self.batch_sizes):
|
99 |
+
out_tensor[:bl, i] = on
|
100 |
+
return out_tensor.to(device)
|
101 |
+
|
102 |
+
def unpad(self, other: torch.Tensor) -> List[torch.Tensor]:
|
103 |
+
out = []
|
104 |
+
for o, bl in zip(other, self.batch_sizes):
|
105 |
+
out.append(o[:bl])
|
106 |
+
return out
|
107 |
+
|
108 |
+
def flip(self) -> "PaddedSequence":
|
109 |
+
return PaddedSequence(
|
110 |
+
self.data.transpose(0, 1), not self.batch_first, self.padding_value
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def extract_embeddings(
|
115 |
+
vocab: Set[str], embedding_file: str, unk_token: str = "UNK", pad_token: str = "PAD"
|
116 |
+
) -> (nn.Embedding, Dict[str, int], List[str]):
|
117 |
+
vocab = vocab | set([unk_token, pad_token])
|
118 |
+
if embedding_file.endswith(".bin"):
|
119 |
+
WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True)
|
120 |
+
|
121 |
+
word_to_vector = dict()
|
122 |
+
WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()])
|
123 |
+
|
124 |
+
if unk_token not in WVs:
|
125 |
+
mean_vector = np.mean(WV_matrix, axis=0)
|
126 |
+
word_to_vector[unk_token] = mean_vector
|
127 |
+
if pad_token not in WVs:
|
128 |
+
word_to_vector[pad_token] = np.zeros(WVs.vector_size)
|
129 |
+
|
130 |
+
for v in vocab:
|
131 |
+
if v in WVs:
|
132 |
+
word_to_vector[v] = WVs[v]
|
133 |
+
|
134 |
+
interner = dict()
|
135 |
+
deinterner = list()
|
136 |
+
vectors = []
|
137 |
+
count = 0
|
138 |
+
for word in [pad_token, unk_token] + sorted(
|
139 |
+
list(word_to_vector.keys() - {unk_token, pad_token})
|
140 |
+
):
|
141 |
+
vector = word_to_vector[word]
|
142 |
+
vectors.append(np.array(vector))
|
143 |
+
interner[word] = count
|
144 |
+
deinterner.append(word)
|
145 |
+
count += 1
|
146 |
+
vectors = torch.FloatTensor(np.array(vectors))
|
147 |
+
embedding = nn.Embedding.from_pretrained(
|
148 |
+
vectors, padding_idx=interner[pad_token]
|
149 |
+
)
|
150 |
+
embedding.weight.requires_grad = False
|
151 |
+
return embedding, interner, deinterner
|
152 |
+
elif embedding_file.endswith(".txt"):
|
153 |
+
word_to_vector = dict()
|
154 |
+
vector = []
|
155 |
+
with open(embedding_file, "r") as inf:
|
156 |
+
for line in inf:
|
157 |
+
contents = line.strip().split()
|
158 |
+
word = contents[0]
|
159 |
+
vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0)
|
160 |
+
word_to_vector[word] = vector
|
161 |
+
embed_size = vector.size()
|
162 |
+
if unk_token not in word_to_vector:
|
163 |
+
mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0)
|
164 |
+
word_to_vector[unk_token] = mean_vector.unsqueeze(0)
|
165 |
+
if pad_token not in word_to_vector:
|
166 |
+
word_to_vector[pad_token] = torch.zeros(embed_size)
|
167 |
+
interner = dict()
|
168 |
+
deinterner = list()
|
169 |
+
vectors = []
|
170 |
+
count = 0
|
171 |
+
for word in [pad_token, unk_token] + sorted(
|
172 |
+
list(word_to_vector.keys() - {unk_token, pad_token})
|
173 |
+
):
|
174 |
+
vector = word_to_vector[word]
|
175 |
+
vectors.append(vector)
|
176 |
+
interner[word] = count
|
177 |
+
deinterner.append(word)
|
178 |
+
count += 1
|
179 |
+
vectors = torch.cat(vectors, dim=0)
|
180 |
+
embedding = nn.Embedding.from_pretrained(
|
181 |
+
vectors, padding_idx=interner[pad_token]
|
182 |
+
)
|
183 |
+
embedding.weight.requires_grad = False
|
184 |
+
return embedding, interner, deinterner
|
185 |
+
else:
|
186 |
+
raise ValueError("Unable to open embeddings file {}".format(embedding_file))
|
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/__init__.py
ADDED
File without changes
|
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/bert_pipeline.py
ADDED
@@ -0,0 +1,852 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO consider if this can be collapsed back down into the pipeline_train.py
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from collections import OrderedDict
|
8 |
+
from itertools import chain
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from BERT_explainability.modules.BERT.BERT_cls_lrp import \
|
15 |
+
BertForSequenceClassification as BertForClsOrigLrp
|
16 |
+
from BERT_explainability.modules.BERT.BertForSequenceClassification import \
|
17 |
+
BertForSequenceClassification as BertForSequenceClassificationTest
|
18 |
+
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
|
19 |
+
from BERT_rationale_benchmark.utils import (Annotation, Evidence,
|
20 |
+
load_datasets, load_documents,
|
21 |
+
write_jsonl)
|
22 |
+
from sklearn.metrics import accuracy_score
|
23 |
+
from transformers import BertForSequenceClassification, BertTokenizer
|
24 |
+
|
25 |
+
logging.basicConfig(
|
26 |
+
level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
|
27 |
+
)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
# let's make this more or less deterministic (not resistent to restarts)
|
30 |
+
random.seed(12345)
|
31 |
+
np.random.seed(67890)
|
32 |
+
torch.manual_seed(10111213)
|
33 |
+
torch.backends.cudnn.deterministic = True
|
34 |
+
torch.backends.cudnn.benchmark = False
|
35 |
+
|
36 |
+
|
37 |
+
import numpy as np
|
38 |
+
|
39 |
+
latex_special_token = ["!@#$%^&*()"]
|
40 |
+
|
41 |
+
|
42 |
+
def generate(text_list, attention_list, latex_file, color="red", rescale_value=False):
|
43 |
+
attention_list = attention_list[: len(text_list)]
|
44 |
+
if attention_list.max() == attention_list.min():
|
45 |
+
attention_list = torch.zeros_like(attention_list)
|
46 |
+
else:
|
47 |
+
attention_list = (
|
48 |
+
100
|
49 |
+
* (attention_list - attention_list.min())
|
50 |
+
/ (attention_list.max() - attention_list.min())
|
51 |
+
)
|
52 |
+
attention_list[attention_list < 1] = 0
|
53 |
+
attention_list = attention_list.tolist()
|
54 |
+
text_list = [text_list[i].replace("$", "") for i in range(len(text_list))]
|
55 |
+
if rescale_value:
|
56 |
+
attention_list = rescale(attention_list)
|
57 |
+
word_num = len(text_list)
|
58 |
+
text_list = clean_word(text_list)
|
59 |
+
with open(latex_file, "w") as f:
|
60 |
+
f.write(
|
61 |
+
r"""\documentclass[varwidth=150mm]{standalone}
|
62 |
+
\special{papersize=210mm,297mm}
|
63 |
+
\usepackage{color}
|
64 |
+
\usepackage{tcolorbox}
|
65 |
+
\usepackage{CJK}
|
66 |
+
\usepackage{adjustbox}
|
67 |
+
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
68 |
+
\begin{document}
|
69 |
+
\begin{CJK*}{UTF8}{gbsn}"""
|
70 |
+
+ "\n"
|
71 |
+
)
|
72 |
+
string = (
|
73 |
+
r"""{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{"""
|
74 |
+
+ "\n"
|
75 |
+
)
|
76 |
+
for idx in range(word_num):
|
77 |
+
# string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
78 |
+
# print(text_list[idx])
|
79 |
+
if "\#\#" in text_list[idx]:
|
80 |
+
token = text_list[idx].replace("\#\#", "")
|
81 |
+
string += (
|
82 |
+
"\\colorbox{%s!%s}{" % (color, attention_list[idx])
|
83 |
+
+ "\\strut "
|
84 |
+
+ token
|
85 |
+
+ "}"
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
string += (
|
89 |
+
" "
|
90 |
+
+ "\\colorbox{%s!%s}{" % (color, attention_list[idx])
|
91 |
+
+ "\\strut "
|
92 |
+
+ text_list[idx]
|
93 |
+
+ "}"
|
94 |
+
)
|
95 |
+
string += "\n}}}"
|
96 |
+
f.write(string + "\n")
|
97 |
+
f.write(
|
98 |
+
r"""\end{CJK*}
|
99 |
+
\end{document}"""
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def clean_word(word_list):
|
104 |
+
new_word_list = []
|
105 |
+
for word in word_list:
|
106 |
+
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
107 |
+
if latex_sensitive in word:
|
108 |
+
word = word.replace(latex_sensitive, "\\" + latex_sensitive)
|
109 |
+
new_word_list.append(word)
|
110 |
+
return new_word_list
|
111 |
+
|
112 |
+
|
113 |
+
def scores_per_word_from_scores_per_token(input, tokenizer, input_ids, scores_per_id):
|
114 |
+
words = tokenizer.convert_ids_to_tokens(input_ids)
|
115 |
+
words = [word.replace("##", "") for word in words]
|
116 |
+
score_per_char = []
|
117 |
+
|
118 |
+
# TODO: DELETE
|
119 |
+
input_ids_chars = []
|
120 |
+
for word in words:
|
121 |
+
if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
|
122 |
+
continue
|
123 |
+
input_ids_chars += list(word)
|
124 |
+
# TODO: DELETE
|
125 |
+
|
126 |
+
for i in range(len(scores_per_id)):
|
127 |
+
if words[i] in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
|
128 |
+
continue
|
129 |
+
score_per_char += [scores_per_id[i]] * len(words[i])
|
130 |
+
|
131 |
+
score_per_word = []
|
132 |
+
start_idx = 0
|
133 |
+
end_idx = 0
|
134 |
+
# TODO: DELETE
|
135 |
+
words_from_chars = []
|
136 |
+
for inp in input:
|
137 |
+
if start_idx >= len(score_per_char):
|
138 |
+
break
|
139 |
+
end_idx = end_idx + len(inp)
|
140 |
+
score_per_word.append(np.max(score_per_char[start_idx:end_idx]))
|
141 |
+
|
142 |
+
# TODO: DELETE
|
143 |
+
words_from_chars.append("".join(input_ids_chars[start_idx:end_idx]))
|
144 |
+
|
145 |
+
start_idx = end_idx
|
146 |
+
|
147 |
+
if words_from_chars[:-1] != input[: len(words_from_chars) - 1]:
|
148 |
+
print(words_from_chars)
|
149 |
+
print(input[: len(words_from_chars)])
|
150 |
+
print(words)
|
151 |
+
print(tokenizer.convert_ids_to_tokens(input_ids))
|
152 |
+
assert False
|
153 |
+
|
154 |
+
return torch.tensor(score_per_word)
|
155 |
+
|
156 |
+
|
157 |
+
def get_input_words(input, tokenizer, input_ids):
|
158 |
+
words = tokenizer.convert_ids_to_tokens(input_ids)
|
159 |
+
words = [word.replace("##", "") for word in words]
|
160 |
+
|
161 |
+
input_ids_chars = []
|
162 |
+
for word in words:
|
163 |
+
if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
|
164 |
+
continue
|
165 |
+
input_ids_chars += list(word)
|
166 |
+
|
167 |
+
start_idx = 0
|
168 |
+
end_idx = 0
|
169 |
+
words_from_chars = []
|
170 |
+
for inp in input:
|
171 |
+
if start_idx >= len(input_ids_chars):
|
172 |
+
break
|
173 |
+
end_idx = end_idx + len(inp)
|
174 |
+
words_from_chars.append("".join(input_ids_chars[start_idx:end_idx]))
|
175 |
+
start_idx = end_idx
|
176 |
+
|
177 |
+
if words_from_chars[:-1] != input[: len(words_from_chars) - 1]:
|
178 |
+
print(words_from_chars)
|
179 |
+
print(input[: len(words_from_chars)])
|
180 |
+
print(words)
|
181 |
+
print(tokenizer.convert_ids_to_tokens(input_ids))
|
182 |
+
assert False
|
183 |
+
return words_from_chars
|
184 |
+
|
185 |
+
|
186 |
+
def bert_tokenize_doc(
|
187 |
+
doc: List[List[str]], tokenizer, special_token_map
|
188 |
+
) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]:
|
189 |
+
"""Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words"""
|
190 |
+
sents = []
|
191 |
+
sent_token_spans = []
|
192 |
+
for sent in doc:
|
193 |
+
tokens = []
|
194 |
+
spans = []
|
195 |
+
start = 0
|
196 |
+
for w in sent:
|
197 |
+
if w in special_token_map:
|
198 |
+
tokens.append(w)
|
199 |
+
else:
|
200 |
+
tokens.extend(tokenizer.tokenize(w))
|
201 |
+
end = len(tokens)
|
202 |
+
spans.append((start, end))
|
203 |
+
start = end
|
204 |
+
sents.append(tokens)
|
205 |
+
sent_token_spans.append(spans)
|
206 |
+
return sents, sent_token_spans
|
207 |
+
|
208 |
+
|
209 |
+
def initialize_models(params: dict, batch_first: bool, use_half_precision=False):
|
210 |
+
assert batch_first
|
211 |
+
max_length = params["max_length"]
|
212 |
+
tokenizer = BertTokenizer.from_pretrained(params["bert_vocab"])
|
213 |
+
pad_token_id = tokenizer.pad_token_id
|
214 |
+
cls_token_id = tokenizer.cls_token_id
|
215 |
+
sep_token_id = tokenizer.sep_token_id
|
216 |
+
bert_dir = params["bert_dir"]
|
217 |
+
evidence_classes = dict(
|
218 |
+
(y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"])
|
219 |
+
)
|
220 |
+
evidence_classifier = BertForSequenceClassification.from_pretrained(
|
221 |
+
bert_dir, num_labels=len(evidence_classes)
|
222 |
+
)
|
223 |
+
word_interner = tokenizer.vocab
|
224 |
+
de_interner = tokenizer.ids_to_tokens
|
225 |
+
return evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer
|
226 |
+
|
227 |
+
|
228 |
+
BATCH_FIRST = True
|
229 |
+
|
230 |
+
|
231 |
+
def extract_docid_from_dataset_element(element):
|
232 |
+
return next(iter(element.evidences))[0].docid
|
233 |
+
|
234 |
+
|
235 |
+
def extract_evidence_from_dataset_element(element):
|
236 |
+
return next(iter(element.evidences))
|
237 |
+
|
238 |
+
|
239 |
+
def main():
|
240 |
+
parser = argparse.ArgumentParser(
|
241 |
+
description="""Trains a pipeline model.
|
242 |
+
|
243 |
+
Step 1 is evidence identification, that is identify if a given sentence is evidence or not
|
244 |
+
Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task
|
245 |
+
(e.g. sentiment or significance).
|
246 |
+
|
247 |
+
These models should be separated into two separate steps, but at the moment:
|
248 |
+
* prep data (load, intern documents, load json)
|
249 |
+
* convert data for evidence identification - in the case of training data we take all the positives and sample some
|
250 |
+
negatives
|
251 |
+
* side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a
|
252 |
+
broader sampling of negative values.
|
253 |
+
* train evidence identification
|
254 |
+
* convert data for evidence classification - take all rationales + decisions and use this as input
|
255 |
+
* train evidence classification
|
256 |
+
* decode first the evidence, then run classification for each split
|
257 |
+
|
258 |
+
""",
|
259 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--data_dir",
|
263 |
+
dest="data_dir",
|
264 |
+
required=True,
|
265 |
+
help="Which directory contains a {train,val,test}.jsonl file?",
|
266 |
+
)
|
267 |
+
parser.add_argument(
|
268 |
+
"--output_dir",
|
269 |
+
dest="output_dir",
|
270 |
+
required=True,
|
271 |
+
help="Where shall we write intermediate models + final data to?",
|
272 |
+
)
|
273 |
+
parser.add_argument(
|
274 |
+
"--model_params",
|
275 |
+
dest="model_params",
|
276 |
+
required=True,
|
277 |
+
help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.",
|
278 |
+
)
|
279 |
+
args = parser.parse_args()
|
280 |
+
assert BATCH_FIRST
|
281 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
282 |
+
|
283 |
+
with open(args.model_params, "r") as fp:
|
284 |
+
logger.info(f"Loading model parameters from {args.model_params}")
|
285 |
+
model_params = json.load(fp)
|
286 |
+
logger.info(f"Params: {json.dumps(model_params, indent=2, sort_keys=True)}")
|
287 |
+
train, val, test = load_datasets(args.data_dir)
|
288 |
+
docids = set(
|
289 |
+
e.docid
|
290 |
+
for e in chain.from_iterable(
|
291 |
+
chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test)))
|
292 |
+
)
|
293 |
+
)
|
294 |
+
documents = load_documents(args.data_dir, docids)
|
295 |
+
logger.info(f"Loaded {len(documents)} documents")
|
296 |
+
(
|
297 |
+
evidence_classifier,
|
298 |
+
word_interner,
|
299 |
+
de_interner,
|
300 |
+
evidence_classes,
|
301 |
+
tokenizer,
|
302 |
+
) = initialize_models(model_params, batch_first=BATCH_FIRST)
|
303 |
+
logger.info(f"We have {len(word_interner)} wordpieces")
|
304 |
+
cache = os.path.join(args.output_dir, "preprocessed.pkl")
|
305 |
+
if os.path.exists(cache):
|
306 |
+
logger.info(f"Loading interned documents from {cache}")
|
307 |
+
(interned_documents) = torch.load(cache)
|
308 |
+
else:
|
309 |
+
logger.info(f"Interning documents")
|
310 |
+
interned_documents = {}
|
311 |
+
for d, doc in documents.items():
|
312 |
+
encoding = tokenizer.encode_plus(
|
313 |
+
doc,
|
314 |
+
add_special_tokens=True,
|
315 |
+
max_length=model_params["max_length"],
|
316 |
+
return_token_type_ids=False,
|
317 |
+
pad_to_max_length=False,
|
318 |
+
return_attention_mask=True,
|
319 |
+
return_tensors="pt",
|
320 |
+
truncation=True,
|
321 |
+
)
|
322 |
+
interned_documents[d] = encoding
|
323 |
+
torch.save((interned_documents), cache)
|
324 |
+
|
325 |
+
evidence_classifier = evidence_classifier.cuda()
|
326 |
+
optimizer = None
|
327 |
+
scheduler = None
|
328 |
+
|
329 |
+
save_dir = args.output_dir
|
330 |
+
|
331 |
+
logging.info(f"Beginning training classifier")
|
332 |
+
evidence_classifier_output_dir = os.path.join(save_dir, "classifier")
|
333 |
+
os.makedirs(save_dir, exist_ok=True)
|
334 |
+
os.makedirs(evidence_classifier_output_dir, exist_ok=True)
|
335 |
+
model_save_file = os.path.join(evidence_classifier_output_dir, "classifier.pt")
|
336 |
+
epoch_save_file = os.path.join(
|
337 |
+
evidence_classifier_output_dir, "classifier_epoch_data.pt"
|
338 |
+
)
|
339 |
+
|
340 |
+
device = next(evidence_classifier.parameters()).device
|
341 |
+
if optimizer is None:
|
342 |
+
optimizer = torch.optim.Adam(
|
343 |
+
evidence_classifier.parameters(),
|
344 |
+
lr=model_params["evidence_classifier"]["lr"],
|
345 |
+
)
|
346 |
+
criterion = nn.CrossEntropyLoss(reduction="none")
|
347 |
+
batch_size = model_params["evidence_classifier"]["batch_size"]
|
348 |
+
epochs = model_params["evidence_classifier"]["epochs"]
|
349 |
+
patience = model_params["evidence_classifier"]["patience"]
|
350 |
+
max_grad_norm = model_params["evidence_classifier"].get("max_grad_norm", None)
|
351 |
+
|
352 |
+
class_labels = [k for k, v in sorted(evidence_classes.items())]
|
353 |
+
|
354 |
+
results = {
|
355 |
+
"train_loss": [],
|
356 |
+
"train_f1": [],
|
357 |
+
"train_acc": [],
|
358 |
+
"val_loss": [],
|
359 |
+
"val_f1": [],
|
360 |
+
"val_acc": [],
|
361 |
+
}
|
362 |
+
best_epoch = -1
|
363 |
+
best_val_acc = 0
|
364 |
+
best_val_loss = float("inf")
|
365 |
+
best_model_state_dict = None
|
366 |
+
start_epoch = 0
|
367 |
+
epoch_data = {}
|
368 |
+
if os.path.exists(epoch_save_file):
|
369 |
+
logging.info(f"Restoring model from {model_save_file}")
|
370 |
+
evidence_classifier.load_state_dict(torch.load(model_save_file))
|
371 |
+
epoch_data = torch.load(epoch_save_file)
|
372 |
+
start_epoch = epoch_data["epoch"] + 1
|
373 |
+
# handle finishing because patience was exceeded or we didn't get the best final epoch
|
374 |
+
if bool(epoch_data.get("done", 0)):
|
375 |
+
start_epoch = epochs
|
376 |
+
results = epoch_data["results"]
|
377 |
+
best_epoch = start_epoch
|
378 |
+
best_model_state_dict = OrderedDict(
|
379 |
+
{k: v.cpu() for k, v in evidence_classifier.state_dict().items()}
|
380 |
+
)
|
381 |
+
logging.info(f"Restoring training from epoch {start_epoch}")
|
382 |
+
logging.info(
|
383 |
+
f"Training evidence classifier from epoch {start_epoch} until epoch {epochs}"
|
384 |
+
)
|
385 |
+
optimizer.zero_grad()
|
386 |
+
for epoch in range(start_epoch, epochs):
|
387 |
+
epoch_train_data = random.sample(train, k=len(train))
|
388 |
+
epoch_train_loss = 0
|
389 |
+
epoch_training_acc = 0
|
390 |
+
evidence_classifier.train()
|
391 |
+
logging.info(
|
392 |
+
f"Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples"
|
393 |
+
)
|
394 |
+
for batch_start in range(0, len(epoch_train_data), batch_size):
|
395 |
+
batch_elements = epoch_train_data[
|
396 |
+
batch_start : min(batch_start + batch_size, len(epoch_train_data))
|
397 |
+
]
|
398 |
+
targets = [evidence_classes[s.classification] for s in batch_elements]
|
399 |
+
targets = torch.tensor(targets, dtype=torch.long, device=device)
|
400 |
+
samples_encoding = [
|
401 |
+
interned_documents[extract_docid_from_dataset_element(s)]
|
402 |
+
for s in batch_elements
|
403 |
+
]
|
404 |
+
input_ids = (
|
405 |
+
torch.stack(
|
406 |
+
[
|
407 |
+
samples_encoding[i]["input_ids"]
|
408 |
+
for i in range(len(samples_encoding))
|
409 |
+
]
|
410 |
+
)
|
411 |
+
.squeeze(1)
|
412 |
+
.to(device)
|
413 |
+
)
|
414 |
+
attention_masks = (
|
415 |
+
torch.stack(
|
416 |
+
[
|
417 |
+
samples_encoding[i]["attention_mask"]
|
418 |
+
for i in range(len(samples_encoding))
|
419 |
+
]
|
420 |
+
)
|
421 |
+
.squeeze(1)
|
422 |
+
.to(device)
|
423 |
+
)
|
424 |
+
preds = evidence_classifier(
|
425 |
+
input_ids=input_ids, attention_mask=attention_masks
|
426 |
+
)[0]
|
427 |
+
epoch_training_acc += accuracy_score(
|
428 |
+
preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False
|
429 |
+
)
|
430 |
+
loss = criterion(preds, targets.to(device=preds.device)).sum()
|
431 |
+
epoch_train_loss += loss.item()
|
432 |
+
loss.backward()
|
433 |
+
assert loss == loss # for nans
|
434 |
+
if max_grad_norm:
|
435 |
+
torch.nn.utils.clip_grad_norm_(
|
436 |
+
evidence_classifier.parameters(), max_grad_norm
|
437 |
+
)
|
438 |
+
optimizer.step()
|
439 |
+
if scheduler:
|
440 |
+
scheduler.step()
|
441 |
+
optimizer.zero_grad()
|
442 |
+
epoch_train_loss /= len(epoch_train_data)
|
443 |
+
epoch_training_acc /= len(epoch_train_data)
|
444 |
+
assert epoch_train_loss == epoch_train_loss # for nans
|
445 |
+
results["train_loss"].append(epoch_train_loss)
|
446 |
+
logging.info(f"Epoch {epoch} training loss {epoch_train_loss}")
|
447 |
+
logging.info(f"Epoch {epoch} training accuracy {epoch_training_acc}")
|
448 |
+
|
449 |
+
with torch.no_grad():
|
450 |
+
epoch_val_loss = 0
|
451 |
+
epoch_val_acc = 0
|
452 |
+
epoch_val_data = random.sample(val, k=len(val))
|
453 |
+
evidence_classifier.eval()
|
454 |
+
val_batch_size = 32
|
455 |
+
logging.info(
|
456 |
+
f"Validating with {len(epoch_val_data) // val_batch_size} batches with {len(epoch_val_data)} examples"
|
457 |
+
)
|
458 |
+
for batch_start in range(0, len(epoch_val_data), val_batch_size):
|
459 |
+
batch_elements = epoch_val_data[
|
460 |
+
batch_start : min(batch_start + val_batch_size, len(epoch_val_data))
|
461 |
+
]
|
462 |
+
targets = [evidence_classes[s.classification] for s in batch_elements]
|
463 |
+
targets = torch.tensor(targets, dtype=torch.long, device=device)
|
464 |
+
samples_encoding = [
|
465 |
+
interned_documents[extract_docid_from_dataset_element(s)]
|
466 |
+
for s in batch_elements
|
467 |
+
]
|
468 |
+
input_ids = (
|
469 |
+
torch.stack(
|
470 |
+
[
|
471 |
+
samples_encoding[i]["input_ids"]
|
472 |
+
for i in range(len(samples_encoding))
|
473 |
+
]
|
474 |
+
)
|
475 |
+
.squeeze(1)
|
476 |
+
.to(device)
|
477 |
+
)
|
478 |
+
attention_masks = (
|
479 |
+
torch.stack(
|
480 |
+
[
|
481 |
+
samples_encoding[i]["attention_mask"]
|
482 |
+
for i in range(len(samples_encoding))
|
483 |
+
]
|
484 |
+
)
|
485 |
+
.squeeze(1)
|
486 |
+
.to(device)
|
487 |
+
)
|
488 |
+
preds = evidence_classifier(
|
489 |
+
input_ids=input_ids, attention_mask=attention_masks
|
490 |
+
)[0]
|
491 |
+
epoch_val_acc += accuracy_score(
|
492 |
+
preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False
|
493 |
+
)
|
494 |
+
loss = criterion(preds, targets.to(device=preds.device)).sum()
|
495 |
+
epoch_val_loss += loss.item()
|
496 |
+
|
497 |
+
epoch_val_loss /= len(val)
|
498 |
+
epoch_val_acc /= len(val)
|
499 |
+
results["val_acc"].append(epoch_val_acc)
|
500 |
+
results["val_loss"] = epoch_val_loss
|
501 |
+
|
502 |
+
logging.info(f"Epoch {epoch} val loss {epoch_val_loss}")
|
503 |
+
logging.info(f"Epoch {epoch} val acc {epoch_val_acc}")
|
504 |
+
|
505 |
+
if epoch_val_acc > best_val_acc or (
|
506 |
+
epoch_val_acc == best_val_acc and epoch_val_loss < best_val_loss
|
507 |
+
):
|
508 |
+
best_model_state_dict = OrderedDict(
|
509 |
+
{k: v.cpu() for k, v in evidence_classifier.state_dict().items()}
|
510 |
+
)
|
511 |
+
best_epoch = epoch
|
512 |
+
best_val_acc = epoch_val_acc
|
513 |
+
best_val_loss = epoch_val_loss
|
514 |
+
epoch_data = {
|
515 |
+
"epoch": epoch,
|
516 |
+
"results": results,
|
517 |
+
"best_val_acc": best_val_acc,
|
518 |
+
"done": 0,
|
519 |
+
}
|
520 |
+
torch.save(evidence_classifier.state_dict(), model_save_file)
|
521 |
+
torch.save(epoch_data, epoch_save_file)
|
522 |
+
logging.debug(
|
523 |
+
f"Epoch {epoch} new best model with val accuracy {epoch_val_acc}"
|
524 |
+
)
|
525 |
+
if epoch - best_epoch > patience:
|
526 |
+
logging.info(f"Exiting after epoch {epoch} due to no improvement")
|
527 |
+
epoch_data["done"] = 1
|
528 |
+
torch.save(epoch_data, epoch_save_file)
|
529 |
+
break
|
530 |
+
|
531 |
+
epoch_data["done"] = 1
|
532 |
+
epoch_data["results"] = results
|
533 |
+
torch.save(epoch_data, epoch_save_file)
|
534 |
+
evidence_classifier.load_state_dict(best_model_state_dict)
|
535 |
+
evidence_classifier = evidence_classifier.to(device=device)
|
536 |
+
evidence_classifier.eval()
|
537 |
+
|
538 |
+
# test
|
539 |
+
|
540 |
+
test_classifier = BertForSequenceClassificationTest.from_pretrained(
|
541 |
+
model_params["bert_dir"], num_labels=len(evidence_classes)
|
542 |
+
).to(device)
|
543 |
+
orig_lrp_classifier = BertForClsOrigLrp.from_pretrained(
|
544 |
+
model_params["bert_dir"], num_labels=len(evidence_classes)
|
545 |
+
).to(device)
|
546 |
+
if os.path.exists(epoch_save_file):
|
547 |
+
logging.info(f"Restoring model from {model_save_file}")
|
548 |
+
test_classifier.load_state_dict(torch.load(model_save_file))
|
549 |
+
orig_lrp_classifier.load_state_dict(torch.load(model_save_file))
|
550 |
+
test_classifier.eval()
|
551 |
+
orig_lrp_classifier.eval()
|
552 |
+
test_batch_size = 1
|
553 |
+
logging.info(
|
554 |
+
f"Testing with {len(test) // test_batch_size} batches with {len(test)} examples"
|
555 |
+
)
|
556 |
+
|
557 |
+
# explainability
|
558 |
+
explanations = Generator(test_classifier)
|
559 |
+
explanations_orig_lrp = Generator(orig_lrp_classifier)
|
560 |
+
method = "transformer_attribution"
|
561 |
+
method_folder = {
|
562 |
+
"transformer_attribution": "ours",
|
563 |
+
"partial_lrp": "partial_lrp",
|
564 |
+
"last_attn": "last_attn",
|
565 |
+
"attn_gradcam": "attn_gradcam",
|
566 |
+
"lrp": "lrp",
|
567 |
+
"rollout": "rollout",
|
568 |
+
"ground_truth": "ground_truth",
|
569 |
+
"generate_all": "generate_all",
|
570 |
+
}
|
571 |
+
method_expl = {
|
572 |
+
"transformer_attribution": explanations.generate_LRP,
|
573 |
+
"partial_lrp": explanations_orig_lrp.generate_LRP_last_layer,
|
574 |
+
"last_attn": explanations_orig_lrp.generate_attn_last_layer,
|
575 |
+
"attn_gradcam": explanations_orig_lrp.generate_attn_gradcam,
|
576 |
+
"lrp": explanations_orig_lrp.generate_full_lrp,
|
577 |
+
"rollout": explanations_orig_lrp.generate_rollout,
|
578 |
+
}
|
579 |
+
|
580 |
+
os.makedirs(os.path.join(args.output_dir, method_folder[method]), exist_ok=True)
|
581 |
+
|
582 |
+
result_files = []
|
583 |
+
for i in range(5, 85, 5):
|
584 |
+
result_files.append(
|
585 |
+
open(
|
586 |
+
os.path.join(
|
587 |
+
args.output_dir, "{0}/identifier_results_{1}.json"
|
588 |
+
).format(method_folder[method], i),
|
589 |
+
"w",
|
590 |
+
)
|
591 |
+
)
|
592 |
+
|
593 |
+
j = 0
|
594 |
+
for batch_start in range(0, len(test), test_batch_size):
|
595 |
+
batch_elements = test[
|
596 |
+
batch_start : min(batch_start + test_batch_size, len(test))
|
597 |
+
]
|
598 |
+
targets = [evidence_classes[s.classification] for s in batch_elements]
|
599 |
+
targets = torch.tensor(targets, dtype=torch.long, device=device)
|
600 |
+
samples_encoding = [
|
601 |
+
interned_documents[extract_docid_from_dataset_element(s)]
|
602 |
+
for s in batch_elements
|
603 |
+
]
|
604 |
+
input_ids = (
|
605 |
+
torch.stack(
|
606 |
+
[
|
607 |
+
samples_encoding[i]["input_ids"]
|
608 |
+
for i in range(len(samples_encoding))
|
609 |
+
]
|
610 |
+
)
|
611 |
+
.squeeze(1)
|
612 |
+
.to(device)
|
613 |
+
)
|
614 |
+
attention_masks = (
|
615 |
+
torch.stack(
|
616 |
+
[
|
617 |
+
samples_encoding[i]["attention_mask"]
|
618 |
+
for i in range(len(samples_encoding))
|
619 |
+
]
|
620 |
+
)
|
621 |
+
.squeeze(1)
|
622 |
+
.to(device)
|
623 |
+
)
|
624 |
+
preds = test_classifier(
|
625 |
+
input_ids=input_ids, attention_mask=attention_masks
|
626 |
+
)[0]
|
627 |
+
|
628 |
+
for s in batch_elements:
|
629 |
+
doc_name = extract_docid_from_dataset_element(s)
|
630 |
+
inp = documents[doc_name].split()
|
631 |
+
classification = "neg" if targets.item() == 0 else "pos"
|
632 |
+
is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0
|
633 |
+
if method == "generate_all":
|
634 |
+
file_name = "{0}_{1}_{2}.tex".format(
|
635 |
+
j, classification, is_classification_correct
|
636 |
+
)
|
637 |
+
GT_global = os.path.join(
|
638 |
+
args.output_dir, "{0}/visual_results_{1}.pdf"
|
639 |
+
).format(method_folder["ground_truth"], j)
|
640 |
+
GT_ours = os.path.join(
|
641 |
+
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
|
642 |
+
).format(
|
643 |
+
method_folder["transformer_attribution"],
|
644 |
+
j,
|
645 |
+
classification,
|
646 |
+
is_classification_correct,
|
647 |
+
)
|
648 |
+
CF_ours = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
|
649 |
+
method_folder["transformer_attribution"], j
|
650 |
+
)
|
651 |
+
GT_partial = os.path.join(
|
652 |
+
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
|
653 |
+
).format(
|
654 |
+
method_folder["partial_lrp"],
|
655 |
+
j,
|
656 |
+
classification,
|
657 |
+
is_classification_correct,
|
658 |
+
)
|
659 |
+
CF_partial = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
|
660 |
+
method_folder["partial_lrp"], j
|
661 |
+
)
|
662 |
+
GT_gradcam = os.path.join(
|
663 |
+
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
|
664 |
+
).format(
|
665 |
+
method_folder["attn_gradcam"],
|
666 |
+
j,
|
667 |
+
classification,
|
668 |
+
is_classification_correct,
|
669 |
+
)
|
670 |
+
CF_gradcam = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
|
671 |
+
method_folder["attn_gradcam"], j
|
672 |
+
)
|
673 |
+
GT_lrp = os.path.join(
|
674 |
+
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
|
675 |
+
).format(
|
676 |
+
method_folder["lrp"],
|
677 |
+
j,
|
678 |
+
classification,
|
679 |
+
is_classification_correct,
|
680 |
+
)
|
681 |
+
CF_lrp = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
|
682 |
+
method_folder["lrp"], j
|
683 |
+
)
|
684 |
+
GT_lastattn = os.path.join(
|
685 |
+
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
|
686 |
+
).format(
|
687 |
+
method_folder["last_attn"],
|
688 |
+
j,
|
689 |
+
classification,
|
690 |
+
is_classification_correct,
|
691 |
+
)
|
692 |
+
GT_rollout = os.path.join(
|
693 |
+
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
|
694 |
+
).format(
|
695 |
+
method_folder["rollout"],
|
696 |
+
j,
|
697 |
+
classification,
|
698 |
+
is_classification_correct,
|
699 |
+
)
|
700 |
+
with open(file_name, "w") as f:
|
701 |
+
f.write(
|
702 |
+
r"""\documentclass[varwidth]{standalone}
|
703 |
+
\usepackage{color}
|
704 |
+
\usepackage{tcolorbox}
|
705 |
+
\usepackage{CJK}
|
706 |
+
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
707 |
+
\begin{document}
|
708 |
+
\begin{CJK*}{UTF8}{gbsn}
|
709 |
+
{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{
|
710 |
+
\setlength{\tabcolsep}{2pt} % Default value: 6pt
|
711 |
+
\begin{tabular}{ccc}
|
712 |
+
\includegraphics[width=0.32\linewidth]{"""
|
713 |
+
+ GT_global
|
714 |
+
+ """}&
|
715 |
+
\includegraphics[width=0.32\linewidth]{"""
|
716 |
+
+ GT_ours
|
717 |
+
+ """}&
|
718 |
+
\includegraphics[width=0.32\linewidth]{"""
|
719 |
+
+ CF_ours
|
720 |
+
+ """}\\\\
|
721 |
+
(a) & (b) & (c)\\\\
|
722 |
+
\includegraphics[width=0.32\linewidth]{"""
|
723 |
+
+ GT_partial
|
724 |
+
+ """}&
|
725 |
+
\includegraphics[width=0.32\linewidth]{"""
|
726 |
+
+ CF_partial
|
727 |
+
+ """}&
|
728 |
+
\includegraphics[width=0.32\linewidth]{"""
|
729 |
+
+ GT_gradcam
|
730 |
+
+ """}\\\\
|
731 |
+
(d) & (e) & (f)\\\\
|
732 |
+
\includegraphics[width=0.32\linewidth]{"""
|
733 |
+
+ CF_gradcam
|
734 |
+
+ """}&
|
735 |
+
\includegraphics[width=0.32\linewidth]{"""
|
736 |
+
+ GT_lrp
|
737 |
+
+ """}&
|
738 |
+
\includegraphics[width=0.32\linewidth]{"""
|
739 |
+
+ CF_lrp
|
740 |
+
+ """}\\\\
|
741 |
+
(g) & (h) & (i)\\\\
|
742 |
+
\includegraphics[width=0.32\linewidth]{"""
|
743 |
+
+ GT_lastattn
|
744 |
+
+ """}&
|
745 |
+
\includegraphics[width=0.32\linewidth]{"""
|
746 |
+
+ GT_rollout
|
747 |
+
+ """}&\\\\
|
748 |
+
(j) & (k)&\\\\
|
749 |
+
\end{tabular}
|
750 |
+
}}}
|
751 |
+
\end{CJK*}
|
752 |
+
\end{document}
|
753 |
+
)"""
|
754 |
+
)
|
755 |
+
j += 1
|
756 |
+
break
|
757 |
+
|
758 |
+
if method == "ground_truth":
|
759 |
+
inp_cropped = get_input_words(inp, tokenizer, input_ids[0])
|
760 |
+
cam = torch.zeros(len(inp_cropped))
|
761 |
+
for evidence in extract_evidence_from_dataset_element(s):
|
762 |
+
start_idx = evidence.start_token
|
763 |
+
if start_idx >= len(cam):
|
764 |
+
break
|
765 |
+
end_idx = evidence.end_token
|
766 |
+
cam[start_idx:end_idx] = 1
|
767 |
+
generate(
|
768 |
+
inp_cropped,
|
769 |
+
cam,
|
770 |
+
(
|
771 |
+
os.path.join(
|
772 |
+
args.output_dir, "{0}/visual_results_{1}.tex"
|
773 |
+
).format(method_folder[method], j)
|
774 |
+
),
|
775 |
+
color="green",
|
776 |
+
)
|
777 |
+
j = j + 1
|
778 |
+
break
|
779 |
+
text = tokenizer.convert_ids_to_tokens(input_ids[0])
|
780 |
+
classification = "neg" if targets.item() == 0 else "pos"
|
781 |
+
is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0
|
782 |
+
target_idx = targets.item()
|
783 |
+
cam_target = method_expl[method](
|
784 |
+
input_ids=input_ids,
|
785 |
+
attention_mask=attention_masks,
|
786 |
+
index=target_idx,
|
787 |
+
)[0]
|
788 |
+
cam_target = cam_target.clamp(min=0)
|
789 |
+
generate(
|
790 |
+
text,
|
791 |
+
cam_target,
|
792 |
+
(
|
793 |
+
os.path.join(args.output_dir, "{0}/{1}_GT_{2}_{3}.tex").format(
|
794 |
+
method_folder[method],
|
795 |
+
j,
|
796 |
+
classification,
|
797 |
+
is_classification_correct,
|
798 |
+
)
|
799 |
+
),
|
800 |
+
)
|
801 |
+
if method in [
|
802 |
+
"transformer_attribution",
|
803 |
+
"partial_lrp",
|
804 |
+
"attn_gradcam",
|
805 |
+
"lrp",
|
806 |
+
]:
|
807 |
+
cam_false_class = method_expl[method](
|
808 |
+
input_ids=input_ids,
|
809 |
+
attention_mask=attention_masks,
|
810 |
+
index=1 - target_idx,
|
811 |
+
)[0]
|
812 |
+
cam_false_class = cam_false_class.clamp(min=0)
|
813 |
+
generate(
|
814 |
+
text,
|
815 |
+
cam_false_class,
|
816 |
+
(
|
817 |
+
os.path.join(args.output_dir, "{0}/{1}_CF.tex").format(
|
818 |
+
method_folder[method], j
|
819 |
+
)
|
820 |
+
),
|
821 |
+
)
|
822 |
+
cam = cam_target
|
823 |
+
cam = scores_per_word_from_scores_per_token(
|
824 |
+
inp, tokenizer, input_ids[0], cam
|
825 |
+
)
|
826 |
+
j = j + 1
|
827 |
+
doc_name = extract_docid_from_dataset_element(s)
|
828 |
+
hard_rationales = []
|
829 |
+
for res, i in enumerate(range(5, 85, 5)):
|
830 |
+
print("calculating top ", i)
|
831 |
+
_, indices = cam.topk(k=i)
|
832 |
+
for index in indices.tolist():
|
833 |
+
hard_rationales.append(
|
834 |
+
{"start_token": index, "end_token": index + 1}
|
835 |
+
)
|
836 |
+
result_dict = {
|
837 |
+
"annotation_id": doc_name,
|
838 |
+
"rationales": [
|
839 |
+
{
|
840 |
+
"docid": doc_name,
|
841 |
+
"hard_rationale_predictions": hard_rationales,
|
842 |
+
}
|
843 |
+
],
|
844 |
+
}
|
845 |
+
result_files[res].write(json.dumps(result_dict) + "\n")
|
846 |
+
|
847 |
+
for i in range(len(result_files)):
|
848 |
+
result_files[i].close()
|
849 |
+
|
850 |
+
|
851 |
+
if __name__ == "__main__":
|
852 |
+
main()
|
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_train.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from itertools import chain
|
7 |
+
from typing import Set
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from rationale_benchmark.models.mlp import (AttentiveClassifier,
|
12 |
+
BahadanauAttention, RNNEncoder,
|
13 |
+
WordEmbedder)
|
14 |
+
from rationale_benchmark.models.model_utils import extract_embeddings
|
15 |
+
from rationale_benchmark.models.pipeline.evidence_classifier import \
|
16 |
+
train_evidence_classifier
|
17 |
+
from rationale_benchmark.models.pipeline.evidence_identifier import \
|
18 |
+
train_evidence_identifier
|
19 |
+
from rationale_benchmark.models.pipeline.pipeline_utils import decode
|
20 |
+
from rationale_benchmark.utils import (intern_annotations, intern_documents,
|
21 |
+
load_datasets, load_documents,
|
22 |
+
write_jsonl)
|
23 |
+
|
24 |
+
logging.basicConfig(
|
25 |
+
level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
|
26 |
+
)
|
27 |
+
# let's make this more or less deterministic (not resistant to restarts)
|
28 |
+
random.seed(12345)
|
29 |
+
np.random.seed(67890)
|
30 |
+
torch.manual_seed(10111213)
|
31 |
+
torch.backends.cudnn.deterministic = True
|
32 |
+
torch.backends.cudnn.benchmark = False
|
33 |
+
|
34 |
+
|
35 |
+
def initialize_models(
|
36 |
+
params: dict, vocab: Set[str], batch_first: bool, unk_token="UNK"
|
37 |
+
):
|
38 |
+
# TODO this is obviously asking for some sort of dependency injection. implement if it saves me time.
|
39 |
+
if "embedding_file" in params["embeddings"]:
|
40 |
+
embeddings, word_interner, de_interner = extract_embeddings(
|
41 |
+
vocab, params["embeddings"]["embedding_file"], unk_token=unk_token
|
42 |
+
)
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
embeddings = embeddings.cuda()
|
45 |
+
else:
|
46 |
+
raise ValueError("No 'embedding_file' found in params!")
|
47 |
+
word_embedder = WordEmbedder(embeddings, params["embeddings"]["dropout"])
|
48 |
+
query_encoder = RNNEncoder(
|
49 |
+
word_embedder,
|
50 |
+
batch_first=batch_first,
|
51 |
+
condition=False,
|
52 |
+
attention_mechanism=BahadanauAttention(word_embedder.output_dimension),
|
53 |
+
)
|
54 |
+
document_encoder = RNNEncoder(
|
55 |
+
word_embedder,
|
56 |
+
batch_first=batch_first,
|
57 |
+
condition=True,
|
58 |
+
attention_mechanism=BahadanauAttention(
|
59 |
+
word_embedder.output_dimension, query_size=query_encoder.output_dimension
|
60 |
+
),
|
61 |
+
)
|
62 |
+
evidence_identifier = AttentiveClassifier(
|
63 |
+
document_encoder,
|
64 |
+
query_encoder,
|
65 |
+
2,
|
66 |
+
params["evidence_identifier"]["mlp_size"],
|
67 |
+
params["evidence_identifier"]["dropout"],
|
68 |
+
)
|
69 |
+
query_encoder = RNNEncoder(
|
70 |
+
word_embedder,
|
71 |
+
batch_first=batch_first,
|
72 |
+
condition=False,
|
73 |
+
attention_mechanism=BahadanauAttention(word_embedder.output_dimension),
|
74 |
+
)
|
75 |
+
document_encoder = RNNEncoder(
|
76 |
+
word_embedder,
|
77 |
+
batch_first=batch_first,
|
78 |
+
condition=True,
|
79 |
+
attention_mechanism=BahadanauAttention(
|
80 |
+
word_embedder.output_dimension, query_size=query_encoder.output_dimension
|
81 |
+
),
|
82 |
+
)
|
83 |
+
evidence_classes = dict(
|
84 |
+
(y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"])
|
85 |
+
)
|
86 |
+
evidence_classifier = AttentiveClassifier(
|
87 |
+
document_encoder,
|
88 |
+
query_encoder,
|
89 |
+
len(evidence_classes),
|
90 |
+
params["evidence_classifier"]["mlp_size"],
|
91 |
+
params["evidence_classifier"]["dropout"],
|
92 |
+
)
|
93 |
+
return (
|
94 |
+
evidence_identifier,
|
95 |
+
evidence_classifier,
|
96 |
+
word_interner,
|
97 |
+
de_interner,
|
98 |
+
evidence_classes,
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def main():
|
103 |
+
parser = argparse.ArgumentParser(
|
104 |
+
description="""Trains a pipeline model.
|
105 |
+
|
106 |
+
Step 1 is evidence identification, that is identify if a given sentence is evidence or not
|
107 |
+
Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task (e.g. sentiment or significance).
|
108 |
+
|
109 |
+
These models should be separated into two separate steps, but at the moment:
|
110 |
+
* prep data (load, intern documents, load json)
|
111 |
+
* convert data for evidence identification - in the case of training data we take all the positives and sample some negatives
|
112 |
+
* side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values.
|
113 |
+
* train evidence identification
|
114 |
+
* convert data for evidence classification - take all rationales + decisions and use this as input
|
115 |
+
* train evidence classification
|
116 |
+
* decode first the evidence, then run classification for each split
|
117 |
+
|
118 |
+
""",
|
119 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--data_dir",
|
123 |
+
dest="data_dir",
|
124 |
+
required=True,
|
125 |
+
help="Which directory contains a {train,val,test}.jsonl file?",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--output_dir",
|
129 |
+
dest="output_dir",
|
130 |
+
required=True,
|
131 |
+
help="Where shall we write intermediate models + final data to?",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--model_params",
|
135 |
+
dest="model_params",
|
136 |
+
required=True,
|
137 |
+
help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.",
|
138 |
+
)
|
139 |
+
args = parser.parse_args()
|
140 |
+
BATCH_FIRST = True
|
141 |
+
|
142 |
+
with open(args.model_params, "r") as fp:
|
143 |
+
logging.debug(f"Loading model parameters from {args.model_params}")
|
144 |
+
model_params = json.load(fp)
|
145 |
+
train, val, test = load_datasets(args.data_dir)
|
146 |
+
docids = set(
|
147 |
+
e.docid
|
148 |
+
for e in chain.from_iterable(
|
149 |
+
chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test)))
|
150 |
+
)
|
151 |
+
)
|
152 |
+
documents = load_documents(args.data_dir, docids)
|
153 |
+
document_vocab = set(chain.from_iterable(chain.from_iterable(documents.values())))
|
154 |
+
annotation_vocab = set(
|
155 |
+
chain.from_iterable(e.query.split() for e in chain(train, val, test))
|
156 |
+
)
|
157 |
+
logging.debug(
|
158 |
+
f"Loaded {len(documents)} documents with {len(document_vocab)} unique words"
|
159 |
+
)
|
160 |
+
# this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important
|
161 |
+
vocab = document_vocab | annotation_vocab
|
162 |
+
unk_token = "UNK"
|
163 |
+
(
|
164 |
+
evidence_identifier,
|
165 |
+
evidence_classifier,
|
166 |
+
word_interner,
|
167 |
+
de_interner,
|
168 |
+
evidence_classes,
|
169 |
+
) = initialize_models(
|
170 |
+
model_params, vocab, batch_first=BATCH_FIRST, unk_token=unk_token
|
171 |
+
)
|
172 |
+
logging.debug(
|
173 |
+
f"Including annotations, we have {len(vocab)} total words in the data, with embeddings for {len(word_interner)}"
|
174 |
+
)
|
175 |
+
interned_documents = intern_documents(documents, word_interner, unk_token)
|
176 |
+
interned_train = intern_annotations(train, word_interner, unk_token)
|
177 |
+
interned_val = intern_annotations(val, word_interner, unk_token)
|
178 |
+
interned_test = intern_annotations(test, word_interner, unk_token)
|
179 |
+
assert BATCH_FIRST # for correctness of the split dimension for DataParallel
|
180 |
+
evidence_identifier, evidence_ident_results = train_evidence_identifier(
|
181 |
+
evidence_identifier.cuda(),
|
182 |
+
args.output_dir,
|
183 |
+
interned_train,
|
184 |
+
interned_val,
|
185 |
+
interned_documents,
|
186 |
+
model_params,
|
187 |
+
tensorize_model_inputs=True,
|
188 |
+
)
|
189 |
+
evidence_classifier, evidence_class_results = train_evidence_classifier(
|
190 |
+
evidence_classifier.cuda(),
|
191 |
+
args.output_dir,
|
192 |
+
interned_train,
|
193 |
+
interned_val,
|
194 |
+
interned_documents,
|
195 |
+
model_params,
|
196 |
+
class_interner=evidence_classes,
|
197 |
+
tensorize_model_inputs=True,
|
198 |
+
)
|
199 |
+
pipeline_batch_size = min(
|
200 |
+
[
|
201 |
+
model_params["evidence_classifier"]["batch_size"],
|
202 |
+
model_params["evidence_identifier"]["batch_size"],
|
203 |
+
]
|
204 |
+
)
|
205 |
+
pipeline_results, train_decoded, val_decoded, test_decoded = decode(
|
206 |
+
evidence_identifier,
|
207 |
+
evidence_classifier,
|
208 |
+
interned_train,
|
209 |
+
interned_val,
|
210 |
+
interned_test,
|
211 |
+
interned_documents,
|
212 |
+
evidence_classes,
|
213 |
+
pipeline_batch_size,
|
214 |
+
tensorize_model_inputs=True,
|
215 |
+
)
|
216 |
+
write_jsonl(train_decoded, os.path.join(args.output_dir, "train_decoded.jsonl"))
|
217 |
+
write_jsonl(val_decoded, os.path.join(args.output_dir, "val_decoded.jsonl"))
|
218 |
+
write_jsonl(test_decoded, os.path.join(args.output_dir, "test_decoded.jsonl"))
|
219 |
+
with open(
|
220 |
+
os.path.join(args.output_dir, "identifier_results.json"), "w"
|
221 |
+
) as ident_output, open(
|
222 |
+
os.path.join(args.output_dir, "classifier_results.json"), "w"
|
223 |
+
) as class_output:
|
224 |
+
ident_output.write(json.dumps(evidence_ident_results))
|
225 |
+
class_output.write(json.dumps(evidence_class_results))
|
226 |
+
for k, v in pipeline_results.items():
|
227 |
+
if type(v) is dict:
|
228 |
+
for k1, v1 in v.items():
|
229 |
+
logging.info(f"Pipeline results for {k}, {k1}={v1}")
|
230 |
+
else:
|
231 |
+
logging.info(f"Pipeline results {k}\t={v}")
|
232 |
+
|
233 |
+
|
234 |
+
if __name__ == "__main__":
|
235 |
+
main()
|
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_utils.py
ADDED
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import logging
|
3 |
+
from collections import defaultdict, namedtuple
|
4 |
+
from itertools import chain
|
5 |
+
from typing import Any, Dict, List, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from rationale_benchmark.metrics import (PositionScoredDocument, Rationale,
|
11 |
+
partial_match_score,
|
12 |
+
score_hard_rationale_predictions,
|
13 |
+
score_soft_tokens)
|
14 |
+
from rationale_benchmark.models.model_utils import PaddedSequence
|
15 |
+
from rationale_benchmark.utils import Annotation
|
16 |
+
from sklearn.metrics import accuracy_score, classification_report
|
17 |
+
|
18 |
+
SentenceEvidence = namedtuple(
|
19 |
+
"SentenceEvidence", "kls ann_id query docid index sentence"
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def token_annotations_to_evidence_classification(
|
24 |
+
annotations: List[Annotation],
|
25 |
+
documents: Dict[str, List[List[Any]]],
|
26 |
+
class_interner: Dict[str, int],
|
27 |
+
) -> List[SentenceEvidence]:
|
28 |
+
ret = []
|
29 |
+
for ann in annotations:
|
30 |
+
docid_to_ev = defaultdict(list)
|
31 |
+
for evidence in ann.all_evidences():
|
32 |
+
docid_to_ev[evidence.docid].append(evidence)
|
33 |
+
for docid, evidences in docid_to_ev.items():
|
34 |
+
evidences = sorted(evidences, key=lambda ev: ev.start_token)
|
35 |
+
text = []
|
36 |
+
covered_tokens = set()
|
37 |
+
doc = list(chain.from_iterable(documents[docid]))
|
38 |
+
for evidence in evidences:
|
39 |
+
assert (
|
40 |
+
evidence.start_token >= 0
|
41 |
+
and evidence.end_token > evidence.start_token
|
42 |
+
)
|
43 |
+
assert evidence.start_token < len(doc) and evidence.end_token <= len(
|
44 |
+
doc
|
45 |
+
)
|
46 |
+
text.extend(evidence.text)
|
47 |
+
new_tokens = set(range(evidence.start_token, evidence.end_token))
|
48 |
+
if len(new_tokens & covered_tokens) > 0:
|
49 |
+
raise ValueError(
|
50 |
+
"Have overlapping token ranges covered in the evidence spans and the implementer was lazy; deal with it"
|
51 |
+
)
|
52 |
+
covered_tokens |= new_tokens
|
53 |
+
assert len(text) > 0
|
54 |
+
ret.append(
|
55 |
+
SentenceEvidence(
|
56 |
+
kls=class_interner[ann.classification],
|
57 |
+
query=ann.query,
|
58 |
+
ann_id=ann.annotation_id,
|
59 |
+
docid=docid,
|
60 |
+
index=-1,
|
61 |
+
sentence=tuple(text),
|
62 |
+
)
|
63 |
+
)
|
64 |
+
return ret
|
65 |
+
|
66 |
+
|
67 |
+
def annotations_to_evidence_classification(
|
68 |
+
annotations: List[Annotation],
|
69 |
+
documents: Dict[str, List[List[Any]]],
|
70 |
+
class_interner: Dict[str, int],
|
71 |
+
include_all: bool,
|
72 |
+
) -> List[SentenceEvidence]:
|
73 |
+
"""Converts Corpus-Level annotations to Sentence Level relevance judgments.
|
74 |
+
|
75 |
+
As this module is about a pipelined approach for evidence identification,
|
76 |
+
inputs to both an evidence identifier and evidence classifier need to be to
|
77 |
+
be on a sentence level, this module converts data to be that form.
|
78 |
+
|
79 |
+
The return type is of the form
|
80 |
+
annotation id -> docid -> [sentence level annotations]
|
81 |
+
"""
|
82 |
+
ret = []
|
83 |
+
for ann in annotations:
|
84 |
+
ann_id = ann.annotation_id
|
85 |
+
docids = set(ev.docid for ev in chain.from_iterable(ann.evidences))
|
86 |
+
annotations_for_doc = defaultdict(list)
|
87 |
+
for d in docids:
|
88 |
+
for index, sent in enumerate(documents[d]):
|
89 |
+
annotations_for_doc[d].append(
|
90 |
+
SentenceEvidence(
|
91 |
+
kls=class_interner[ann.classification],
|
92 |
+
query=ann.query,
|
93 |
+
ann_id=ann.annotation_id,
|
94 |
+
docid=d,
|
95 |
+
index=index,
|
96 |
+
sentence=tuple(sent),
|
97 |
+
)
|
98 |
+
)
|
99 |
+
if include_all:
|
100 |
+
ret.extend(chain.from_iterable(annotations_for_doc.values()))
|
101 |
+
else:
|
102 |
+
contributes = set()
|
103 |
+
for ev in chain.from_iterable(ann.evidences):
|
104 |
+
for index in range(ev.start_sentence, ev.end_sentence):
|
105 |
+
contributes.add(annotations_for_doc[ev.docid][index])
|
106 |
+
ret.extend(contributes)
|
107 |
+
assert len(ret) > 0
|
108 |
+
return ret
|
109 |
+
|
110 |
+
|
111 |
+
def annotations_to_evidence_identification(
|
112 |
+
annotations: List[Annotation], documents: Dict[str, List[List[Any]]]
|
113 |
+
) -> Dict[str, Dict[str, List[SentenceEvidence]]]:
|
114 |
+
"""Converts Corpus-Level annotations to Sentence Level relevance judgments.
|
115 |
+
|
116 |
+
As this module is about a pipelined approach for evidence identification,
|
117 |
+
inputs to both an evidence identifier and evidence classifier need to be to
|
118 |
+
be on a sentence level, this module converts data to be that form.
|
119 |
+
|
120 |
+
The return type is of the form
|
121 |
+
annotation id -> docid -> [sentence level annotations]
|
122 |
+
"""
|
123 |
+
ret = defaultdict(dict) # annotation id -> docid -> sentences
|
124 |
+
for ann in annotations:
|
125 |
+
ann_id = ann.annotation_id
|
126 |
+
for ev_group in ann.evidences:
|
127 |
+
for ev in ev_group:
|
128 |
+
if len(ev.text) == 0:
|
129 |
+
continue
|
130 |
+
if ev.docid not in ret[ann_id]:
|
131 |
+
ret[ann.annotation_id][ev.docid] = []
|
132 |
+
# populate the document with "not evidence"; to be filled in later
|
133 |
+
for index, sent in enumerate(documents[ev.docid]):
|
134 |
+
ret[ann.annotation_id][ev.docid].append(
|
135 |
+
SentenceEvidence(
|
136 |
+
kls=0,
|
137 |
+
query=ann.query,
|
138 |
+
ann_id=ann.annotation_id,
|
139 |
+
docid=ev.docid,
|
140 |
+
index=index,
|
141 |
+
sentence=sent,
|
142 |
+
)
|
143 |
+
)
|
144 |
+
# define the evidence sections of the document
|
145 |
+
for s in range(ev.start_sentence, ev.end_sentence):
|
146 |
+
ret[ann.annotation_id][ev.docid][s] = SentenceEvidence(
|
147 |
+
kls=1,
|
148 |
+
ann_id=ann.annotation_id,
|
149 |
+
query=ann.query,
|
150 |
+
docid=ev.docid,
|
151 |
+
index=ret[ann.annotation_id][ev.docid][s].index,
|
152 |
+
sentence=ret[ann.annotation_id][ev.docid][s].sentence,
|
153 |
+
)
|
154 |
+
return ret
|
155 |
+
|
156 |
+
|
157 |
+
def annotations_to_evidence_token_identification(
|
158 |
+
annotations: List[Annotation],
|
159 |
+
source_documents: Dict[str, List[List[str]]],
|
160 |
+
interned_documents: Dict[str, List[List[int]]],
|
161 |
+
token_mapping: Dict[str, List[List[Tuple[int, int]]]],
|
162 |
+
) -> Dict[str, Dict[str, List[SentenceEvidence]]]:
|
163 |
+
# TODO document
|
164 |
+
# TODO should we simplify to use only source text?
|
165 |
+
ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences
|
166 |
+
positive_tokens = 0
|
167 |
+
negative_tokens = 0
|
168 |
+
for ann in annotations:
|
169 |
+
annid = ann.annotation_id
|
170 |
+
docids = set(ev.docid for ev in chain.from_iterable(ann.evidences))
|
171 |
+
sentence_offsets = defaultdict(list) # docid -> [(start, end)]
|
172 |
+
classes = defaultdict(list) # docid -> [token is yea or nay]
|
173 |
+
for docid in docids:
|
174 |
+
start = 0
|
175 |
+
assert len(source_documents[docid]) == len(interned_documents[docid])
|
176 |
+
for whole_token_sent, wordpiece_sent in zip(
|
177 |
+
source_documents[docid], interned_documents[docid]
|
178 |
+
):
|
179 |
+
classes[docid].extend([0 for _ in wordpiece_sent])
|
180 |
+
end = start + len(wordpiece_sent)
|
181 |
+
sentence_offsets[docid].append((start, end))
|
182 |
+
start = end
|
183 |
+
for ev in chain.from_iterable(ann.evidences):
|
184 |
+
if len(ev.text) == 0:
|
185 |
+
continue
|
186 |
+
flat_token_map = list(chain.from_iterable(token_mapping[ev.docid]))
|
187 |
+
if ev.start_token != -1:
|
188 |
+
# start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1]
|
189 |
+
start, end = (
|
190 |
+
flat_token_map[ev.start_token][0],
|
191 |
+
flat_token_map[ev.end_token - 1][1],
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
start = flat_token_map[sentence_offsets[ev.start_sentence][0]][0]
|
195 |
+
end = flat_token_map[sentence_offsets[ev.end_sentence - 1][1]][1]
|
196 |
+
for i in range(start, end):
|
197 |
+
classes[ev.docid][i] = 1
|
198 |
+
for docid, offsets in sentence_offsets.items():
|
199 |
+
token_assignments = classes[docid]
|
200 |
+
positive_tokens += sum(token_assignments)
|
201 |
+
negative_tokens += len(token_assignments) - sum(token_assignments)
|
202 |
+
for s, (start, end) in enumerate(offsets):
|
203 |
+
sent = interned_documents[docid][s]
|
204 |
+
ret[annid][docid].append(
|
205 |
+
SentenceEvidence(
|
206 |
+
kls=tuple(token_assignments[start:end]),
|
207 |
+
query=ann.query,
|
208 |
+
ann_id=ann.annotation_id,
|
209 |
+
docid=docid,
|
210 |
+
index=s,
|
211 |
+
sentence=sent,
|
212 |
+
)
|
213 |
+
)
|
214 |
+
logging.info(
|
215 |
+
f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens"
|
216 |
+
)
|
217 |
+
return ret
|
218 |
+
|
219 |
+
|
220 |
+
def make_preds_batch(
|
221 |
+
classifier: nn.Module,
|
222 |
+
batch_elements: List[SentenceEvidence],
|
223 |
+
device=None,
|
224 |
+
criterion: nn.Module = None,
|
225 |
+
tensorize_model_inputs: bool = True,
|
226 |
+
) -> Tuple[float, List[float], List[int], List[int]]:
|
227 |
+
"""Batch predictions
|
228 |
+
|
229 |
+
Args:
|
230 |
+
classifier: a module that looks like an AttentiveClassifier
|
231 |
+
batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects.
|
232 |
+
device: Optional; what compute device this should run on
|
233 |
+
criterion: Optional; a loss function
|
234 |
+
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
|
235 |
+
"""
|
236 |
+
# delete any "None" padding, if any (imposed by the use of the "grouper")
|
237 |
+
batch_elements = filter(lambda x: x is not None, batch_elements)
|
238 |
+
targets, queries, sentences = zip(
|
239 |
+
*[(s.kls, s.query, s.sentence) for s in batch_elements]
|
240 |
+
)
|
241 |
+
ids = [(s.ann_id, s.docid, s.index) for s in batch_elements]
|
242 |
+
targets = torch.tensor(targets, dtype=torch.long, device=device)
|
243 |
+
if tensorize_model_inputs:
|
244 |
+
queries = [torch.tensor(q, dtype=torch.long) for q in queries]
|
245 |
+
sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
|
246 |
+
preds = classifier(queries, ids, sentences)
|
247 |
+
targets = targets.to(device=preds.device)
|
248 |
+
if criterion:
|
249 |
+
loss = criterion(preds, targets)
|
250 |
+
else:
|
251 |
+
loss = None
|
252 |
+
# .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16
|
253 |
+
hard_preds = torch.argmax(preds.float(), dim=-1)
|
254 |
+
return loss, preds, hard_preds, targets
|
255 |
+
|
256 |
+
|
257 |
+
def make_preds_epoch(
|
258 |
+
classifier: nn.Module,
|
259 |
+
data: List[SentenceEvidence],
|
260 |
+
batch_size: int,
|
261 |
+
device=None,
|
262 |
+
criterion: nn.Module = None,
|
263 |
+
tensorize_model_inputs: bool = True,
|
264 |
+
):
|
265 |
+
"""Predictions for more than one batch.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
classifier: a module that looks like an AttentiveClassifier
|
269 |
+
data: a list of elements to make predictions over. These must be SentenceEvidence objects.
|
270 |
+
batch_size: the biggest chunk we can fit in one batch.
|
271 |
+
device: Optional; what compute device this should run on
|
272 |
+
criterion: Optional; a loss function
|
273 |
+
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
|
274 |
+
"""
|
275 |
+
epoch_loss = 0
|
276 |
+
epoch_soft_pred = []
|
277 |
+
epoch_hard_pred = []
|
278 |
+
epoch_truth = []
|
279 |
+
batches = _grouper(data, batch_size)
|
280 |
+
classifier.eval()
|
281 |
+
for batch in batches:
|
282 |
+
loss, soft_preds, hard_preds, targets = make_preds_batch(
|
283 |
+
classifier,
|
284 |
+
batch,
|
285 |
+
device,
|
286 |
+
criterion=criterion,
|
287 |
+
tensorize_model_inputs=tensorize_model_inputs,
|
288 |
+
)
|
289 |
+
if loss is not None:
|
290 |
+
epoch_loss += loss.sum().item()
|
291 |
+
epoch_hard_pred.extend(hard_preds)
|
292 |
+
epoch_soft_pred.extend(soft_preds.cpu())
|
293 |
+
epoch_truth.extend(targets)
|
294 |
+
epoch_loss /= len(data)
|
295 |
+
epoch_hard_pred = [x.item() for x in epoch_hard_pred]
|
296 |
+
epoch_truth = [x.item() for x in epoch_truth]
|
297 |
+
return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth
|
298 |
+
|
299 |
+
|
300 |
+
def make_token_preds_batch(
|
301 |
+
classifier: nn.Module,
|
302 |
+
batch_elements: List[SentenceEvidence],
|
303 |
+
token_mapping: Dict[str, List[List[Tuple[int, int]]]],
|
304 |
+
device=None,
|
305 |
+
criterion: nn.Module = None,
|
306 |
+
tensorize_model_inputs: bool = True,
|
307 |
+
) -> Tuple[float, List[float], List[int], List[int]]:
|
308 |
+
"""Batch predictions
|
309 |
+
|
310 |
+
Args:
|
311 |
+
classifier: a module that looks like an AttentiveClassifier
|
312 |
+
batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects.
|
313 |
+
device: Optional; what compute device this should run on
|
314 |
+
criterion: Optional; a loss function
|
315 |
+
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
|
316 |
+
"""
|
317 |
+
# delete any "None" padding, if any (imposed by the use of the "grouper")
|
318 |
+
batch_elements = filter(lambda x: x is not None, batch_elements)
|
319 |
+
targets, queries, sentences = zip(
|
320 |
+
*[(s.kls, s.query, s.sentence) for s in batch_elements]
|
321 |
+
)
|
322 |
+
ids = [(s.ann_id, s.docid, s.index) for s in batch_elements]
|
323 |
+
targets = PaddedSequence.autopad(
|
324 |
+
[torch.tensor(t, dtype=torch.long, device=device) for t in targets],
|
325 |
+
batch_first=True,
|
326 |
+
device=device,
|
327 |
+
)
|
328 |
+
aggregate_spans = [token_mapping[s.docid][s.index] for s in batch_elements]
|
329 |
+
if tensorize_model_inputs:
|
330 |
+
queries = [torch.tensor(q, dtype=torch.long) for q in queries]
|
331 |
+
sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
|
332 |
+
preds = classifier(queries, ids, sentences, aggregate_spans)
|
333 |
+
targets = targets.to(device=preds.device)
|
334 |
+
mask = targets.mask(on=1, off=0, device=preds.device, dtype=torch.float)
|
335 |
+
if criterion:
|
336 |
+
loss = criterion(
|
337 |
+
preds, (targets.data.to(device=preds.device) * mask).squeeze()
|
338 |
+
).sum()
|
339 |
+
else:
|
340 |
+
loss = None
|
341 |
+
hard_preds = [
|
342 |
+
torch.round(x).to(dtype=torch.int).cpu() for x in targets.unpad(preds)
|
343 |
+
]
|
344 |
+
targets = [[y.item() for y in x] for x in targets.unpad(targets.data.cpu())]
|
345 |
+
return loss, preds, hard_preds, targets # targets.unpad(targets.data.cpu())
|
346 |
+
|
347 |
+
|
348 |
+
# TODO fix the arguments
|
349 |
+
def make_token_preds_epoch(
|
350 |
+
classifier: nn.Module,
|
351 |
+
data: List[SentenceEvidence],
|
352 |
+
token_mapping: Dict[str, List[List[Tuple[int, int]]]],
|
353 |
+
batch_size: int,
|
354 |
+
device=None,
|
355 |
+
criterion: nn.Module = None,
|
356 |
+
tensorize_model_inputs: bool = True,
|
357 |
+
):
|
358 |
+
"""Predictions for more than one batch.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
classifier: a module that looks like an AttentiveClassifier
|
362 |
+
data: a list of elements to make predictions over. These must be SentenceEvidence objects.
|
363 |
+
batch_size: the biggest chunk we can fit in one batch.
|
364 |
+
device: Optional; what compute device this should run on
|
365 |
+
criterion: Optional; a loss function
|
366 |
+
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
|
367 |
+
"""
|
368 |
+
epoch_loss = 0
|
369 |
+
epoch_soft_pred = []
|
370 |
+
epoch_hard_pred = []
|
371 |
+
epoch_truth = []
|
372 |
+
batches = _grouper(data, batch_size)
|
373 |
+
classifier.eval()
|
374 |
+
for batch in batches:
|
375 |
+
loss, soft_preds, hard_preds, targets = make_token_preds_batch(
|
376 |
+
classifier,
|
377 |
+
batch,
|
378 |
+
token_mapping,
|
379 |
+
device,
|
380 |
+
criterion=criterion,
|
381 |
+
tensorize_model_inputs=tensorize_model_inputs,
|
382 |
+
)
|
383 |
+
if loss is not None:
|
384 |
+
epoch_loss += loss.sum().item()
|
385 |
+
epoch_hard_pred.extend(hard_preds)
|
386 |
+
epoch_soft_pred.extend(soft_preds.cpu().tolist())
|
387 |
+
epoch_truth.extend(targets)
|
388 |
+
epoch_loss /= len(data)
|
389 |
+
return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth
|
390 |
+
|
391 |
+
|
392 |
+
# copied from https://docs.python.org/3/library/itertools.html#itertools-recipes
|
393 |
+
def _grouper(iterable, n, fillvalue=None):
|
394 |
+
"Collect data into fixed-length chunks or blocks"
|
395 |
+
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
|
396 |
+
args = [iter(iterable)] * n
|
397 |
+
return itertools.zip_longest(*args, fillvalue=fillvalue)
|
398 |
+
|
399 |
+
|
400 |
+
def score_rationales(
|
401 |
+
truth: List[Annotation],
|
402 |
+
documents: Dict[str, List[List[int]]],
|
403 |
+
input_data: List[SentenceEvidence],
|
404 |
+
scores: List[float],
|
405 |
+
) -> dict:
|
406 |
+
results = {}
|
407 |
+
doc_to_sent_scores = dict() # (annid, docid) -> [sentence scores]
|
408 |
+
for sent, score in zip(input_data, scores):
|
409 |
+
k = (sent.ann_id, sent.docid)
|
410 |
+
if k not in doc_to_sent_scores:
|
411 |
+
doc_to_sent_scores[k] = [0.0 for _ in range(len(documents[sent.docid]))]
|
412 |
+
if not isinstance(score[1], float):
|
413 |
+
score[1] = score[1].item()
|
414 |
+
doc_to_sent_scores[(sent.ann_id, sent.docid)][sent.index] = score[1]
|
415 |
+
# hard rationale scoring
|
416 |
+
best_sentence = {k: np.argmax(np.array(v)) for k, v in doc_to_sent_scores.items()}
|
417 |
+
predicted_rationales = []
|
418 |
+
for (ann_id, docid), sent_idx in best_sentence.items():
|
419 |
+
start_token = sum(len(s) for s in documents[docid][:sent_idx])
|
420 |
+
end_token = start_token + len(documents[docid][sent_idx])
|
421 |
+
predicted_rationales.append(Rationale(ann_id, docid, start_token, end_token))
|
422 |
+
true_rationales = list(
|
423 |
+
chain.from_iterable(Rationale.from_annotation(rat) for rat in truth)
|
424 |
+
)
|
425 |
+
|
426 |
+
results["hard_rationale_scores"] = score_hard_rationale_predictions(
|
427 |
+
true_rationales, predicted_rationales
|
428 |
+
)
|
429 |
+
results["hard_rationale_partial_match_scores"] = partial_match_score(
|
430 |
+
true_rationales, predicted_rationales, [0.5]
|
431 |
+
)
|
432 |
+
|
433 |
+
# soft rationale scoring
|
434 |
+
instance_format = []
|
435 |
+
for (ann_id, docid), sentences in doc_to_sent_scores.items():
|
436 |
+
soft_token_predictions = []
|
437 |
+
for sent_score, sent_text in zip(sentences, documents[docid]):
|
438 |
+
soft_token_predictions.extend(sent_score for _ in range(len(sent_text)))
|
439 |
+
instance_format.append(
|
440 |
+
{
|
441 |
+
"annotation_id": ann_id,
|
442 |
+
"rationales": [
|
443 |
+
{
|
444 |
+
"docid": docid,
|
445 |
+
"soft_rationale_predictions": soft_token_predictions,
|
446 |
+
"soft_sentence_predictions": sentences,
|
447 |
+
}
|
448 |
+
],
|
449 |
+
}
|
450 |
+
)
|
451 |
+
flattened_documents = {
|
452 |
+
k: list(chain.from_iterable(v)) for k, v in documents.items()
|
453 |
+
}
|
454 |
+
token_scoring_format = PositionScoredDocument.from_results(
|
455 |
+
instance_format, truth, flattened_documents, use_tokens=True
|
456 |
+
)
|
457 |
+
results["soft_token_scores"] = score_soft_tokens(token_scoring_format)
|
458 |
+
sentence_scoring_format = PositionScoredDocument.from_results(
|
459 |
+
instance_format, truth, documents, use_tokens=False
|
460 |
+
)
|
461 |
+
results["soft_sentence_scores"] = score_soft_tokens(sentence_scoring_format)
|
462 |
+
return results
|
463 |
+
|
464 |
+
|
465 |
+
def decode(
|
466 |
+
evidence_identifier: nn.Module,
|
467 |
+
evidence_classifier: nn.Module,
|
468 |
+
train: List[Annotation],
|
469 |
+
val: List[Annotation],
|
470 |
+
test: List[Annotation],
|
471 |
+
docs: Dict[str, List[List[int]]],
|
472 |
+
class_interner: Dict[str, int],
|
473 |
+
batch_size: int,
|
474 |
+
tensorize_model_inputs: bool,
|
475 |
+
decoding_docs: Dict[str, List[Any]] = None,
|
476 |
+
) -> dict:
|
477 |
+
"""Identifies and then classifies evidence
|
478 |
+
|
479 |
+
Args:
|
480 |
+
evidence_identifier: a module for identifying evidence statements
|
481 |
+
evidence_classifier: a module for making a classification based on evidence statements
|
482 |
+
train: A List of interned Annotations
|
483 |
+
val: A List of interned Annotations
|
484 |
+
test: A List of interned Annotations
|
485 |
+
docs: A Dict of Documents, which are interned sentences.
|
486 |
+
class_interner: Converts an Annotation's final class into ints
|
487 |
+
batch_size: how big should our batches be?
|
488 |
+
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
|
489 |
+
"""
|
490 |
+
device = None
|
491 |
+
class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])]
|
492 |
+
if decoding_docs is None:
|
493 |
+
decoding_docs = docs
|
494 |
+
|
495 |
+
def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]:
|
496 |
+
"""Prepares data for evidence identification and classification.
|
497 |
+
|
498 |
+
Creates paired evaluation data, wherein each (annotation, docid, sentence, kls)
|
499 |
+
tuplet appears first as the kls determining if the sentence is evidence, and
|
500 |
+
secondarily what the overall classification for the (annotation/docid) pair is.
|
501 |
+
This allows selection based on model scores of the evidence_identifier for
|
502 |
+
input to the evidence_classifier.
|
503 |
+
"""
|
504 |
+
identification_data = annotations_to_evidence_identification(data, docs)
|
505 |
+
classification_data = annotations_to_evidence_classification(
|
506 |
+
data, docs, class_interner, include_all=True
|
507 |
+
)
|
508 |
+
ann_doc_sents = defaultdict(
|
509 |
+
lambda: defaultdict(dict)
|
510 |
+
) # ann id -> docid -> sent idx -> sent data
|
511 |
+
ret = []
|
512 |
+
for sent_ev in classification_data:
|
513 |
+
id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index]
|
514 |
+
ret.append((id_data, sent_ev))
|
515 |
+
assert id_data.ann_id == sent_ev.ann_id
|
516 |
+
assert id_data.docid == sent_ev.docid
|
517 |
+
assert id_data.index == sent_ev.index
|
518 |
+
assert len(ret) == len(classification_data)
|
519 |
+
return ret
|
520 |
+
|
521 |
+
def decode_batch(
|
522 |
+
data: List[Tuple[SentenceEvidence, SentenceEvidence]],
|
523 |
+
name: str,
|
524 |
+
score: bool = False,
|
525 |
+
annotations: List[Annotation] = None,
|
526 |
+
) -> dict:
|
527 |
+
"""Identifies evidence statements and then makes classifications based on it.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
data: a paired list of SentenceEvidences, differing only in the kls field.
|
531 |
+
The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class
|
532 |
+
name: a name for a results dict
|
533 |
+
"""
|
534 |
+
|
535 |
+
num_uniques = len(set((x.ann_id, x.docid) for x, _ in data))
|
536 |
+
logging.info(
|
537 |
+
f"Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations"
|
538 |
+
)
|
539 |
+
identifier_data, classifier_data = zip(*data)
|
540 |
+
results = dict()
|
541 |
+
IdentificationClassificationResult = namedtuple(
|
542 |
+
"IdentificationClassificationResult",
|
543 |
+
"identification_data classification_data soft_identification hard_identification soft_classification hard_classification",
|
544 |
+
)
|
545 |
+
with torch.no_grad():
|
546 |
+
# make predictions for the evidence_identifier
|
547 |
+
evidence_identifier.eval()
|
548 |
+
evidence_classifier.eval()
|
549 |
+
|
550 |
+
(
|
551 |
+
_,
|
552 |
+
soft_identification_preds,
|
553 |
+
hard_identification_preds,
|
554 |
+
_,
|
555 |
+
) = make_preds_epoch(
|
556 |
+
evidence_identifier,
|
557 |
+
identifier_data,
|
558 |
+
batch_size,
|
559 |
+
device,
|
560 |
+
tensorize_model_inputs=tensorize_model_inputs,
|
561 |
+
)
|
562 |
+
assert len(soft_identification_preds) == len(data)
|
563 |
+
identification_results = defaultdict(list)
|
564 |
+
for id_data, cls_data, soft_id_pred, hard_id_pred in zip(
|
565 |
+
identifier_data,
|
566 |
+
classifier_data,
|
567 |
+
soft_identification_preds,
|
568 |
+
hard_identification_preds,
|
569 |
+
):
|
570 |
+
res = IdentificationClassificationResult(
|
571 |
+
identification_data=id_data,
|
572 |
+
classification_data=cls_data,
|
573 |
+
# 1 is p(evidence|sent,query)
|
574 |
+
soft_identification=soft_id_pred[1].float().item(),
|
575 |
+
hard_identification=hard_id_pred,
|
576 |
+
soft_classification=None,
|
577 |
+
hard_classification=False,
|
578 |
+
)
|
579 |
+
identification_results[(id_data.ann_id, id_data.docid)].append(res)
|
580 |
+
|
581 |
+
best_identification_results = {
|
582 |
+
key: max(value, key=lambda x: x.soft_identification)
|
583 |
+
for key, value in identification_results.items()
|
584 |
+
}
|
585 |
+
logging.info(
|
586 |
+
f"Selected the best sentence for {len(identification_results)} examples from a total of {len(soft_identification_preds)} sentences"
|
587 |
+
)
|
588 |
+
ids, classification_data = zip(
|
589 |
+
*[
|
590 |
+
(k, v.classification_data)
|
591 |
+
for k, v in best_identification_results.items()
|
592 |
+
]
|
593 |
+
)
|
594 |
+
(
|
595 |
+
_,
|
596 |
+
soft_classification_preds,
|
597 |
+
hard_classification_preds,
|
598 |
+
classification_truth,
|
599 |
+
) = make_preds_epoch(
|
600 |
+
evidence_classifier,
|
601 |
+
classification_data,
|
602 |
+
batch_size,
|
603 |
+
device,
|
604 |
+
tensorize_model_inputs=tensorize_model_inputs,
|
605 |
+
)
|
606 |
+
classification_results = dict()
|
607 |
+
for eyeD, soft_class, hard_class in zip(
|
608 |
+
ids, soft_classification_preds, hard_classification_preds
|
609 |
+
):
|
610 |
+
input_id_result = best_identification_results[eyeD]
|
611 |
+
res = IdentificationClassificationResult(
|
612 |
+
identification_data=input_id_result.identification_data,
|
613 |
+
classification_data=input_id_result.classification_data,
|
614 |
+
soft_identification=input_id_result.soft_identification,
|
615 |
+
hard_identification=input_id_result.hard_identification,
|
616 |
+
soft_classification=soft_class,
|
617 |
+
hard_classification=hard_class,
|
618 |
+
)
|
619 |
+
classification_results[eyeD] = res
|
620 |
+
|
621 |
+
if score:
|
622 |
+
truth = []
|
623 |
+
pred = []
|
624 |
+
for res in classification_results.values():
|
625 |
+
truth.append(res.classification_data.kls)
|
626 |
+
pred.append(res.hard_classification)
|
627 |
+
# results[f'{name}_f1'] = classification_report(classification_truth, pred, target_names=class_labels, output_dict=True)
|
628 |
+
results[f"{name}_f1"] = classification_report(
|
629 |
+
classification_truth,
|
630 |
+
hard_classification_preds,
|
631 |
+
target_names=class_labels,
|
632 |
+
output_dict=True,
|
633 |
+
)
|
634 |
+
results[f"{name}_acc"] = accuracy_score(
|
635 |
+
classification_truth, hard_classification_preds
|
636 |
+
)
|
637 |
+
results[f"{name}_rationale"] = score_rationales(
|
638 |
+
annotations,
|
639 |
+
decoding_docs,
|
640 |
+
identifier_data,
|
641 |
+
soft_identification_preds,
|
642 |
+
)
|
643 |
+
|
644 |
+
# turn the above results into a format suitable for scoring via the rationale scorer
|
645 |
+
# n.b. the sentence-level evidence predictions (hard and soft) are
|
646 |
+
# broadcast to the token level for scoring. The comprehensiveness class
|
647 |
+
# score is also a lie since the pipeline model above is faithful by
|
648 |
+
# design.
|
649 |
+
decoded = dict()
|
650 |
+
decoded_scores = defaultdict(list)
|
651 |
+
for (ann_id, docid), pred in classification_results.items():
|
652 |
+
sentence_prediction_scores = [
|
653 |
+
x.soft_identification
|
654 |
+
for x in identification_results[(ann_id, docid)]
|
655 |
+
]
|
656 |
+
sentence_start_token = sum(
|
657 |
+
len(s)
|
658 |
+
for s in decoding_docs[docid][: pred.identification_data.index]
|
659 |
+
)
|
660 |
+
sentence_end_token = sentence_start_token + len(
|
661 |
+
decoding_docs[docid][pred.classification_data.index]
|
662 |
+
)
|
663 |
+
hard_rationale_predictions = [
|
664 |
+
{
|
665 |
+
"start_token": sentence_start_token,
|
666 |
+
"end_token": sentence_end_token,
|
667 |
+
}
|
668 |
+
]
|
669 |
+
soft_rationale_predictions = []
|
670 |
+
for sent_result in sorted(
|
671 |
+
identification_results[(ann_id, docid)],
|
672 |
+
key=lambda x: x.identification_data.index,
|
673 |
+
):
|
674 |
+
soft_rationale_predictions.extend(
|
675 |
+
sent_result.soft_identification
|
676 |
+
for _ in range(
|
677 |
+
len(
|
678 |
+
decoding_docs[sent_result.identification_data.docid][
|
679 |
+
sent_result.identification_data.index
|
680 |
+
]
|
681 |
+
)
|
682 |
+
)
|
683 |
+
)
|
684 |
+
if ann_id not in decoded:
|
685 |
+
decoded[ann_id] = {
|
686 |
+
"annotation_id": ann_id,
|
687 |
+
"rationales": [],
|
688 |
+
"classification": class_labels[pred.hard_classification],
|
689 |
+
"classification_scores": {
|
690 |
+
class_labels[i]: s.item()
|
691 |
+
for i, s in enumerate(pred.soft_classification)
|
692 |
+
},
|
693 |
+
# TODO this should turn into the data distribution for the predicted class
|
694 |
+
# "comprehensiveness_classification_scores": 0.0,
|
695 |
+
"truth": pred.classification_data.kls,
|
696 |
+
}
|
697 |
+
decoded[ann_id]["rationales"].append(
|
698 |
+
{
|
699 |
+
"docid": docid,
|
700 |
+
"hard_rationale_predictions": hard_rationale_predictions,
|
701 |
+
"soft_rationale_predictions": soft_rationale_predictions,
|
702 |
+
"soft_sentence_predictions": sentence_prediction_scores,
|
703 |
+
}
|
704 |
+
)
|
705 |
+
decoded_scores[ann_id].append(pred.soft_classification)
|
706 |
+
|
707 |
+
# in practice, this is always a single element operation:
|
708 |
+
# in evidence inference (prompt is really a prompt + document), fever (we split documents into two classifications), movies (you only have one opinion about a movie), or boolQ (single document prompts)
|
709 |
+
# this exists to support weird models we *might* implement for cose/esnli
|
710 |
+
for ann_id, scores_list in decoded_scores.items():
|
711 |
+
scores = torch.stack(scores_list)
|
712 |
+
score_avg = torch.mean(scores, dim=0)
|
713 |
+
# .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16
|
714 |
+
hard_pred = torch.argmax(score_avg.float()).item()
|
715 |
+
decoded[ann_id]["classification"] = class_labels[hard_pred]
|
716 |
+
decoded[ann_id]["classification_scores"] = {
|
717 |
+
class_labels[i]: s.item() for i, s in enumerate(score_avg)
|
718 |
+
}
|
719 |
+
return results, list(decoded.values())
|
720 |
+
|
721 |
+
test_results, test_decoded = decode_batch(prep(test), "test", score=False)
|
722 |
+
val_results, val_decoded = dict(), []
|
723 |
+
train_results, train_decoded = dict(), []
|
724 |
+
# val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val)
|
725 |
+
# train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train)
|
726 |
+
return (
|
727 |
+
dict(**train_results, **val_results, **test_results),
|
728 |
+
train_decoded,
|
729 |
+
val_decoded,
|
730 |
+
test_decoded,
|
731 |
+
)
|
732 |
+
|
733 |
+
|
734 |
+
def decode_evidence_tokens_and_classify(
|
735 |
+
evidence_token_identifier: nn.Module,
|
736 |
+
evidence_classifier: nn.Module,
|
737 |
+
train: List[Annotation],
|
738 |
+
val: List[Annotation],
|
739 |
+
test: List[Annotation],
|
740 |
+
docs: Dict[str, List[List[int]]],
|
741 |
+
source_documents: Dict[str, List[List[str]]],
|
742 |
+
token_mapping: Dict[str, List[List[Tuple[int, int]]]],
|
743 |
+
class_interner: Dict[str, int],
|
744 |
+
batch_size: int,
|
745 |
+
decoding_docs: Dict[str, List[Any]],
|
746 |
+
use_cose_hack: bool = False,
|
747 |
+
) -> dict:
|
748 |
+
"""Identifies and then classifies evidence
|
749 |
+
|
750 |
+
Args:
|
751 |
+
evidence_token_identifier: a module for identifying evidence statements
|
752 |
+
evidence_classifier: a module for making a classification based on evidence statements
|
753 |
+
train: A List of interned Annotations
|
754 |
+
val: A List of interned Annotations
|
755 |
+
test: A List of interned Annotations
|
756 |
+
docs: A Dict of Documents, which are interned sentences.
|
757 |
+
class_interner: Converts an Annotation's final class into ints
|
758 |
+
batch_size: how big should our batches be?
|
759 |
+
"""
|
760 |
+
device = None
|
761 |
+
class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])]
|
762 |
+
if decoding_docs is None:
|
763 |
+
decoding_docs = docs
|
764 |
+
|
765 |
+
def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]:
|
766 |
+
"""Prepares data for evidence identification and classification.
|
767 |
+
|
768 |
+
Creates paired evaluation data, wherein each (annotation, docid, sentence, kls)
|
769 |
+
tuplet appears first as the kls determining if the sentence is evidence, and
|
770 |
+
secondarily what the overall classification for the (annotation/docid) pair is.
|
771 |
+
This allows selection based on model scores of the evidence_token_identifier for
|
772 |
+
input to the evidence_classifier.
|
773 |
+
"""
|
774 |
+
# identification_data = annotations_to_evidence_identification(data, docs)
|
775 |
+
classification_data = token_annotations_to_evidence_classification(
|
776 |
+
data, docs, class_interner
|
777 |
+
)
|
778 |
+
# annotation id -> docid -> [SentenceEvidence])
|
779 |
+
identification_data = annotations_to_evidence_token_identification(
|
780 |
+
data,
|
781 |
+
source_documents=decoding_docs,
|
782 |
+
interned_documents=docs,
|
783 |
+
token_mapping=token_mapping,
|
784 |
+
)
|
785 |
+
ann_doc_sents = defaultdict(
|
786 |
+
lambda: defaultdict(dict)
|
787 |
+
) # ann id -> docid -> sent idx -> sent data
|
788 |
+
ret = []
|
789 |
+
for sent_ev in classification_data:
|
790 |
+
id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index]
|
791 |
+
ret.append((id_data, sent_ev))
|
792 |
+
assert id_data.ann_id == sent_ev.ann_id
|
793 |
+
assert id_data.docid == sent_ev.docid
|
794 |
+
# assert id_data.index == sent_ev.index
|
795 |
+
assert len(ret) == len(classification_data)
|
796 |
+
return ret
|
797 |
+
|
798 |
+
def decode_batch(
|
799 |
+
data: List[Tuple[SentenceEvidence, SentenceEvidence]],
|
800 |
+
name: str,
|
801 |
+
score: bool = False,
|
802 |
+
annotations: List[Annotation] = None,
|
803 |
+
class_labels: dict = class_labels,
|
804 |
+
) -> dict:
|
805 |
+
"""Identifies evidence statements and then makes classifications based on it.
|
806 |
+
|
807 |
+
Args:
|
808 |
+
data: a paired list of SentenceEvidences, differing only in the kls field.
|
809 |
+
The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class
|
810 |
+
name: a name for a results dict
|
811 |
+
"""
|
812 |
+
|
813 |
+
num_uniques = len(set((x.ann_id, x.docid) for x, _ in data))
|
814 |
+
logging.info(
|
815 |
+
f"Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations"
|
816 |
+
)
|
817 |
+
identifier_data, classifier_data = zip(*data)
|
818 |
+
results = dict()
|
819 |
+
with torch.no_grad():
|
820 |
+
# make predictions for the evidence_token_identifier
|
821 |
+
evidence_token_identifier.eval()
|
822 |
+
evidence_classifier.eval()
|
823 |
+
|
824 |
+
(
|
825 |
+
_,
|
826 |
+
soft_identification_preds,
|
827 |
+
hard_identification_preds,
|
828 |
+
id_preds_truth,
|
829 |
+
) = make_token_preds_epoch(
|
830 |
+
evidence_token_identifier,
|
831 |
+
identifier_data,
|
832 |
+
token_mapping,
|
833 |
+
batch_size,
|
834 |
+
device,
|
835 |
+
tensorize_model_inputs=True,
|
836 |
+
)
|
837 |
+
assert len(soft_identification_preds) == len(data)
|
838 |
+
evidence_only_cls = []
|
839 |
+
for id_data, cls_data, soft_id_pred, hard_id_pred in zip(
|
840 |
+
identifier_data,
|
841 |
+
classifier_data,
|
842 |
+
soft_identification_preds,
|
843 |
+
hard_identification_preds,
|
844 |
+
):
|
845 |
+
assert cls_data.ann_id == id_data.ann_id
|
846 |
+
sent = []
|
847 |
+
for start, end in token_mapping[cls_data.docid][0]:
|
848 |
+
if bool(hard_id_pred[start]):
|
849 |
+
sent.extend(id_data.sentence[start:end])
|
850 |
+
# assert len(sent) > 0
|
851 |
+
new_cls_data = SentenceEvidence(
|
852 |
+
cls_data.kls,
|
853 |
+
cls_data.ann_id,
|
854 |
+
cls_data.query,
|
855 |
+
cls_data.docid,
|
856 |
+
cls_data.index,
|
857 |
+
tuple(sent),
|
858 |
+
)
|
859 |
+
evidence_only_cls.append(new_cls_data)
|
860 |
+
(
|
861 |
+
_,
|
862 |
+
soft_classification_preds,
|
863 |
+
hard_classification_preds,
|
864 |
+
classification_truth,
|
865 |
+
) = make_preds_epoch(
|
866 |
+
evidence_classifier,
|
867 |
+
evidence_only_cls,
|
868 |
+
batch_size,
|
869 |
+
device,
|
870 |
+
tensorize_model_inputs=True,
|
871 |
+
)
|
872 |
+
|
873 |
+
if use_cose_hack:
|
874 |
+
logging.info(
|
875 |
+
"Reformatting identification and classification results to fit COS-E"
|
876 |
+
)
|
877 |
+
grouping = 5
|
878 |
+
new_soft_identification_preds = []
|
879 |
+
new_hard_identification_preds = []
|
880 |
+
new_id_preds_truth = []
|
881 |
+
new_soft_classification_preds = []
|
882 |
+
new_hard_classification_preds = []
|
883 |
+
new_classification_truth = []
|
884 |
+
new_identifier_data = []
|
885 |
+
class_labels = []
|
886 |
+
|
887 |
+
# TODO fix the labels for COS-E
|
888 |
+
for i in range(0, len(soft_identification_preds), grouping):
|
889 |
+
cls_scores = torch.stack(
|
890 |
+
soft_classification_preds[i : i + grouping]
|
891 |
+
)
|
892 |
+
cls_scores = nn.functional.softmax(cls_scores, dim=-1)
|
893 |
+
cls_scores = cls_scores[:, 1]
|
894 |
+
choice = torch.argmax(cls_scores)
|
895 |
+
cls_labels = [
|
896 |
+
x.ann_id.split("_")[-1]
|
897 |
+
for x in evidence_only_cls[i : i + grouping]
|
898 |
+
]
|
899 |
+
class_labels = cls_labels # we need to update the class labels because of the terrible hackery used to train this
|
900 |
+
cls_truths = [x.kls for x in evidence_only_cls[i : i + grouping]]
|
901 |
+
# cls_choice = evidence_only_cls[i + choice].ann_id.split('_')[-1]
|
902 |
+
cls_truth = np.argmax(cls_truths)
|
903 |
+
new_soft_identification_preds.append(
|
904 |
+
soft_identification_preds[i + choice]
|
905 |
+
)
|
906 |
+
new_hard_identification_preds.append(
|
907 |
+
hard_identification_preds[i + choice]
|
908 |
+
)
|
909 |
+
new_id_preds_truth.append(id_preds_truth[i + choice])
|
910 |
+
new_soft_classification_preds.append(
|
911 |
+
soft_classification_preds[i + choice]
|
912 |
+
)
|
913 |
+
new_hard_classification_preds.append(choice)
|
914 |
+
new_identifier_data.append(identifier_data[i + choice])
|
915 |
+
# new_hard_classification_preds.append(hard_classification_preds[i + choice])
|
916 |
+
# new_classification_truth.append(classification_truth[i + choice])
|
917 |
+
new_classification_truth.append(cls_truth)
|
918 |
+
|
919 |
+
soft_identification_preds = new_soft_identification_preds
|
920 |
+
hard_identification_preds = new_hard_identification_preds
|
921 |
+
id_preds_truth = new_id_preds_truth
|
922 |
+
soft_classification_preds = new_soft_classification_preds
|
923 |
+
hard_classification_preds = new_hard_classification_preds
|
924 |
+
classification_truth = new_classification_truth
|
925 |
+
identifier_data = new_identifier_data
|
926 |
+
if score:
|
927 |
+
results[f"{name}_f1"] = classification_report(
|
928 |
+
classification_truth,
|
929 |
+
hard_classification_preds,
|
930 |
+
target_names=class_labels,
|
931 |
+
output_dict=True,
|
932 |
+
)
|
933 |
+
results[f"{name}_acc"] = accuracy_score(
|
934 |
+
classification_truth, hard_classification_preds
|
935 |
+
)
|
936 |
+
results[f"{name}_token_pred_acc"] = accuracy_score(
|
937 |
+
list(chain.from_iterable(id_preds_truth)),
|
938 |
+
list(chain.from_iterable(hard_identification_preds)),
|
939 |
+
)
|
940 |
+
results[f"{name}_token_pred_f1"] = classification_report(
|
941 |
+
list(chain.from_iterable(id_preds_truth)),
|
942 |
+
list(chain.from_iterable(hard_identification_preds)),
|
943 |
+
output_dict=True,
|
944 |
+
)
|
945 |
+
# TODO for token level stuff!
|
946 |
+
soft_id_scores = [
|
947 |
+
[1 - x, x] for x in chain.from_iterable(soft_identification_preds)
|
948 |
+
]
|
949 |
+
results[f"{name}_rationale"] = score_rationales(
|
950 |
+
annotations, decoding_docs, identifier_data, soft_id_scores
|
951 |
+
)
|
952 |
+
logging.info(f"Results: {results}")
|
953 |
+
|
954 |
+
# turn the above results into a format suitable for scoring via the rationale scorer
|
955 |
+
# n.b. the sentence-level evidence predictions (hard and soft) are
|
956 |
+
# broadcast to the token level for scoring. The comprehensiveness class
|
957 |
+
# score is also a lie since the pipeline model above is faithful by
|
958 |
+
# design.
|
959 |
+
decoded = dict()
|
960 |
+
scores = []
|
961 |
+
assert len(identifier_data) == len(soft_identification_preds)
|
962 |
+
for (
|
963 |
+
id_data,
|
964 |
+
soft_id_pred,
|
965 |
+
hard_id_pred,
|
966 |
+
soft_cls_preds,
|
967 |
+
hard_cls_pred,
|
968 |
+
) in zip(
|
969 |
+
identifier_data,
|
970 |
+
soft_identification_preds,
|
971 |
+
hard_identification_preds,
|
972 |
+
soft_classification_preds,
|
973 |
+
hard_classification_preds,
|
974 |
+
):
|
975 |
+
docid = id_data.docid
|
976 |
+
if use_cose_hack:
|
977 |
+
docid = "_".join(docid.split("_")[0:-1])
|
978 |
+
assert len(docid) > 0
|
979 |
+
rationales = {
|
980 |
+
"docid": docid,
|
981 |
+
"hard_rationale_predictions": [],
|
982 |
+
# token level classifications, a value must be provided per-token
|
983 |
+
# in an ideal world, these correspond to the hard-decoding above.
|
984 |
+
"soft_rationale_predictions": [],
|
985 |
+
# sentence level classifications, a value must be provided for every
|
986 |
+
# sentence in each document, or not at all
|
987 |
+
"soft_sentence_predictions": [1.0],
|
988 |
+
}
|
989 |
+
last = -1
|
990 |
+
start_span = -1
|
991 |
+
for pos, (start, _) in enumerate(token_mapping[id_data.docid][0]):
|
992 |
+
rationales["soft_rationale_predictions"].append(soft_id_pred[start])
|
993 |
+
if bool(hard_id_pred[start]):
|
994 |
+
if start_span == -1:
|
995 |
+
start_span = pos
|
996 |
+
last = pos
|
997 |
+
else:
|
998 |
+
if start_span != -1:
|
999 |
+
rationales["hard_rationale_predictions"].append(
|
1000 |
+
{
|
1001 |
+
"start_token": start_span,
|
1002 |
+
"end_token": last + 1,
|
1003 |
+
}
|
1004 |
+
)
|
1005 |
+
last = -1
|
1006 |
+
start_span = -1
|
1007 |
+
if start_span != -1:
|
1008 |
+
rationales["hard_rationale_predictions"].append(
|
1009 |
+
{
|
1010 |
+
"start_token": start_span,
|
1011 |
+
"end_token": last + 1,
|
1012 |
+
}
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
ann_id = id_data.ann_id
|
1016 |
+
if use_cose_hack:
|
1017 |
+
ann_id = "_".join(ann_id.split("_")[0:-1])
|
1018 |
+
soft_cls_preds = nn.functional.softmax(soft_cls_preds)
|
1019 |
+
decoded[id_data.ann_id] = {
|
1020 |
+
"annotation_id": ann_id,
|
1021 |
+
"rationales": [rationales],
|
1022 |
+
"classification": class_labels[hard_cls_pred],
|
1023 |
+
"classification_scores": {
|
1024 |
+
class_labels[i]: score.item()
|
1025 |
+
for i, score in enumerate(soft_cls_preds)
|
1026 |
+
},
|
1027 |
+
}
|
1028 |
+
return results, list(decoded.values())
|
1029 |
+
|
1030 |
+
# test_results, test_decoded = dict(), []
|
1031 |
+
# val_results, val_decoded = dict(), []
|
1032 |
+
train_results, train_decoded = dict(), []
|
1033 |
+
val_results, val_decoded = decode_batch(
|
1034 |
+
prep(val), "val", score=True, annotations=val, class_labels=class_labels
|
1035 |
+
)
|
1036 |
+
test_results, test_decoded = decode_batch(
|
1037 |
+
prep(test), "test", score=False, class_labels=class_labels
|
1038 |
+
)
|
1039 |
+
# train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train, class_labels=class_labels)
|
1040 |
+
return (
|
1041 |
+
dict(**train_results, **val_results, **test_results),
|
1042 |
+
train_decoded,
|
1043 |
+
val_decoded,
|
1044 |
+
test_decoded,
|
1045 |
+
)
|
Transformer-Explainability/BERT_rationale_benchmark/models/sequence_taggers.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from rationale_benchmark.models.model_utils import PaddedSequence
|
6 |
+
from transformers import BertModel
|
7 |
+
|
8 |
+
|
9 |
+
class BertTagger(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
bert_dir: str,
|
13 |
+
pad_token_id: int,
|
14 |
+
cls_token_id: int,
|
15 |
+
sep_token_id: int,
|
16 |
+
max_length: int = 512,
|
17 |
+
use_half_precision=True,
|
18 |
+
):
|
19 |
+
super(BertTagger, self).__init__()
|
20 |
+
self.sep_token_id = sep_token_id
|
21 |
+
self.cls_token_id = cls_token_id
|
22 |
+
self.pad_token_id = pad_token_id
|
23 |
+
self.max_length = max_length
|
24 |
+
bert = BertModel.from_pretrained(bert_dir)
|
25 |
+
if use_half_precision:
|
26 |
+
import apex
|
27 |
+
|
28 |
+
bert = bert.half()
|
29 |
+
self.bert = bert
|
30 |
+
self.relevance_tagger = nn.Sequential(
|
31 |
+
nn.Linear(self.bert.config.hidden_size, 1), nn.Sigmoid()
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
self,
|
36 |
+
query: List[torch.tensor],
|
37 |
+
docids: List[Any],
|
38 |
+
document_batch: List[torch.tensor],
|
39 |
+
aggregate_spans: List[Tuple[int, int]],
|
40 |
+
):
|
41 |
+
assert len(query) == len(document_batch)
|
42 |
+
# note about device management: since distributed training is enabled, the inputs to this module can be on
|
43 |
+
# *any* device (preferably cpu, since we wrap and unwrap the module) we want to keep these params on the
|
44 |
+
# input device (assuming CPU) for as long as possible for cheap memory access
|
45 |
+
target_device = next(self.parameters()).device
|
46 |
+
# cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
|
47 |
+
sep_token = torch.tensor([self.sep_token_id]).to(
|
48 |
+
device=document_batch[0].device
|
49 |
+
)
|
50 |
+
input_tensors = []
|
51 |
+
query_lengths = []
|
52 |
+
for q, d in zip(query, document_batch):
|
53 |
+
if len(q) + len(d) + 1 > self.max_length:
|
54 |
+
d = d[: (self.max_length - len(q) - 1)]
|
55 |
+
input_tensors.append(torch.cat([q, sep_token, d]))
|
56 |
+
query_lengths.append(q.size()[0])
|
57 |
+
bert_input = PaddedSequence.autopad(
|
58 |
+
input_tensors,
|
59 |
+
batch_first=True,
|
60 |
+
padding_value=self.pad_token_id,
|
61 |
+
device=target_device,
|
62 |
+
)
|
63 |
+
outputs = self.bert(
|
64 |
+
bert_input.data,
|
65 |
+
attention_mask=bert_input.mask(
|
66 |
+
on=0.0, off=float("-inf"), dtype=torch.float, device=target_device
|
67 |
+
),
|
68 |
+
)
|
69 |
+
hidden = outputs[0]
|
70 |
+
classes = self.relevance_tagger(hidden)
|
71 |
+
ret = []
|
72 |
+
for ql, cls, doc in zip(query_lengths, classes, document_batch):
|
73 |
+
start = ql + 1
|
74 |
+
end = start + len(doc)
|
75 |
+
ret.append(cls[ql + 1 : end])
|
76 |
+
return PaddedSequence.autopad(
|
77 |
+
ret, batch_first=True, padding_value=0, device=target_device
|
78 |
+
).data.squeeze(dim=-1)
|
Transformer-Explainability/BERT_rationale_benchmark/utils.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from dataclasses import asdict, dataclass, is_dataclass
|
4 |
+
from itertools import chain
|
5 |
+
from typing import Dict, FrozenSet, List, Set, Tuple, Union
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass(eq=True, frozen=True)
|
9 |
+
class Evidence:
|
10 |
+
"""
|
11 |
+
(docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience.
|
12 |
+
Args:
|
13 |
+
text: Some representation of the evidence text
|
14 |
+
docid: Some identifier for the document
|
15 |
+
start_token: The canonical start token, inclusive
|
16 |
+
end_token: The canonical end token, exclusive
|
17 |
+
start_sentence: Best guess start sentence, inclusive
|
18 |
+
end_sentence: Best guess end sentence, exclusive
|
19 |
+
"""
|
20 |
+
|
21 |
+
text: Union[str, Tuple[int], Tuple[str]]
|
22 |
+
docid: str
|
23 |
+
start_token: int = -1
|
24 |
+
end_token: int = -1
|
25 |
+
start_sentence: int = -1
|
26 |
+
end_sentence: int = -1
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass(eq=True, frozen=True)
|
30 |
+
class Annotation:
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
annotation_id: unique ID for this annotation element
|
34 |
+
query: some representation of a query string
|
35 |
+
evidences: a set of "evidence groups".
|
36 |
+
Each evidence group is:
|
37 |
+
* sufficient to respond to the query (or justify an answer)
|
38 |
+
* composed of one or more Evidences
|
39 |
+
* may have multiple documents in it (depending on the dataset)
|
40 |
+
- e-snli has multiple documents
|
41 |
+
- other datasets do not
|
42 |
+
classification: str
|
43 |
+
query_type: Optional str, additional information about the query
|
44 |
+
docids: a set of docids in which one may find evidence.
|
45 |
+
"""
|
46 |
+
|
47 |
+
annotation_id: str
|
48 |
+
query: Union[str, Tuple[int]]
|
49 |
+
evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]]
|
50 |
+
classification: str
|
51 |
+
query_type: str = None
|
52 |
+
docids: Set[str] = None
|
53 |
+
|
54 |
+
def all_evidences(self) -> Tuple[Evidence]:
|
55 |
+
return tuple(list(chain.from_iterable(self.evidences)))
|
56 |
+
|
57 |
+
|
58 |
+
def annotations_to_jsonl(annotations, output_file):
|
59 |
+
with open(output_file, "w") as of:
|
60 |
+
for ann in sorted(annotations, key=lambda x: x.annotation_id):
|
61 |
+
as_json = _annotation_to_dict(ann)
|
62 |
+
as_str = json.dumps(as_json, sort_keys=True)
|
63 |
+
of.write(as_str)
|
64 |
+
of.write("\n")
|
65 |
+
|
66 |
+
|
67 |
+
def _annotation_to_dict(dc):
|
68 |
+
# convenience method
|
69 |
+
if is_dataclass(dc):
|
70 |
+
d = asdict(dc)
|
71 |
+
ret = dict()
|
72 |
+
for k, v in d.items():
|
73 |
+
ret[k] = _annotation_to_dict(v)
|
74 |
+
return ret
|
75 |
+
elif isinstance(dc, dict):
|
76 |
+
ret = dict()
|
77 |
+
for k, v in dc.items():
|
78 |
+
k = _annotation_to_dict(k)
|
79 |
+
v = _annotation_to_dict(v)
|
80 |
+
ret[k] = v
|
81 |
+
return ret
|
82 |
+
elif isinstance(dc, str):
|
83 |
+
return dc
|
84 |
+
elif isinstance(dc, (set, frozenset, list, tuple)):
|
85 |
+
ret = []
|
86 |
+
for x in dc:
|
87 |
+
ret.append(_annotation_to_dict(x))
|
88 |
+
return tuple(ret)
|
89 |
+
else:
|
90 |
+
return dc
|
91 |
+
|
92 |
+
|
93 |
+
def load_jsonl(fp: str) -> List[dict]:
|
94 |
+
ret = []
|
95 |
+
with open(fp, "r") as inf:
|
96 |
+
for line in inf:
|
97 |
+
content = json.loads(line)
|
98 |
+
ret.append(content)
|
99 |
+
return ret
|
100 |
+
|
101 |
+
|
102 |
+
def write_jsonl(jsonl, output_file):
|
103 |
+
with open(output_file, "w") as of:
|
104 |
+
for js in jsonl:
|
105 |
+
as_str = json.dumps(js, sort_keys=True)
|
106 |
+
of.write(as_str)
|
107 |
+
of.write("\n")
|
108 |
+
|
109 |
+
|
110 |
+
def annotations_from_jsonl(fp: str) -> List[Annotation]:
|
111 |
+
ret = []
|
112 |
+
with open(fp, "r") as inf:
|
113 |
+
for line in inf:
|
114 |
+
content = json.loads(line)
|
115 |
+
ev_groups = []
|
116 |
+
for ev_group in content["evidences"]:
|
117 |
+
ev_group = tuple([Evidence(**ev) for ev in ev_group])
|
118 |
+
ev_groups.append(ev_group)
|
119 |
+
content["evidences"] = frozenset(ev_groups)
|
120 |
+
ret.append(Annotation(**content))
|
121 |
+
return ret
|
122 |
+
|
123 |
+
|
124 |
+
def load_datasets(
|
125 |
+
data_dir: str,
|
126 |
+
) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]:
|
127 |
+
"""Loads a training, validation, and test dataset
|
128 |
+
|
129 |
+
Each dataset is assumed to have been serialized by annotations_to_jsonl,
|
130 |
+
that is it is a list of json-serialized Annotation instances.
|
131 |
+
"""
|
132 |
+
train_data = annotations_from_jsonl(os.path.join(data_dir, "train.jsonl"))
|
133 |
+
val_data = annotations_from_jsonl(os.path.join(data_dir, "val.jsonl"))
|
134 |
+
test_data = annotations_from_jsonl(os.path.join(data_dir, "test.jsonl"))
|
135 |
+
return train_data, val_data, test_data
|
136 |
+
|
137 |
+
|
138 |
+
def load_documents(
|
139 |
+
data_dir: str, docids: Set[str] = None
|
140 |
+
) -> Dict[str, List[List[str]]]:
|
141 |
+
"""Loads a subset of available documents from disk.
|
142 |
+
|
143 |
+
Each document is assumed to be serialized as newline ('\n') separated sentences.
|
144 |
+
Each sentence is assumed to be space (' ') joined tokens.
|
145 |
+
"""
|
146 |
+
if os.path.exists(os.path.join(data_dir, "docs.jsonl")):
|
147 |
+
assert not os.path.exists(os.path.join(data_dir, "docs"))
|
148 |
+
return load_documents_from_file(data_dir, docids)
|
149 |
+
|
150 |
+
docs_dir = os.path.join(data_dir, "docs")
|
151 |
+
res = dict()
|
152 |
+
if docids is None:
|
153 |
+
docids = sorted(os.listdir(docs_dir))
|
154 |
+
else:
|
155 |
+
docids = sorted(set(str(d) for d in docids))
|
156 |
+
for d in docids:
|
157 |
+
with open(os.path.join(docs_dir, d), "r") as inf:
|
158 |
+
res[d] = inf.read()
|
159 |
+
return res
|
160 |
+
|
161 |
+
|
162 |
+
def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]:
|
163 |
+
"""Loads a subset of available documents from disk.
|
164 |
+
|
165 |
+
Returns a tokenized version of the document.
|
166 |
+
"""
|
167 |
+
unflattened_docs = load_documents(data_dir, docids)
|
168 |
+
flattened_docs = dict()
|
169 |
+
for doc, unflattened in unflattened_docs.items():
|
170 |
+
flattened_docs[doc] = list(chain.from_iterable(unflattened))
|
171 |
+
return flattened_docs
|
172 |
+
|
173 |
+
|
174 |
+
def intern_documents(
|
175 |
+
documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
Replaces every word with its index in an embeddings file.
|
179 |
+
|
180 |
+
If a word is not found, uses the unk_token instead
|
181 |
+
"""
|
182 |
+
ret = dict()
|
183 |
+
unk = word_interner[unk_token]
|
184 |
+
for docid, sentences in documents.items():
|
185 |
+
ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences]
|
186 |
+
return ret
|
187 |
+
|
188 |
+
|
189 |
+
def intern_annotations(
|
190 |
+
annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str
|
191 |
+
):
|
192 |
+
ret = []
|
193 |
+
for ann in annotations:
|
194 |
+
ev_groups = []
|
195 |
+
for ev_group in ann.evidences:
|
196 |
+
evs = []
|
197 |
+
for ev in ev_group:
|
198 |
+
evs.append(
|
199 |
+
Evidence(
|
200 |
+
text=tuple(
|
201 |
+
[
|
202 |
+
word_interner.get(t, word_interner[unk_token])
|
203 |
+
for t in ev.text.split()
|
204 |
+
]
|
205 |
+
),
|
206 |
+
docid=ev.docid,
|
207 |
+
start_token=ev.start_token,
|
208 |
+
end_token=ev.end_token,
|
209 |
+
start_sentence=ev.start_sentence,
|
210 |
+
end_sentence=ev.end_sentence,
|
211 |
+
)
|
212 |
+
)
|
213 |
+
ev_groups.append(tuple(evs))
|
214 |
+
ret.append(
|
215 |
+
Annotation(
|
216 |
+
annotation_id=ann.annotation_id,
|
217 |
+
query=tuple(
|
218 |
+
[
|
219 |
+
word_interner.get(t, word_interner[unk_token])
|
220 |
+
for t in ann.query.split()
|
221 |
+
]
|
222 |
+
),
|
223 |
+
evidences=frozenset(ev_groups),
|
224 |
+
classification=ann.classification,
|
225 |
+
query_type=ann.query_type,
|
226 |
+
)
|
227 |
+
)
|
228 |
+
return ret
|
229 |
+
|
230 |
+
|
231 |
+
def load_documents_from_file(
|
232 |
+
data_dir: str, docids: Set[str] = None
|
233 |
+
) -> Dict[str, List[List[str]]]:
|
234 |
+
"""Loads a subset of available documents from 'docs.jsonl' file on disk.
|
235 |
+
|
236 |
+
Each document is assumed to be serialized as newline ('\n') separated sentences.
|
237 |
+
Each sentence is assumed to be space (' ') joined tokens.
|
238 |
+
"""
|
239 |
+
docs_file = os.path.join(data_dir, "docs.jsonl")
|
240 |
+
documents = load_jsonl(docs_file)
|
241 |
+
documents = {doc["docid"]: doc["document"] for doc in documents}
|
242 |
+
# res = dict()
|
243 |
+
# if docids is None:
|
244 |
+
# docids = sorted(list(documents.keys()))
|
245 |
+
# else:
|
246 |
+
# docids = sorted(set(str(d) for d in docids))
|
247 |
+
# for d in docids:
|
248 |
+
# lines = documents[d].split('\n')
|
249 |
+
# tokenized = [line.strip().split(' ') for line in lines]
|
250 |
+
# res[d] = tokenized
|
251 |
+
return documents
|
Transformer-Explainability/DeiT.PNG
ADDED
|
Transformer-Explainability/DeiT_example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Transformer-Explainability/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Hila Chefer
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Transformer-Explainability/README.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch Implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838) [CVPR 2021]
|
2 |
+
|
3 |
+
#### Check out our new advancements- [Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers](https://github.com/hila-chefer/Transformer-MM-Explainability)!
|
4 |
+
Faster, more general, and can be applied to *any* type of attention!
|
5 |
+
Among the features:
|
6 |
+
* We remove LRP for a simple and quick solution, and prove that the great results from our first paper still hold!
|
7 |
+
* We expand our work to *any* type of Transformer- not just self-attention based encoders, but also co-attention encoders and encoder-decoders!
|
8 |
+
* We show that VQA models can actually understand both image and text and make connections!
|
9 |
+
* We use a DETR object detector and create segmentation masks from our explanations!
|
10 |
+
* We provide a colab notebook with all the examples. You can very easily add images and questions of your own!
|
11 |
+
|
12 |
+
<p align="center">
|
13 |
+
<img width="400" height="450" src="new_work.jpg">
|
14 |
+
</p>
|
15 |
+
|
16 |
+
---
|
17 |
+
## ViT explainability notebook:
|
18 |
+
[](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb)
|
19 |
+
|
20 |
+
## BERT explainability notebook:
|
21 |
+
[](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb)
|
22 |
+
---
|
23 |
+
|
24 |
+
## Updates
|
25 |
+
April 5 2021: Check out this new [post](https://analyticsindiamag.com/compute-relevancy-of-transformer-networks-via-novel-interpretable-transformer/) about our paper! A great resource for understanding the main concepts behind our work.
|
26 |
+
|
27 |
+
March 15 2021: [A Colab notebook for BERT for sentiment analysis added!](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb)
|
28 |
+
|
29 |
+
Feb 28 2021: Our paper was accepted to CVPR 2021!
|
30 |
+
|
31 |
+
Feb 17 2021: [A Colab notebook with all examples added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb)
|
32 |
+
|
33 |
+
Jan 5 2021: [A Jupyter notebook for DeiT added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/DeiT_example.ipynb)
|
34 |
+
|
35 |
+
|
36 |
+
<p align="center">
|
37 |
+
<img width="300" height="460" src="https://github.com/hila-chefer/Transformer-Explainability/blob/main/DeiT.PNG">
|
38 |
+
</p>
|
39 |
+
|
40 |
+
|
41 |
+
## Introduction
|
42 |
+
Official implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838).
|
43 |
+
|
44 |
+
We introduce a novel method which allows to visualize classifications made by a Transformer based model for both vision and NLP tasks.
|
45 |
+
Our method also allows to visualize explanations per class.
|
46 |
+
|
47 |
+
<p align="center">
|
48 |
+
<img width="600" height="200" src="https://github.com/hila-chefer/Transformer-Explainability/blob/main/method-page-001.jpg">
|
49 |
+
</p>
|
50 |
+
Method consists of 3 phases:
|
51 |
+
|
52 |
+
1. Calculating relevance for each attention matrix using our novel formulation of LRP.
|
53 |
+
|
54 |
+
2. Backpropagation of gradients for each attention matrix w.r.t. the visualized class. Gradients are used to average attention heads.
|
55 |
+
|
56 |
+
3. Layer aggregation with rollout.
|
57 |
+
|
58 |
+
Please notice our [Jupyter notebook](https://github.com/hila-chefer/Transformer-Explainability/blob/main/example.ipynb) where you can run the two class specific examples from the paper.
|
59 |
+
|
60 |
+
|
61 |
+

|
62 |
+
|
63 |
+
To add another input image, simply add the image to the [samples folder](https://github.com/hila-chefer/Transformer-Explainability/tree/main/samples), and use the `generate_visualization` function for your selected class of interest (using the `class_index={class_idx}`), not specifying the index will visualize the top class.
|
64 |
+
|
65 |
+
## Credits
|
66 |
+
ViT implementation is based on:
|
67 |
+
- https://github.com/rwightman/pytorch-image-models
|
68 |
+
- https://github.com/lucidrains/vit-pytorch
|
69 |
+
- pretrained weights from: https://github.com/google-research/vision_transformer
|
70 |
+
|
71 |
+
BERT implementation is taken from the huggingface Transformers library:
|
72 |
+
https://huggingface.co/transformers/
|
73 |
+
|
74 |
+
ERASER benchmark code adapted from the ERASER GitHub implementation: https://github.com/jayded/eraserbenchmark
|
75 |
+
|
76 |
+
Text visualizations in supplementary were created using TAHV heatmap generator for text: https://github.com/jiesutd/Text-Attention-Heatmap-Visualization
|
77 |
+
|
78 |
+
## Reproducing results on ViT
|
79 |
+
|
80 |
+
### Section A. Segmentation Results
|
81 |
+
|
82 |
+
Example:
|
83 |
+
```
|
84 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/imagenet_seg_eval.py --method transformer_attribution --imagenet-seg-path /path/to/gtsegs_ijcv.mat
|
85 |
+
|
86 |
+
```
|
87 |
+
[Link to download dataset](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat).
|
88 |
+
|
89 |
+
In the exmaple above we run a segmentation test with our method. Notice you can choose which method you wish to run using the `--method` argument.
|
90 |
+
You must provide a path to imagenet segmentation data in `--imagenet-seg-path`.
|
91 |
+
|
92 |
+
### Section B. Perturbation Results
|
93 |
+
|
94 |
+
Example:
|
95 |
+
```
|
96 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/generate_visualizations.py --method transformer_attribution --imagenet-validation-path /path/to/imagenet_validation_directory
|
97 |
+
```
|
98 |
+
|
99 |
+
Notice that you can choose to visualize by target or top class by using the `--vis-cls` argument.
|
100 |
+
|
101 |
+
Now to run the perturbation test run the following command:
|
102 |
+
```
|
103 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/pertubation_eval_from_hdf5.py --method transformer_attribution
|
104 |
+
```
|
105 |
+
|
106 |
+
Notice that you can use the `--neg` argument to run either positive or negative perturbation.
|
107 |
+
|
108 |
+
## Reproducing results on BERT
|
109 |
+
|
110 |
+
1. Download the pretrained weights:
|
111 |
+
|
112 |
+
- Download `classifier.zip` from https://drive.google.com/file/d/1kGMTr69UWWe70i-o2_JfjmWDQjT66xwQ/view?usp=sharing
|
113 |
+
- mkdir -p `./bert_models/movies`
|
114 |
+
- unzip classifier.zip -d ./bert_models/movies/
|
115 |
+
|
116 |
+
2. Download the dataset pkl file:
|
117 |
+
|
118 |
+
- Download `preprocessed.pkl` from https://drive.google.com/file/d/1-gfbTj6D87KIm_u1QMHGLKSL3e93hxBH/view?usp=sharing
|
119 |
+
- mv preprocessed.pkl ./bert_models/movies
|
120 |
+
|
121 |
+
3. Download the dataset:
|
122 |
+
|
123 |
+
- Download `movies.zip` from https://drive.google.com/file/d/11faFLGkc0hkw3wrGTYJBr1nIvkRb189F/view?usp=sharing
|
124 |
+
- unzip movies.zip -d ./data/
|
125 |
+
|
126 |
+
4. Now you can run the model.
|
127 |
+
|
128 |
+
Example:
|
129 |
+
```
|
130 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies/ --output_dir bert_models/movies/ --model_params BERT_params/movies_bert.json
|
131 |
+
```
|
132 |
+
To control which algorithm to use for explanations change the `method` variable in `BERT_rationale_benchmark/models/pipeline/bert_pipeline.py` (Defaults to 'transformer_attribution' which is our method).
|
133 |
+
Running this command will create a directory for the method in `bert_models/movies/<method_name>`.
|
134 |
+
|
135 |
+
In order to run f1 test with k, run the following command:
|
136 |
+
```
|
137 |
+
PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/metrics.py --data_dir data/movies/ --split test --results bert_models/movies/<method_name>/identifier_results_k.json
|
138 |
+
```
|
139 |
+
|
140 |
+
Also, in the method directory there will be created `.tex` files containing the explanations extracted for each example. This corresponds to our visualizations in the supplementary.
|
141 |
+
|
142 |
+
## Citing our paper
|
143 |
+
If you make use of our work, please cite our paper:
|
144 |
+
```
|
145 |
+
@InProceedings{Chefer_2021_CVPR,
|
146 |
+
author = {Chefer, Hila and Gur, Shir and Wolf, Lior},
|
147 |
+
title = {Transformer Interpretability Beyond Attention Visualization},
|
148 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
149 |
+
month = {June},
|
150 |
+
year = {2021},
|
151 |
+
pages = {782-791}
|
152 |
+
}
|
153 |
+
```
|
Transformer-Explainability/Transformer_explainability.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Transformer-Explainability/baselines/ViT/ViT_LRP.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from baselines.ViT.helpers import load_pretrained
|
7 |
+
from baselines.ViT.layer_helpers import to_2tuple
|
8 |
+
from baselines.ViT.weight_init import trunc_normal_
|
9 |
+
from einops import rearrange
|
10 |
+
from modules.layers_ours import *
|
11 |
+
|
12 |
+
|
13 |
+
def _cfg(url="", **kwargs):
|
14 |
+
return {
|
15 |
+
"url": url,
|
16 |
+
"num_classes": 1000,
|
17 |
+
"input_size": (3, 224, 224),
|
18 |
+
"pool_size": None,
|
19 |
+
"crop_pct": 0.9,
|
20 |
+
"interpolation": "bicubic",
|
21 |
+
"first_conv": "patch_embed.proj",
|
22 |
+
"classifier": "head",
|
23 |
+
**kwargs,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
default_cfgs = {
|
28 |
+
# patch models
|
29 |
+
"vit_small_patch16_224": _cfg(
|
30 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
|
31 |
+
),
|
32 |
+
"vit_base_patch16_224": _cfg(
|
33 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
|
34 |
+
mean=(0.5, 0.5, 0.5),
|
35 |
+
std=(0.5, 0.5, 0.5),
|
36 |
+
),
|
37 |
+
"vit_large_patch16_224": _cfg(
|
38 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
|
39 |
+
mean=(0.5, 0.5, 0.5),
|
40 |
+
std=(0.5, 0.5, 0.5),
|
41 |
+
),
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
46 |
+
# adding residual consideration
|
47 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
48 |
+
batch_size = all_layer_matrices[0].shape[0]
|
49 |
+
eye = (
|
50 |
+
torch.eye(num_tokens)
|
51 |
+
.expand(batch_size, num_tokens, num_tokens)
|
52 |
+
.to(all_layer_matrices[0].device)
|
53 |
+
)
|
54 |
+
all_layer_matrices = [
|
55 |
+
all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
|
56 |
+
]
|
57 |
+
# all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
58 |
+
# for i in range(len(all_layer_matrices))]
|
59 |
+
joint_attention = all_layer_matrices[start_layer]
|
60 |
+
for i in range(start_layer + 1, len(all_layer_matrices)):
|
61 |
+
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
62 |
+
return joint_attention
|
63 |
+
|
64 |
+
|
65 |
+
class Mlp(nn.Module):
|
66 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
|
67 |
+
super().__init__()
|
68 |
+
out_features = out_features or in_features
|
69 |
+
hidden_features = hidden_features or in_features
|
70 |
+
self.fc1 = Linear(in_features, hidden_features)
|
71 |
+
self.act = GELU()
|
72 |
+
self.fc2 = Linear(hidden_features, out_features)
|
73 |
+
self.drop = Dropout(drop)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
x = self.fc1(x)
|
77 |
+
x = self.act(x)
|
78 |
+
x = self.drop(x)
|
79 |
+
x = self.fc2(x)
|
80 |
+
x = self.drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
def relprop(self, cam, **kwargs):
|
84 |
+
cam = self.drop.relprop(cam, **kwargs)
|
85 |
+
cam = self.fc2.relprop(cam, **kwargs)
|
86 |
+
cam = self.act.relprop(cam, **kwargs)
|
87 |
+
cam = self.fc1.relprop(cam, **kwargs)
|
88 |
+
return cam
|
89 |
+
|
90 |
+
|
91 |
+
class Attention(nn.Module):
|
92 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
93 |
+
super().__init__()
|
94 |
+
self.num_heads = num_heads
|
95 |
+
head_dim = dim // num_heads
|
96 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
97 |
+
self.scale = head_dim**-0.5
|
98 |
+
|
99 |
+
# A = Q*K^T
|
100 |
+
self.matmul1 = einsum("bhid,bhjd->bhij")
|
101 |
+
# attn = A*V
|
102 |
+
self.matmul2 = einsum("bhij,bhjd->bhid")
|
103 |
+
|
104 |
+
self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
|
105 |
+
self.attn_drop = Dropout(attn_drop)
|
106 |
+
self.proj = Linear(dim, dim)
|
107 |
+
self.proj_drop = Dropout(proj_drop)
|
108 |
+
self.softmax = Softmax(dim=-1)
|
109 |
+
|
110 |
+
self.attn_cam = None
|
111 |
+
self.attn = None
|
112 |
+
self.v = None
|
113 |
+
self.v_cam = None
|
114 |
+
self.attn_gradients = None
|
115 |
+
|
116 |
+
def get_attn(self):
|
117 |
+
return self.attn
|
118 |
+
|
119 |
+
def save_attn(self, attn):
|
120 |
+
self.attn = attn
|
121 |
+
|
122 |
+
def save_attn_cam(self, cam):
|
123 |
+
self.attn_cam = cam
|
124 |
+
|
125 |
+
def get_attn_cam(self):
|
126 |
+
return self.attn_cam
|
127 |
+
|
128 |
+
def get_v(self):
|
129 |
+
return self.v
|
130 |
+
|
131 |
+
def save_v(self, v):
|
132 |
+
self.v = v
|
133 |
+
|
134 |
+
def save_v_cam(self, cam):
|
135 |
+
self.v_cam = cam
|
136 |
+
|
137 |
+
def get_v_cam(self):
|
138 |
+
return self.v_cam
|
139 |
+
|
140 |
+
def save_attn_gradients(self, attn_gradients):
|
141 |
+
self.attn_gradients = attn_gradients
|
142 |
+
|
143 |
+
def get_attn_gradients(self):
|
144 |
+
return self.attn_gradients
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
b, n, _, h = *x.shape, self.num_heads
|
148 |
+
qkv = self.qkv(x)
|
149 |
+
q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
|
150 |
+
|
151 |
+
self.save_v(v)
|
152 |
+
|
153 |
+
dots = self.matmul1([q, k]) * self.scale
|
154 |
+
|
155 |
+
attn = self.softmax(dots)
|
156 |
+
attn = self.attn_drop(attn)
|
157 |
+
|
158 |
+
self.save_attn(attn)
|
159 |
+
attn.register_hook(self.save_attn_gradients)
|
160 |
+
|
161 |
+
out = self.matmul2([attn, v])
|
162 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
163 |
+
|
164 |
+
out = self.proj(out)
|
165 |
+
out = self.proj_drop(out)
|
166 |
+
return out
|
167 |
+
|
168 |
+
def relprop(self, cam, **kwargs):
|
169 |
+
cam = self.proj_drop.relprop(cam, **kwargs)
|
170 |
+
cam = self.proj.relprop(cam, **kwargs)
|
171 |
+
cam = rearrange(cam, "b n (h d) -> b h n d", h=self.num_heads)
|
172 |
+
|
173 |
+
# attn = A*V
|
174 |
+
(cam1, cam_v) = self.matmul2.relprop(cam, **kwargs)
|
175 |
+
cam1 /= 2
|
176 |
+
cam_v /= 2
|
177 |
+
|
178 |
+
self.save_v_cam(cam_v)
|
179 |
+
self.save_attn_cam(cam1)
|
180 |
+
|
181 |
+
cam1 = self.attn_drop.relprop(cam1, **kwargs)
|
182 |
+
cam1 = self.softmax.relprop(cam1, **kwargs)
|
183 |
+
|
184 |
+
# A = Q*K^T
|
185 |
+
(cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
|
186 |
+
cam_q /= 2
|
187 |
+
cam_k /= 2
|
188 |
+
|
189 |
+
cam_qkv = rearrange(
|
190 |
+
[cam_q, cam_k, cam_v],
|
191 |
+
"qkv b h n d -> b n (qkv h d)",
|
192 |
+
qkv=3,
|
193 |
+
h=self.num_heads,
|
194 |
+
)
|
195 |
+
|
196 |
+
return self.qkv.relprop(cam_qkv, **kwargs)
|
197 |
+
|
198 |
+
|
199 |
+
class Block(nn.Module):
|
200 |
+
def __init__(
|
201 |
+
self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
self.norm1 = LayerNorm(dim, eps=1e-6)
|
205 |
+
self.attn = Attention(
|
206 |
+
dim,
|
207 |
+
num_heads=num_heads,
|
208 |
+
qkv_bias=qkv_bias,
|
209 |
+
attn_drop=attn_drop,
|
210 |
+
proj_drop=drop,
|
211 |
+
)
|
212 |
+
self.norm2 = LayerNorm(dim, eps=1e-6)
|
213 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
214 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
|
215 |
+
|
216 |
+
self.add1 = Add()
|
217 |
+
self.add2 = Add()
|
218 |
+
self.clone1 = Clone()
|
219 |
+
self.clone2 = Clone()
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
x1, x2 = self.clone1(x, 2)
|
223 |
+
x = self.add1([x1, self.attn(self.norm1(x2))])
|
224 |
+
x1, x2 = self.clone2(x, 2)
|
225 |
+
x = self.add2([x1, self.mlp(self.norm2(x2))])
|
226 |
+
return x
|
227 |
+
|
228 |
+
def relprop(self, cam, **kwargs):
|
229 |
+
(cam1, cam2) = self.add2.relprop(cam, **kwargs)
|
230 |
+
cam2 = self.mlp.relprop(cam2, **kwargs)
|
231 |
+
cam2 = self.norm2.relprop(cam2, **kwargs)
|
232 |
+
cam = self.clone2.relprop((cam1, cam2), **kwargs)
|
233 |
+
|
234 |
+
(cam1, cam2) = self.add1.relprop(cam, **kwargs)
|
235 |
+
cam2 = self.attn.relprop(cam2, **kwargs)
|
236 |
+
cam2 = self.norm1.relprop(cam2, **kwargs)
|
237 |
+
cam = self.clone1.relprop((cam1, cam2), **kwargs)
|
238 |
+
return cam
|
239 |
+
|
240 |
+
|
241 |
+
class PatchEmbed(nn.Module):
|
242 |
+
"""Image to Patch Embedding"""
|
243 |
+
|
244 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
245 |
+
super().__init__()
|
246 |
+
img_size = to_2tuple(img_size)
|
247 |
+
patch_size = to_2tuple(patch_size)
|
248 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
249 |
+
self.img_size = img_size
|
250 |
+
self.patch_size = patch_size
|
251 |
+
self.num_patches = num_patches
|
252 |
+
|
253 |
+
self.proj = Conv2d(
|
254 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
255 |
+
)
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
B, C, H, W = x.shape
|
259 |
+
# FIXME look at relaxing size constraints
|
260 |
+
assert (
|
261 |
+
H == self.img_size[0] and W == self.img_size[1]
|
262 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
263 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
264 |
+
return x
|
265 |
+
|
266 |
+
def relprop(self, cam, **kwargs):
|
267 |
+
cam = cam.transpose(1, 2)
|
268 |
+
cam = cam.reshape(
|
269 |
+
cam.shape[0],
|
270 |
+
cam.shape[1],
|
271 |
+
(self.img_size[0] // self.patch_size[0]),
|
272 |
+
(self.img_size[1] // self.patch_size[1]),
|
273 |
+
)
|
274 |
+
return self.proj.relprop(cam, **kwargs)
|
275 |
+
|
276 |
+
|
277 |
+
class VisionTransformer(nn.Module):
|
278 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
279 |
+
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
img_size=224,
|
283 |
+
patch_size=16,
|
284 |
+
in_chans=3,
|
285 |
+
num_classes=1000,
|
286 |
+
embed_dim=768,
|
287 |
+
depth=12,
|
288 |
+
num_heads=12,
|
289 |
+
mlp_ratio=4.0,
|
290 |
+
qkv_bias=False,
|
291 |
+
mlp_head=False,
|
292 |
+
drop_rate=0.0,
|
293 |
+
attn_drop_rate=0.0,
|
294 |
+
):
|
295 |
+
super().__init__()
|
296 |
+
self.num_classes = num_classes
|
297 |
+
self.num_features = (
|
298 |
+
self.embed_dim
|
299 |
+
) = embed_dim # num_features for consistency with other models
|
300 |
+
self.patch_embed = PatchEmbed(
|
301 |
+
img_size=img_size,
|
302 |
+
patch_size=patch_size,
|
303 |
+
in_chans=in_chans,
|
304 |
+
embed_dim=embed_dim,
|
305 |
+
)
|
306 |
+
num_patches = self.patch_embed.num_patches
|
307 |
+
|
308 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
309 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
310 |
+
|
311 |
+
self.blocks = nn.ModuleList(
|
312 |
+
[
|
313 |
+
Block(
|
314 |
+
dim=embed_dim,
|
315 |
+
num_heads=num_heads,
|
316 |
+
mlp_ratio=mlp_ratio,
|
317 |
+
qkv_bias=qkv_bias,
|
318 |
+
drop=drop_rate,
|
319 |
+
attn_drop=attn_drop_rate,
|
320 |
+
)
|
321 |
+
for i in range(depth)
|
322 |
+
]
|
323 |
+
)
|
324 |
+
|
325 |
+
self.norm = LayerNorm(embed_dim)
|
326 |
+
if mlp_head:
|
327 |
+
# paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
|
328 |
+
self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
|
329 |
+
else:
|
330 |
+
# with a single Linear layer as head, the param count within rounding of paper
|
331 |
+
self.head = Linear(embed_dim, num_classes)
|
332 |
+
|
333 |
+
# FIXME not quite sure what the proper weight init is supposed to be,
|
334 |
+
# normal / trunc normal w/ std == .02 similar to other Bert like transformers
|
335 |
+
trunc_normal_(self.pos_embed, std=0.02) # embeddings same as weights?
|
336 |
+
trunc_normal_(self.cls_token, std=0.02)
|
337 |
+
self.apply(self._init_weights)
|
338 |
+
|
339 |
+
self.pool = IndexSelect()
|
340 |
+
self.add = Add()
|
341 |
+
|
342 |
+
self.inp_grad = None
|
343 |
+
|
344 |
+
def save_inp_grad(self, grad):
|
345 |
+
self.inp_grad = grad
|
346 |
+
|
347 |
+
def get_inp_grad(self):
|
348 |
+
return self.inp_grad
|
349 |
+
|
350 |
+
def _init_weights(self, m):
|
351 |
+
if isinstance(m, nn.Linear):
|
352 |
+
trunc_normal_(m.weight, std=0.02)
|
353 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
354 |
+
nn.init.constant_(m.bias, 0)
|
355 |
+
elif isinstance(m, nn.LayerNorm):
|
356 |
+
nn.init.constant_(m.bias, 0)
|
357 |
+
nn.init.constant_(m.weight, 1.0)
|
358 |
+
|
359 |
+
@property
|
360 |
+
def no_weight_decay(self):
|
361 |
+
return {"pos_embed", "cls_token"}
|
362 |
+
|
363 |
+
def forward(self, x):
|
364 |
+
B = x.shape[0]
|
365 |
+
x = self.patch_embed(x)
|
366 |
+
|
367 |
+
cls_tokens = self.cls_token.expand(
|
368 |
+
B, -1, -1
|
369 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
370 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
371 |
+
x = self.add([x, self.pos_embed])
|
372 |
+
|
373 |
+
x.register_hook(self.save_inp_grad)
|
374 |
+
|
375 |
+
for blk in self.blocks:
|
376 |
+
x = blk(x)
|
377 |
+
|
378 |
+
x = self.norm(x)
|
379 |
+
x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
|
380 |
+
x = x.squeeze(1)
|
381 |
+
x = self.head(x)
|
382 |
+
return x
|
383 |
+
|
384 |
+
def relprop(
|
385 |
+
self,
|
386 |
+
cam=None,
|
387 |
+
method="transformer_attribution",
|
388 |
+
is_ablation=False,
|
389 |
+
start_layer=0,
|
390 |
+
**kwargs,
|
391 |
+
):
|
392 |
+
# print(kwargs)
|
393 |
+
# print("conservation 1", cam.sum())
|
394 |
+
cam = self.head.relprop(cam, **kwargs)
|
395 |
+
cam = cam.unsqueeze(1)
|
396 |
+
cam = self.pool.relprop(cam, **kwargs)
|
397 |
+
cam = self.norm.relprop(cam, **kwargs)
|
398 |
+
for blk in reversed(self.blocks):
|
399 |
+
cam = blk.relprop(cam, **kwargs)
|
400 |
+
|
401 |
+
# print("conservation 2", cam.sum())
|
402 |
+
# print("min", cam.min())
|
403 |
+
|
404 |
+
if method == "full":
|
405 |
+
(cam, _) = self.add.relprop(cam, **kwargs)
|
406 |
+
cam = cam[:, 1:]
|
407 |
+
cam = self.patch_embed.relprop(cam, **kwargs)
|
408 |
+
# sum on channels
|
409 |
+
cam = cam.sum(dim=1)
|
410 |
+
return cam
|
411 |
+
|
412 |
+
elif method == "rollout":
|
413 |
+
# cam rollout
|
414 |
+
attn_cams = []
|
415 |
+
for blk in self.blocks:
|
416 |
+
attn_heads = blk.attn.get_attn_cam().clamp(min=0)
|
417 |
+
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
|
418 |
+
attn_cams.append(avg_heads)
|
419 |
+
cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
|
420 |
+
cam = cam[:, 0, 1:]
|
421 |
+
return cam
|
422 |
+
|
423 |
+
# our method, method name grad is legacy
|
424 |
+
elif method == "transformer_attribution" or method == "grad":
|
425 |
+
cams = []
|
426 |
+
for blk in self.blocks:
|
427 |
+
grad = blk.attn.get_attn_gradients()
|
428 |
+
cam = blk.attn.get_attn_cam()
|
429 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
430 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
431 |
+
cam = grad * cam
|
432 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
433 |
+
cams.append(cam.unsqueeze(0))
|
434 |
+
rollout = compute_rollout_attention(cams, start_layer=start_layer)
|
435 |
+
cam = rollout[:, 0, 1:]
|
436 |
+
return cam
|
437 |
+
|
438 |
+
elif method == "last_layer":
|
439 |
+
cam = self.blocks[-1].attn.get_attn_cam()
|
440 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
441 |
+
if is_ablation:
|
442 |
+
grad = self.blocks[-1].attn.get_attn_gradients()
|
443 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
444 |
+
cam = grad * cam
|
445 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
446 |
+
cam = cam[0, 1:]
|
447 |
+
return cam
|
448 |
+
|
449 |
+
elif method == "last_layer_attn":
|
450 |
+
cam = self.blocks[-1].attn.get_attn()
|
451 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
452 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
453 |
+
cam = cam[0, 1:]
|
454 |
+
return cam
|
455 |
+
|
456 |
+
elif method == "second_layer":
|
457 |
+
cam = self.blocks[1].attn.get_attn_cam()
|
458 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
459 |
+
if is_ablation:
|
460 |
+
grad = self.blocks[1].attn.get_attn_gradients()
|
461 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
462 |
+
cam = grad * cam
|
463 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
464 |
+
cam = cam[0, 1:]
|
465 |
+
return cam
|
466 |
+
|
467 |
+
|
468 |
+
def _conv_filter(state_dict, patch_size=16):
|
469 |
+
"""convert patch embedding weight from manual patchify + linear proj to conv"""
|
470 |
+
out_dict = {}
|
471 |
+
for k, v in state_dict.items():
|
472 |
+
if "patch_embed.proj.weight" in k:
|
473 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
474 |
+
out_dict[k] = v
|
475 |
+
return out_dict
|
476 |
+
|
477 |
+
|
478 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
479 |
+
model = VisionTransformer(
|
480 |
+
patch_size=16,
|
481 |
+
embed_dim=768,
|
482 |
+
depth=12,
|
483 |
+
num_heads=12,
|
484 |
+
mlp_ratio=4,
|
485 |
+
qkv_bias=True,
|
486 |
+
**kwargs,
|
487 |
+
)
|
488 |
+
model.default_cfg = default_cfgs["vit_base_patch16_224"]
|
489 |
+
if pretrained:
|
490 |
+
load_pretrained(
|
491 |
+
model,
|
492 |
+
num_classes=model.num_classes,
|
493 |
+
in_chans=kwargs.get("in_chans", 3),
|
494 |
+
filter_fn=_conv_filter,
|
495 |
+
)
|
496 |
+
return model
|
497 |
+
|
498 |
+
|
499 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
500 |
+
model = VisionTransformer(
|
501 |
+
patch_size=16,
|
502 |
+
embed_dim=1024,
|
503 |
+
depth=24,
|
504 |
+
num_heads=16,
|
505 |
+
mlp_ratio=4,
|
506 |
+
qkv_bias=True,
|
507 |
+
**kwargs,
|
508 |
+
)
|
509 |
+
model.default_cfg = default_cfgs["vit_large_patch16_224"]
|
510 |
+
if pretrained:
|
511 |
+
load_pretrained(
|
512 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)
|
513 |
+
)
|
514 |
+
return model
|
515 |
+
|
516 |
+
|
517 |
+
def deit_base_patch16_224(pretrained=False, **kwargs):
|
518 |
+
model = VisionTransformer(
|
519 |
+
patch_size=16,
|
520 |
+
embed_dim=768,
|
521 |
+
depth=12,
|
522 |
+
num_heads=12,
|
523 |
+
mlp_ratio=4,
|
524 |
+
qkv_bias=True,
|
525 |
+
**kwargs,
|
526 |
+
)
|
527 |
+
model.default_cfg = _cfg()
|
528 |
+
if pretrained:
|
529 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
530 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
531 |
+
map_location="cpu",
|
532 |
+
check_hash=True,
|
533 |
+
)
|
534 |
+
model.load_state_dict(checkpoint["model"])
|
535 |
+
return model
|
Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from numpy import *
|
6 |
+
|
7 |
+
|
8 |
+
# compute rollout between attention layers
|
9 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
10 |
+
# adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
|
11 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
12 |
+
batch_size = all_layer_matrices[0].shape[0]
|
13 |
+
eye = (
|
14 |
+
torch.eye(num_tokens)
|
15 |
+
.expand(batch_size, num_tokens, num_tokens)
|
16 |
+
.to(all_layer_matrices[0].device)
|
17 |
+
)
|
18 |
+
all_layer_matrices = [
|
19 |
+
all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
|
20 |
+
]
|
21 |
+
matrices_aug = [
|
22 |
+
all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
23 |
+
for i in range(len(all_layer_matrices))
|
24 |
+
]
|
25 |
+
joint_attention = matrices_aug[start_layer]
|
26 |
+
for i in range(start_layer + 1, len(matrices_aug)):
|
27 |
+
joint_attention = matrices_aug[i].bmm(joint_attention)
|
28 |
+
return joint_attention
|
29 |
+
|
30 |
+
|
31 |
+
class LRP:
|
32 |
+
def __init__(self, model):
|
33 |
+
self.model = model
|
34 |
+
self.model.eval()
|
35 |
+
|
36 |
+
def generate_LRP(
|
37 |
+
self,
|
38 |
+
input,
|
39 |
+
index=None,
|
40 |
+
method="transformer_attribution",
|
41 |
+
is_ablation=False,
|
42 |
+
start_layer=0,
|
43 |
+
):
|
44 |
+
output = self.model(input)
|
45 |
+
kwargs = {"alpha": 1}
|
46 |
+
if index == None:
|
47 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
48 |
+
|
49 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
50 |
+
one_hot[0, index] = 1
|
51 |
+
one_hot_vector = one_hot
|
52 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
53 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
54 |
+
|
55 |
+
self.model.zero_grad()
|
56 |
+
one_hot.backward(retain_graph=True)
|
57 |
+
|
58 |
+
return self.model.relprop(
|
59 |
+
torch.tensor(one_hot_vector).to(input.device),
|
60 |
+
method=method,
|
61 |
+
is_ablation=is_ablation,
|
62 |
+
start_layer=start_layer,
|
63 |
+
**kwargs
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
class Baselines:
|
68 |
+
def __init__(self, model):
|
69 |
+
self.model = model
|
70 |
+
self.model.eval()
|
71 |
+
|
72 |
+
def generate_cam_attn(self, input, index=None):
|
73 |
+
output = self.model(input.cuda(), register_hook=True)
|
74 |
+
if index == None:
|
75 |
+
index = np.argmax(output.cpu().data.numpy())
|
76 |
+
|
77 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
78 |
+
one_hot[0][index] = 1
|
79 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
80 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
81 |
+
|
82 |
+
self.model.zero_grad()
|
83 |
+
one_hot.backward(retain_graph=True)
|
84 |
+
#################### attn
|
85 |
+
grad = self.model.blocks[-1].attn.get_attn_gradients()
|
86 |
+
cam = self.model.blocks[-1].attn.get_attention_map()
|
87 |
+
cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
|
88 |
+
grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
|
89 |
+
grad = grad.mean(dim=[1, 2], keepdim=True)
|
90 |
+
cam = (cam * grad).mean(0).clamp(min=0)
|
91 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min())
|
92 |
+
|
93 |
+
return cam
|
94 |
+
#################### attn
|
95 |
+
|
96 |
+
def generate_rollout(self, input, start_layer=0):
|
97 |
+
self.model(input)
|
98 |
+
blocks = self.model.blocks
|
99 |
+
all_layer_attentions = []
|
100 |
+
for blk in blocks:
|
101 |
+
attn_heads = blk.attn.get_attention_map()
|
102 |
+
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
|
103 |
+
all_layer_attentions.append(avg_heads)
|
104 |
+
rollout = compute_rollout_attention(
|
105 |
+
all_layer_attentions, start_layer=start_layer
|
106 |
+
)
|
107 |
+
return rollout[:, 0, 1:]
|
Transformer-Explainability/baselines/ViT/ViT_new.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from baselines.ViT.helpers import load_pretrained
|
9 |
+
from baselines.ViT.layer_helpers import to_2tuple
|
10 |
+
from baselines.ViT.weight_init import trunc_normal_
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
|
14 |
+
def _cfg(url="", **kwargs):
|
15 |
+
return {
|
16 |
+
"url": url,
|
17 |
+
"num_classes": 1000,
|
18 |
+
"input_size": (3, 224, 224),
|
19 |
+
"pool_size": None,
|
20 |
+
"crop_pct": 0.9,
|
21 |
+
"interpolation": "bicubic",
|
22 |
+
"first_conv": "patch_embed.proj",
|
23 |
+
"classifier": "head",
|
24 |
+
**kwargs,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
default_cfgs = {
|
29 |
+
# patch models
|
30 |
+
"vit_small_patch16_224": _cfg(
|
31 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
|
32 |
+
),
|
33 |
+
"vit_base_patch16_224": _cfg(
|
34 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
|
35 |
+
mean=(0.5, 0.5, 0.5),
|
36 |
+
std=(0.5, 0.5, 0.5),
|
37 |
+
),
|
38 |
+
"vit_large_patch16_224": _cfg(
|
39 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
|
40 |
+
mean=(0.5, 0.5, 0.5),
|
41 |
+
std=(0.5, 0.5, 0.5),
|
42 |
+
),
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
class Mlp(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
in_features,
|
50 |
+
hidden_features=None,
|
51 |
+
out_features=None,
|
52 |
+
act_layer=nn.GELU,
|
53 |
+
drop=0.0,
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
out_features = out_features or in_features
|
57 |
+
hidden_features = hidden_features or in_features
|
58 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
59 |
+
self.act = act_layer()
|
60 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
61 |
+
self.drop = nn.Dropout(drop)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x = self.fc1(x)
|
65 |
+
x = self.act(x)
|
66 |
+
x = self.drop(x)
|
67 |
+
x = self.fc2(x)
|
68 |
+
x = self.drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class Attention(nn.Module):
|
73 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
74 |
+
super().__init__()
|
75 |
+
self.num_heads = num_heads
|
76 |
+
head_dim = dim // num_heads
|
77 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
78 |
+
self.scale = head_dim**-0.5
|
79 |
+
|
80 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
81 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
82 |
+
self.proj = nn.Linear(dim, dim)
|
83 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
84 |
+
|
85 |
+
self.attn_gradients = None
|
86 |
+
self.attention_map = None
|
87 |
+
|
88 |
+
def save_attn_gradients(self, attn_gradients):
|
89 |
+
self.attn_gradients = attn_gradients
|
90 |
+
|
91 |
+
def get_attn_gradients(self):
|
92 |
+
return self.attn_gradients
|
93 |
+
|
94 |
+
def save_attention_map(self, attention_map):
|
95 |
+
self.attention_map = attention_map
|
96 |
+
|
97 |
+
def get_attention_map(self):
|
98 |
+
return self.attention_map
|
99 |
+
|
100 |
+
def forward(self, x, register_hook=False):
|
101 |
+
b, n, _, h = *x.shape, self.num_heads
|
102 |
+
|
103 |
+
# self.save_output(x)
|
104 |
+
# x.register_hook(self.save_output_grad)
|
105 |
+
|
106 |
+
qkv = self.qkv(x)
|
107 |
+
q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
|
108 |
+
|
109 |
+
dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
|
110 |
+
|
111 |
+
attn = dots.softmax(dim=-1)
|
112 |
+
attn = self.attn_drop(attn)
|
113 |
+
|
114 |
+
out = torch.einsum("bhij,bhjd->bhid", attn, v)
|
115 |
+
|
116 |
+
self.save_attention_map(attn)
|
117 |
+
if register_hook:
|
118 |
+
attn.register_hook(self.save_attn_gradients)
|
119 |
+
|
120 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
121 |
+
out = self.proj(out)
|
122 |
+
out = self.proj_drop(out)
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
class Block(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
dim,
|
130 |
+
num_heads,
|
131 |
+
mlp_ratio=4.0,
|
132 |
+
qkv_bias=False,
|
133 |
+
drop=0.0,
|
134 |
+
attn_drop=0.0,
|
135 |
+
act_layer=nn.GELU,
|
136 |
+
norm_layer=nn.LayerNorm,
|
137 |
+
):
|
138 |
+
super().__init__()
|
139 |
+
self.norm1 = norm_layer(dim)
|
140 |
+
self.attn = Attention(
|
141 |
+
dim,
|
142 |
+
num_heads=num_heads,
|
143 |
+
qkv_bias=qkv_bias,
|
144 |
+
attn_drop=attn_drop,
|
145 |
+
proj_drop=drop,
|
146 |
+
)
|
147 |
+
self.norm2 = norm_layer(dim)
|
148 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
149 |
+
self.mlp = Mlp(
|
150 |
+
in_features=dim,
|
151 |
+
hidden_features=mlp_hidden_dim,
|
152 |
+
act_layer=act_layer,
|
153 |
+
drop=drop,
|
154 |
+
)
|
155 |
+
|
156 |
+
def forward(self, x, register_hook=False):
|
157 |
+
x = x + self.attn(self.norm1(x), register_hook=register_hook)
|
158 |
+
x = x + self.mlp(self.norm2(x))
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
class PatchEmbed(nn.Module):
|
163 |
+
"""Image to Patch Embedding"""
|
164 |
+
|
165 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
166 |
+
super().__init__()
|
167 |
+
img_size = to_2tuple(img_size)
|
168 |
+
patch_size = to_2tuple(patch_size)
|
169 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
170 |
+
self.img_size = img_size
|
171 |
+
self.patch_size = patch_size
|
172 |
+
self.num_patches = num_patches
|
173 |
+
|
174 |
+
self.proj = nn.Conv2d(
|
175 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
B, C, H, W = x.shape
|
180 |
+
# FIXME look at relaxing size constraints
|
181 |
+
assert (
|
182 |
+
H == self.img_size[0] and W == self.img_size[1]
|
183 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
184 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class VisionTransformer(nn.Module):
|
189 |
+
"""Vision Transformer"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
img_size=224,
|
194 |
+
patch_size=16,
|
195 |
+
in_chans=3,
|
196 |
+
num_classes=1000,
|
197 |
+
embed_dim=768,
|
198 |
+
depth=12,
|
199 |
+
num_heads=12,
|
200 |
+
mlp_ratio=4.0,
|
201 |
+
qkv_bias=False,
|
202 |
+
drop_rate=0.0,
|
203 |
+
attn_drop_rate=0.0,
|
204 |
+
norm_layer=nn.LayerNorm,
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
self.num_classes = num_classes
|
208 |
+
self.num_features = (
|
209 |
+
self.embed_dim
|
210 |
+
) = embed_dim # num_features for consistency with other models
|
211 |
+
self.patch_embed = PatchEmbed(
|
212 |
+
img_size=img_size,
|
213 |
+
patch_size=patch_size,
|
214 |
+
in_chans=in_chans,
|
215 |
+
embed_dim=embed_dim,
|
216 |
+
)
|
217 |
+
num_patches = self.patch_embed.num_patches
|
218 |
+
|
219 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
220 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
221 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
222 |
+
|
223 |
+
self.blocks = nn.ModuleList(
|
224 |
+
[
|
225 |
+
Block(
|
226 |
+
dim=embed_dim,
|
227 |
+
num_heads=num_heads,
|
228 |
+
mlp_ratio=mlp_ratio,
|
229 |
+
qkv_bias=qkv_bias,
|
230 |
+
drop=drop_rate,
|
231 |
+
attn_drop=attn_drop_rate,
|
232 |
+
norm_layer=norm_layer,
|
233 |
+
)
|
234 |
+
for i in range(depth)
|
235 |
+
]
|
236 |
+
)
|
237 |
+
self.norm = norm_layer(embed_dim)
|
238 |
+
|
239 |
+
# Classifier head
|
240 |
+
self.head = (
|
241 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
242 |
+
)
|
243 |
+
|
244 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
245 |
+
trunc_normal_(self.cls_token, std=0.02)
|
246 |
+
self.apply(self._init_weights)
|
247 |
+
|
248 |
+
def _init_weights(self, m):
|
249 |
+
if isinstance(m, nn.Linear):
|
250 |
+
trunc_normal_(m.weight, std=0.02)
|
251 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
252 |
+
nn.init.constant_(m.bias, 0)
|
253 |
+
elif isinstance(m, nn.LayerNorm):
|
254 |
+
nn.init.constant_(m.bias, 0)
|
255 |
+
nn.init.constant_(m.weight, 1.0)
|
256 |
+
|
257 |
+
@torch.jit.ignore
|
258 |
+
def no_weight_decay(self):
|
259 |
+
return {"pos_embed", "cls_token"}
|
260 |
+
|
261 |
+
def forward(self, x, register_hook=False):
|
262 |
+
B = x.shape[0]
|
263 |
+
x = self.patch_embed(x)
|
264 |
+
|
265 |
+
cls_tokens = self.cls_token.expand(
|
266 |
+
B, -1, -1
|
267 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
268 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
269 |
+
x = x + self.pos_embed
|
270 |
+
x = self.pos_drop(x)
|
271 |
+
|
272 |
+
for blk in self.blocks:
|
273 |
+
x = blk(x, register_hook=register_hook)
|
274 |
+
|
275 |
+
x = self.norm(x)
|
276 |
+
x = x[:, 0]
|
277 |
+
x = self.head(x)
|
278 |
+
return x
|
279 |
+
|
280 |
+
|
281 |
+
def _conv_filter(state_dict, patch_size=16):
|
282 |
+
"""convert patch embedding weight from manual patchify + linear proj to conv"""
|
283 |
+
out_dict = {}
|
284 |
+
for k, v in state_dict.items():
|
285 |
+
if "patch_embed.proj.weight" in k:
|
286 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
287 |
+
out_dict[k] = v
|
288 |
+
return out_dict
|
289 |
+
|
290 |
+
|
291 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
292 |
+
model = VisionTransformer(
|
293 |
+
patch_size=16,
|
294 |
+
embed_dim=768,
|
295 |
+
depth=12,
|
296 |
+
num_heads=12,
|
297 |
+
mlp_ratio=4,
|
298 |
+
qkv_bias=True,
|
299 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
300 |
+
**kwargs,
|
301 |
+
)
|
302 |
+
model.default_cfg = default_cfgs["vit_base_patch16_224"]
|
303 |
+
if pretrained:
|
304 |
+
load_pretrained(
|
305 |
+
model,
|
306 |
+
num_classes=model.num_classes,
|
307 |
+
in_chans=kwargs.get("in_chans", 3),
|
308 |
+
filter_fn=_conv_filter,
|
309 |
+
)
|
310 |
+
return model
|
311 |
+
|
312 |
+
|
313 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
314 |
+
model = VisionTransformer(
|
315 |
+
patch_size=16,
|
316 |
+
embed_dim=1024,
|
317 |
+
depth=24,
|
318 |
+
num_heads=16,
|
319 |
+
mlp_ratio=4,
|
320 |
+
qkv_bias=True,
|
321 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
322 |
+
**kwargs,
|
323 |
+
)
|
324 |
+
model.default_cfg = default_cfgs["vit_large_patch16_224"]
|
325 |
+
if pretrained:
|
326 |
+
load_pretrained(
|
327 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)
|
328 |
+
)
|
329 |
+
return model
|
Transformer-Explainability/baselines/ViT/ViT_orig_LRP.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from baselines.ViT.helpers import load_pretrained
|
7 |
+
from baselines.ViT.layer_helpers import to_2tuple
|
8 |
+
from baselines.ViT.weight_init import trunc_normal_
|
9 |
+
from einops import rearrange
|
10 |
+
from modules.layers_lrp import *
|
11 |
+
|
12 |
+
|
13 |
+
def _cfg(url="", **kwargs):
|
14 |
+
return {
|
15 |
+
"url": url,
|
16 |
+
"num_classes": 1000,
|
17 |
+
"input_size": (3, 224, 224),
|
18 |
+
"pool_size": None,
|
19 |
+
"crop_pct": 0.9,
|
20 |
+
"interpolation": "bicubic",
|
21 |
+
"first_conv": "patch_embed.proj",
|
22 |
+
"classifier": "head",
|
23 |
+
**kwargs,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
default_cfgs = {
|
28 |
+
# patch models
|
29 |
+
"vit_small_patch16_224": _cfg(
|
30 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
|
31 |
+
),
|
32 |
+
"vit_base_patch16_224": _cfg(
|
33 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
|
34 |
+
mean=(0.5, 0.5, 0.5),
|
35 |
+
std=(0.5, 0.5, 0.5),
|
36 |
+
),
|
37 |
+
"vit_large_patch16_224": _cfg(
|
38 |
+
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
|
39 |
+
mean=(0.5, 0.5, 0.5),
|
40 |
+
std=(0.5, 0.5, 0.5),
|
41 |
+
),
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
46 |
+
# adding residual consideration
|
47 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
48 |
+
batch_size = all_layer_matrices[0].shape[0]
|
49 |
+
eye = (
|
50 |
+
torch.eye(num_tokens)
|
51 |
+
.expand(batch_size, num_tokens, num_tokens)
|
52 |
+
.to(all_layer_matrices[0].device)
|
53 |
+
)
|
54 |
+
all_layer_matrices = [
|
55 |
+
all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
|
56 |
+
]
|
57 |
+
# all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
58 |
+
# for i in range(len(all_layer_matrices))]
|
59 |
+
joint_attention = all_layer_matrices[start_layer]
|
60 |
+
for i in range(start_layer + 1, len(all_layer_matrices)):
|
61 |
+
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
62 |
+
return joint_attention
|
63 |
+
|
64 |
+
|
65 |
+
class Mlp(nn.Module):
|
66 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
|
67 |
+
super().__init__()
|
68 |
+
out_features = out_features or in_features
|
69 |
+
hidden_features = hidden_features or in_features
|
70 |
+
self.fc1 = Linear(in_features, hidden_features)
|
71 |
+
self.act = GELU()
|
72 |
+
self.fc2 = Linear(hidden_features, out_features)
|
73 |
+
self.drop = Dropout(drop)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
x = self.fc1(x)
|
77 |
+
x = self.act(x)
|
78 |
+
x = self.drop(x)
|
79 |
+
x = self.fc2(x)
|
80 |
+
x = self.drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
def relprop(self, cam, **kwargs):
|
84 |
+
cam = self.drop.relprop(cam, **kwargs)
|
85 |
+
cam = self.fc2.relprop(cam, **kwargs)
|
86 |
+
cam = self.act.relprop(cam, **kwargs)
|
87 |
+
cam = self.fc1.relprop(cam, **kwargs)
|
88 |
+
return cam
|
89 |
+
|
90 |
+
|
91 |
+
class Attention(nn.Module):
|
92 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
93 |
+
super().__init__()
|
94 |
+
self.num_heads = num_heads
|
95 |
+
head_dim = dim // num_heads
|
96 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
97 |
+
self.scale = head_dim**-0.5
|
98 |
+
|
99 |
+
# A = Q*K^T
|
100 |
+
self.matmul1 = einsum("bhid,bhjd->bhij")
|
101 |
+
# attn = A*V
|
102 |
+
self.matmul2 = einsum("bhij,bhjd->bhid")
|
103 |
+
|
104 |
+
self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
|
105 |
+
self.attn_drop = Dropout(attn_drop)
|
106 |
+
self.proj = Linear(dim, dim)
|
107 |
+
self.proj_drop = Dropout(proj_drop)
|
108 |
+
self.softmax = Softmax(dim=-1)
|
109 |
+
|
110 |
+
self.attn_cam = None
|
111 |
+
self.attn = None
|
112 |
+
self.v = None
|
113 |
+
self.v_cam = None
|
114 |
+
self.attn_gradients = None
|
115 |
+
|
116 |
+
def get_attn(self):
|
117 |
+
return self.attn
|
118 |
+
|
119 |
+
def save_attn(self, attn):
|
120 |
+
self.attn = attn
|
121 |
+
|
122 |
+
def save_attn_cam(self, cam):
|
123 |
+
self.attn_cam = cam
|
124 |
+
|
125 |
+
def get_attn_cam(self):
|
126 |
+
return self.attn_cam
|
127 |
+
|
128 |
+
def get_v(self):
|
129 |
+
return self.v
|
130 |
+
|
131 |
+
def save_v(self, v):
|
132 |
+
self.v = v
|
133 |
+
|
134 |
+
def save_v_cam(self, cam):
|
135 |
+
self.v_cam = cam
|
136 |
+
|
137 |
+
def get_v_cam(self):
|
138 |
+
return self.v_cam
|
139 |
+
|
140 |
+
def save_attn_gradients(self, attn_gradients):
|
141 |
+
self.attn_gradients = attn_gradients
|
142 |
+
|
143 |
+
def get_attn_gradients(self):
|
144 |
+
return self.attn_gradients
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
b, n, _, h = *x.shape, self.num_heads
|
148 |
+
qkv = self.qkv(x)
|
149 |
+
q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
|
150 |
+
|
151 |
+
self.save_v(v)
|
152 |
+
|
153 |
+
dots = self.matmul1([q, k]) * self.scale
|
154 |
+
|
155 |
+
attn = self.softmax(dots)
|
156 |
+
attn = self.attn_drop(attn)
|
157 |
+
|
158 |
+
self.save_attn(attn)
|
159 |
+
attn.register_hook(self.save_attn_gradients)
|
160 |
+
|
161 |
+
out = self.matmul2([attn, v])
|
162 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
163 |
+
|
164 |
+
out = self.proj(out)
|
165 |
+
out = self.proj_drop(out)
|
166 |
+
return out
|
167 |
+
|
168 |
+
def relprop(self, cam, **kwargs):
|
169 |
+
cam = self.proj_drop.relprop(cam, **kwargs)
|
170 |
+
cam = self.proj.relprop(cam, **kwargs)
|
171 |
+
cam = rearrange(cam, "b n (h d) -> b h n d", h=self.num_heads)
|
172 |
+
|
173 |
+
# attn = A*V
|
174 |
+
(cam1, cam_v) = self.matmul2.relprop(cam, **kwargs)
|
175 |
+
cam1 /= 2
|
176 |
+
cam_v /= 2
|
177 |
+
|
178 |
+
self.save_v_cam(cam_v)
|
179 |
+
self.save_attn_cam(cam1)
|
180 |
+
|
181 |
+
cam1 = self.attn_drop.relprop(cam1, **kwargs)
|
182 |
+
cam1 = self.softmax.relprop(cam1, **kwargs)
|
183 |
+
|
184 |
+
# A = Q*K^T
|
185 |
+
(cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
|
186 |
+
cam_q /= 2
|
187 |
+
cam_k /= 2
|
188 |
+
|
189 |
+
cam_qkv = rearrange(
|
190 |
+
[cam_q, cam_k, cam_v],
|
191 |
+
"qkv b h n d -> b n (qkv h d)",
|
192 |
+
qkv=3,
|
193 |
+
h=self.num_heads,
|
194 |
+
)
|
195 |
+
|
196 |
+
return self.qkv.relprop(cam_qkv, **kwargs)
|
197 |
+
|
198 |
+
|
199 |
+
class Block(nn.Module):
|
200 |
+
def __init__(
|
201 |
+
self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
self.norm1 = LayerNorm(dim, eps=1e-6)
|
205 |
+
self.attn = Attention(
|
206 |
+
dim,
|
207 |
+
num_heads=num_heads,
|
208 |
+
qkv_bias=qkv_bias,
|
209 |
+
attn_drop=attn_drop,
|
210 |
+
proj_drop=drop,
|
211 |
+
)
|
212 |
+
self.norm2 = LayerNorm(dim, eps=1e-6)
|
213 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
214 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
|
215 |
+
|
216 |
+
self.add1 = Add()
|
217 |
+
self.add2 = Add()
|
218 |
+
self.clone1 = Clone()
|
219 |
+
self.clone2 = Clone()
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
x1, x2 = self.clone1(x, 2)
|
223 |
+
x = self.add1([x1, self.attn(self.norm1(x2))])
|
224 |
+
x1, x2 = self.clone2(x, 2)
|
225 |
+
x = self.add2([x1, self.mlp(self.norm2(x2))])
|
226 |
+
return x
|
227 |
+
|
228 |
+
def relprop(self, cam, **kwargs):
|
229 |
+
(cam1, cam2) = self.add2.relprop(cam, **kwargs)
|
230 |
+
cam2 = self.mlp.relprop(cam2, **kwargs)
|
231 |
+
cam2 = self.norm2.relprop(cam2, **kwargs)
|
232 |
+
cam = self.clone2.relprop((cam1, cam2), **kwargs)
|
233 |
+
|
234 |
+
(cam1, cam2) = self.add1.relprop(cam, **kwargs)
|
235 |
+
cam2 = self.attn.relprop(cam2, **kwargs)
|
236 |
+
cam2 = self.norm1.relprop(cam2, **kwargs)
|
237 |
+
cam = self.clone1.relprop((cam1, cam2), **kwargs)
|
238 |
+
return cam
|
239 |
+
|
240 |
+
|
241 |
+
class PatchEmbed(nn.Module):
|
242 |
+
"""Image to Patch Embedding"""
|
243 |
+
|
244 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
245 |
+
super().__init__()
|
246 |
+
img_size = to_2tuple(img_size)
|
247 |
+
patch_size = to_2tuple(patch_size)
|
248 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
249 |
+
self.img_size = img_size
|
250 |
+
self.patch_size = patch_size
|
251 |
+
self.num_patches = num_patches
|
252 |
+
|
253 |
+
self.proj = Conv2d(
|
254 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
255 |
+
)
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
B, C, H, W = x.shape
|
259 |
+
# FIXME look at relaxing size constraints
|
260 |
+
assert (
|
261 |
+
H == self.img_size[0] and W == self.img_size[1]
|
262 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
263 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
264 |
+
return x
|
265 |
+
|
266 |
+
def relprop(self, cam, **kwargs):
|
267 |
+
cam = cam.transpose(1, 2)
|
268 |
+
cam = cam.reshape(
|
269 |
+
cam.shape[0],
|
270 |
+
cam.shape[1],
|
271 |
+
(self.img_size[0] // self.patch_size[0]),
|
272 |
+
(self.img_size[1] // self.patch_size[1]),
|
273 |
+
)
|
274 |
+
return self.proj.relprop(cam, **kwargs)
|
275 |
+
|
276 |
+
|
277 |
+
class VisionTransformer(nn.Module):
|
278 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
279 |
+
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
img_size=224,
|
283 |
+
patch_size=16,
|
284 |
+
in_chans=3,
|
285 |
+
num_classes=1000,
|
286 |
+
embed_dim=768,
|
287 |
+
depth=12,
|
288 |
+
num_heads=12,
|
289 |
+
mlp_ratio=4.0,
|
290 |
+
qkv_bias=False,
|
291 |
+
mlp_head=False,
|
292 |
+
drop_rate=0.0,
|
293 |
+
attn_drop_rate=0.0,
|
294 |
+
):
|
295 |
+
super().__init__()
|
296 |
+
self.num_classes = num_classes
|
297 |
+
self.num_features = (
|
298 |
+
self.embed_dim
|
299 |
+
) = embed_dim # num_features for consistency with other models
|
300 |
+
self.patch_embed = PatchEmbed(
|
301 |
+
img_size=img_size,
|
302 |
+
patch_size=patch_size,
|
303 |
+
in_chans=in_chans,
|
304 |
+
embed_dim=embed_dim,
|
305 |
+
)
|
306 |
+
num_patches = self.patch_embed.num_patches
|
307 |
+
|
308 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
309 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
310 |
+
|
311 |
+
self.blocks = nn.ModuleList(
|
312 |
+
[
|
313 |
+
Block(
|
314 |
+
dim=embed_dim,
|
315 |
+
num_heads=num_heads,
|
316 |
+
mlp_ratio=mlp_ratio,
|
317 |
+
qkv_bias=qkv_bias,
|
318 |
+
drop=drop_rate,
|
319 |
+
attn_drop=attn_drop_rate,
|
320 |
+
)
|
321 |
+
for i in range(depth)
|
322 |
+
]
|
323 |
+
)
|
324 |
+
|
325 |
+
self.norm = LayerNorm(embed_dim)
|
326 |
+
if mlp_head:
|
327 |
+
# paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
|
328 |
+
self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
|
329 |
+
else:
|
330 |
+
# with a single Linear layer as head, the param count within rounding of paper
|
331 |
+
self.head = Linear(embed_dim, num_classes)
|
332 |
+
|
333 |
+
# FIXME not quite sure what the proper weight init is supposed to be,
|
334 |
+
# normal / trunc normal w/ std == .02 similar to other Bert like transformers
|
335 |
+
trunc_normal_(self.pos_embed, std=0.02) # embeddings same as weights?
|
336 |
+
trunc_normal_(self.cls_token, std=0.02)
|
337 |
+
self.apply(self._init_weights)
|
338 |
+
|
339 |
+
self.pool = IndexSelect()
|
340 |
+
self.add = Add()
|
341 |
+
|
342 |
+
self.inp_grad = None
|
343 |
+
|
344 |
+
def save_inp_grad(self, grad):
|
345 |
+
self.inp_grad = grad
|
346 |
+
|
347 |
+
def get_inp_grad(self):
|
348 |
+
return self.inp_grad
|
349 |
+
|
350 |
+
def _init_weights(self, m):
|
351 |
+
if isinstance(m, nn.Linear):
|
352 |
+
trunc_normal_(m.weight, std=0.02)
|
353 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
354 |
+
nn.init.constant_(m.bias, 0)
|
355 |
+
elif isinstance(m, nn.LayerNorm):
|
356 |
+
nn.init.constant_(m.bias, 0)
|
357 |
+
nn.init.constant_(m.weight, 1.0)
|
358 |
+
|
359 |
+
@property
|
360 |
+
def no_weight_decay(self):
|
361 |
+
return {"pos_embed", "cls_token"}
|
362 |
+
|
363 |
+
def forward(self, x):
|
364 |
+
B = x.shape[0]
|
365 |
+
x = self.patch_embed(x)
|
366 |
+
|
367 |
+
cls_tokens = self.cls_token.expand(
|
368 |
+
B, -1, -1
|
369 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
370 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
371 |
+
x = self.add([x, self.pos_embed])
|
372 |
+
|
373 |
+
x.register_hook(self.save_inp_grad)
|
374 |
+
|
375 |
+
for blk in self.blocks:
|
376 |
+
x = blk(x)
|
377 |
+
|
378 |
+
x = self.norm(x)
|
379 |
+
x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
|
380 |
+
x = x.squeeze(1)
|
381 |
+
x = self.head(x)
|
382 |
+
return x
|
383 |
+
|
384 |
+
def relprop(
|
385 |
+
self, cam=None, method="grad", is_ablation=False, start_layer=0, **kwargs
|
386 |
+
):
|
387 |
+
# print(kwargs)
|
388 |
+
# print("conservation 1", cam.sum())
|
389 |
+
cam = self.head.relprop(cam, **kwargs)
|
390 |
+
cam = cam.unsqueeze(1)
|
391 |
+
cam = self.pool.relprop(cam, **kwargs)
|
392 |
+
cam = self.norm.relprop(cam, **kwargs)
|
393 |
+
for blk in reversed(self.blocks):
|
394 |
+
cam = blk.relprop(cam, **kwargs)
|
395 |
+
|
396 |
+
# print("conservation 2", cam.sum())
|
397 |
+
# print("min", cam.min())
|
398 |
+
|
399 |
+
if method == "full":
|
400 |
+
(cam, _) = self.add.relprop(cam, **kwargs)
|
401 |
+
cam = cam[:, 1:]
|
402 |
+
cam = self.patch_embed.relprop(cam, **kwargs)
|
403 |
+
# sum on channels
|
404 |
+
cam = cam.sum(dim=1)
|
405 |
+
return cam
|
406 |
+
|
407 |
+
elif method == "rollout":
|
408 |
+
# cam rollout
|
409 |
+
attn_cams = []
|
410 |
+
for blk in self.blocks:
|
411 |
+
attn_heads = blk.attn.get_attn_cam().clamp(min=0)
|
412 |
+
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
|
413 |
+
attn_cams.append(avg_heads)
|
414 |
+
cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
|
415 |
+
cam = cam[:, 0, 1:]
|
416 |
+
return cam
|
417 |
+
|
418 |
+
elif method == "grad":
|
419 |
+
cams = []
|
420 |
+
for blk in self.blocks:
|
421 |
+
grad = blk.attn.get_attn_gradients()
|
422 |
+
cam = blk.attn.get_attn_cam()
|
423 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
424 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
425 |
+
cam = grad * cam
|
426 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
427 |
+
cams.append(cam.unsqueeze(0))
|
428 |
+
rollout = compute_rollout_attention(cams, start_layer=start_layer)
|
429 |
+
cam = rollout[:, 0, 1:]
|
430 |
+
return cam
|
431 |
+
|
432 |
+
elif method == "last_layer":
|
433 |
+
cam = self.blocks[-1].attn.get_attn_cam()
|
434 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
435 |
+
if is_ablation:
|
436 |
+
grad = self.blocks[-1].attn.get_attn_gradients()
|
437 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
438 |
+
cam = grad * cam
|
439 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
440 |
+
cam = cam[0, 1:]
|
441 |
+
return cam
|
442 |
+
|
443 |
+
elif method == "last_layer_attn":
|
444 |
+
cam = self.blocks[-1].attn.get_attn()
|
445 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
446 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
447 |
+
cam = cam[0, 1:]
|
448 |
+
return cam
|
449 |
+
|
450 |
+
elif method == "second_layer":
|
451 |
+
cam = self.blocks[1].attn.get_attn_cam()
|
452 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
453 |
+
if is_ablation:
|
454 |
+
grad = self.blocks[1].attn.get_attn_gradients()
|
455 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
456 |
+
cam = grad * cam
|
457 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
458 |
+
cam = cam[0, 1:]
|
459 |
+
return cam
|
460 |
+
|
461 |
+
|
462 |
+
def _conv_filter(state_dict, patch_size=16):
|
463 |
+
"""convert patch embedding weight from manual patchify + linear proj to conv"""
|
464 |
+
out_dict = {}
|
465 |
+
for k, v in state_dict.items():
|
466 |
+
if "patch_embed.proj.weight" in k:
|
467 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
468 |
+
out_dict[k] = v
|
469 |
+
return out_dict
|
470 |
+
|
471 |
+
|
472 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
473 |
+
model = VisionTransformer(
|
474 |
+
patch_size=16,
|
475 |
+
embed_dim=768,
|
476 |
+
depth=12,
|
477 |
+
num_heads=12,
|
478 |
+
mlp_ratio=4,
|
479 |
+
qkv_bias=True,
|
480 |
+
**kwargs,
|
481 |
+
)
|
482 |
+
model.default_cfg = default_cfgs["vit_base_patch16_224"]
|
483 |
+
if pretrained:
|
484 |
+
load_pretrained(
|
485 |
+
model,
|
486 |
+
num_classes=model.num_classes,
|
487 |
+
in_chans=kwargs.get("in_chans", 3),
|
488 |
+
filter_fn=_conv_filter,
|
489 |
+
)
|
490 |
+
return model
|
491 |
+
|
492 |
+
|
493 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
494 |
+
model = VisionTransformer(
|
495 |
+
patch_size=16,
|
496 |
+
embed_dim=1024,
|
497 |
+
depth=24,
|
498 |
+
num_heads=16,
|
499 |
+
mlp_ratio=4,
|
500 |
+
qkv_bias=True,
|
501 |
+
**kwargs,
|
502 |
+
)
|
503 |
+
model.default_cfg = default_cfgs["vit_large_patch16_224"]
|
504 |
+
if pretrained:
|
505 |
+
load_pretrained(
|
506 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)
|
507 |
+
)
|
508 |
+
return model
|