Spaces:
Running
Running
PFEemp2024
commited on
Commit
·
4a1df2e
1
Parent(s):
b161bb4
solving GPU error for previous version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +7 -6
- anonymous_demo/__init__.py +5 -0
- anonymous_demo/core/__init__.py +0 -0
- anonymous_demo/core/tad/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/README.MD +3 -0
- anonymous_demo/core/tad/classic/__bert__/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +121 -0
- anonymous_demo/core/tad/classic/__bert__/models/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +46 -0
- anonymous_demo/core/tad/classic/__init__.py +0 -0
- anonymous_demo/core/tad/models/__init__.py +9 -0
- anonymous_demo/core/tad/prediction/__init__.py +0 -0
- anonymous_demo/core/tad/prediction/tad_classifier.py +518 -0
- anonymous_demo/functional/__init__.py +3 -0
- anonymous_demo/functional/checkpoint/__init__.py +1 -0
- anonymous_demo/functional/checkpoint/checkpoint_manager.py +19 -0
- anonymous_demo/functional/config/__init__.py +1 -0
- anonymous_demo/functional/config/config_manager.py +64 -0
- anonymous_demo/functional/config/tad_config_manager.py +229 -0
- anonymous_demo/functional/dataset/__init__.py +1 -0
- anonymous_demo/functional/dataset/dataset_manager.py +45 -0
- anonymous_demo/network/__init__.py +0 -0
- anonymous_demo/network/lcf_pooler.py +28 -0
- anonymous_demo/network/lsa.py +73 -0
- anonymous_demo/network/sa_encoder.py +199 -0
- anonymous_demo/utils/__init__.py +0 -0
- anonymous_demo/utils/demo_utils.py +247 -0
- anonymous_demo/utils/logger.py +38 -0
- app.py +360 -0
- checkpoints.zip +3 -0
- flow correction 30%.ipynb +516 -0
- flow_correction_ag_news.py +388 -0
- flow_correction_imdb.py +388 -0
- gitignore +143 -0
- main_correction.py +89 -0
- requirements.txt +28 -0
- text_defense/201.SST2/stsa.binary.dev.dat +0 -0
- text_defense/201.SST2/stsa.binary.test.dat +0 -0
- text_defense/201.SST2/stsa.binary.train.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.test.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.train.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.valid.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.test.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.train.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.valid.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat +0 -0
- textattack/__init__.py +39 -0
- textattack/__main__.py +6 -0
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
---
|
2 |
-
title: DCWIR
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.31.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: DCWIR Demo
|
3 |
+
emoji: 🛡️
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.31.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
short_description: 'A Demo for DCWIR method in SAE steup . '
|
12 |
---
|
13 |
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
anonymous_demo/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.0.0"
|
2 |
+
|
3 |
+
__name__ = "anonymous_demo"
|
4 |
+
|
5 |
+
from anonymous_demo.functional import TADCheckpointManager
|
anonymous_demo/core/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/classic/__bert__/README.MD
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## This is the simple migration from ABSA-PyTorch under MIT license
|
2 |
+
|
3 |
+
Project Address: https://github.com/songyouwei/ABSA-PyTorch
|
anonymous_demo/core/tad/classic/__bert__/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .models import *
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
from findfile import find_cwd_dir
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
class Tokenizer4Pretraining:
|
8 |
+
def __init__(self, max_seq_len, opt, **kwargs):
|
9 |
+
if kwargs.pop("offline", False):
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
11 |
+
find_cwd_dir(opt.pretrained_bert.split("/")[-1]),
|
12 |
+
do_lower_case="uncased" in opt.pretrained_bert,
|
13 |
+
)
|
14 |
+
else:
|
15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
16 |
+
opt.pretrained_bert, do_lower_case="uncased" in opt.pretrained_bert
|
17 |
+
)
|
18 |
+
self.max_seq_len = max_seq_len
|
19 |
+
|
20 |
+
def text_to_sequence(self, text, reverse=False, padding="post", truncating="post"):
|
21 |
+
return self.tokenizer.encode(
|
22 |
+
text,
|
23 |
+
truncation=True,
|
24 |
+
padding="max_length",
|
25 |
+
max_length=self.max_seq_len,
|
26 |
+
return_tensors="pt",
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class BERTTADDataset(Dataset):
|
31 |
+
def __init__(self, tokenizer, opt):
|
32 |
+
self.bert_baseline_input_colses = {"bert": ["text_bert_indices"]}
|
33 |
+
|
34 |
+
self.tokenizer = tokenizer
|
35 |
+
self.opt = opt
|
36 |
+
self.all_data = []
|
37 |
+
|
38 |
+
def parse_sample(self, text):
|
39 |
+
return [text]
|
40 |
+
|
41 |
+
def prepare_infer_sample(self, text: str, ignore_error):
|
42 |
+
self.process_data(self.parse_sample(text), ignore_error=ignore_error)
|
43 |
+
|
44 |
+
def process_data(self, samples, ignore_error=True):
|
45 |
+
all_data = []
|
46 |
+
if len(samples) > 100:
|
47 |
+
it = tqdm.tqdm(
|
48 |
+
samples, postfix="preparing text classification inference dataloader..."
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
it = samples
|
52 |
+
for text in it:
|
53 |
+
try:
|
54 |
+
# handle for empty lines in inference datasets
|
55 |
+
if text is None or "" == text.strip():
|
56 |
+
raise RuntimeError("Invalid Input!")
|
57 |
+
|
58 |
+
if "!ref!" in text:
|
59 |
+
text, _, labels = text.strip().partition("!ref!")
|
60 |
+
text = text.strip()
|
61 |
+
if labels.count(",") == 2:
|
62 |
+
label, is_adv, adv_train_label = labels.strip().split(",")
|
63 |
+
label, is_adv, adv_train_label = (
|
64 |
+
label.strip(),
|
65 |
+
is_adv.strip(),
|
66 |
+
adv_train_label.strip(),
|
67 |
+
)
|
68 |
+
elif labels.count(",") == 1:
|
69 |
+
label, is_adv = labels.strip().split(",")
|
70 |
+
label, is_adv = label.strip(), is_adv.strip()
|
71 |
+
adv_train_label = "-100"
|
72 |
+
elif labels.count(",") == 0:
|
73 |
+
label = labels.strip()
|
74 |
+
adv_train_label = "-100"
|
75 |
+
is_adv = "-100"
|
76 |
+
else:
|
77 |
+
label = "-100"
|
78 |
+
adv_train_label = "-100"
|
79 |
+
is_adv = "-100"
|
80 |
+
|
81 |
+
label = int(label)
|
82 |
+
adv_train_label = int(adv_train_label)
|
83 |
+
is_adv = int(is_adv)
|
84 |
+
|
85 |
+
else:
|
86 |
+
text = text.strip()
|
87 |
+
label = -100
|
88 |
+
adv_train_label = -100
|
89 |
+
is_adv = -100
|
90 |
+
|
91 |
+
text_indices = self.tokenizer.text_to_sequence("{}".format(text))
|
92 |
+
|
93 |
+
data = {
|
94 |
+
"text_bert_indices": text_indices[0],
|
95 |
+
"text_raw": text,
|
96 |
+
"label": label,
|
97 |
+
"adv_train_label": adv_train_label,
|
98 |
+
"is_adv": is_adv,
|
99 |
+
# 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
|
100 |
+
#
|
101 |
+
# 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
|
102 |
+
#
|
103 |
+
# 'is_adv': self.opt.is_adv_to_index.get(is_adv, -100) if isinstance(is_adv, str) else is_adv,
|
104 |
+
}
|
105 |
+
|
106 |
+
all_data.append(data)
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
if ignore_error:
|
110 |
+
print("Ignore error while processing:", text)
|
111 |
+
else:
|
112 |
+
raise e
|
113 |
+
|
114 |
+
self.all_data = all_data
|
115 |
+
return self.all_data
|
116 |
+
|
117 |
+
def __getitem__(self, index):
|
118 |
+
return self.all_data[index]
|
119 |
+
|
120 |
+
def __len__(self):
|
121 |
+
return len(self.all_data)
|
anonymous_demo/core/tad/classic/__bert__/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .tad_bert import TADBERT
|
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers.models.bert.modeling_bert import BertPooler
|
4 |
+
|
5 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
6 |
+
|
7 |
+
|
8 |
+
class TADBERT(nn.Module):
|
9 |
+
inputs = ["text_bert_indices"]
|
10 |
+
|
11 |
+
def __init__(self, bert, opt):
|
12 |
+
super(TADBERT, self).__init__()
|
13 |
+
self.opt = opt
|
14 |
+
self.bert = bert
|
15 |
+
self.pooler = BertPooler(bert.config)
|
16 |
+
self.dense1 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
17 |
+
self.dense2 = nn.Linear(self.opt.hidden_dim, self.opt.adv_det_dim)
|
18 |
+
self.dense3 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
19 |
+
|
20 |
+
self.encoder1 = Encoder(self.bert.config, opt=opt)
|
21 |
+
self.encoder2 = Encoder(self.bert.config, opt=opt)
|
22 |
+
self.encoder3 = Encoder(self.bert.config, opt=opt)
|
23 |
+
|
24 |
+
def forward(self, inputs):
|
25 |
+
text_raw_indices = inputs[0]
|
26 |
+
last_hidden_state = self.bert(text_raw_indices)["last_hidden_state"]
|
27 |
+
|
28 |
+
sent_logits = self.dense1(self.pooler(last_hidden_state))
|
29 |
+
advdet_logits = self.dense2(self.pooler(last_hidden_state))
|
30 |
+
adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
|
31 |
+
|
32 |
+
att_score = torch.nn.functional.normalize(
|
33 |
+
last_hidden_state.abs().sum(dim=1, keepdim=False)
|
34 |
+
- last_hidden_state.abs().min(dim=1, keepdim=True)[0],
|
35 |
+
p=1,
|
36 |
+
dim=1,
|
37 |
+
)
|
38 |
+
|
39 |
+
outputs = {
|
40 |
+
"sent_logits": sent_logits,
|
41 |
+
"advdet_logits": advdet_logits,
|
42 |
+
"adv_tr_logits": adv_tr_logits,
|
43 |
+
"last_hidden_state": last_hidden_state,
|
44 |
+
"att_score": att_score,
|
45 |
+
}
|
46 |
+
return outputs
|
anonymous_demo/core/tad/classic/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/models/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import anonymous_demo.core.tad.classic.__bert__.models
|
2 |
+
|
3 |
+
|
4 |
+
class BERTTADModelList(list):
|
5 |
+
TADBERT = anonymous_demo.core.tad.classic.__bert__.TADBERT
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
model_list = [self.TADBERT]
|
9 |
+
super().__init__(model_list)
|
anonymous_demo/core/tad/prediction/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/prediction/tad_classifier.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
from findfile import find_file, find_cwd_dir
|
9 |
+
from termcolor import colored
|
10 |
+
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from transformers import (
|
13 |
+
AutoTokenizer,
|
14 |
+
AutoModel,
|
15 |
+
AutoConfig,
|
16 |
+
DebertaV2ForMaskedLM,
|
17 |
+
RobertaForMaskedLM,
|
18 |
+
BertForMaskedLM,
|
19 |
+
)
|
20 |
+
|
21 |
+
from ....functional.dataset.dataset_manager import detect_infer_dataset
|
22 |
+
|
23 |
+
from ..models import BERTTADModelList
|
24 |
+
from ..classic.__bert__.dataset_utils.data_utils_for_inference import (
|
25 |
+
BERTTADDataset,
|
26 |
+
Tokenizer4Pretraining,
|
27 |
+
)
|
28 |
+
|
29 |
+
from ....utils.demo_utils import (
|
30 |
+
print_args,
|
31 |
+
TransformerConnectionError,
|
32 |
+
get_device,
|
33 |
+
build_embedding_matrix,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def init_attacker(tad_classifier, defense):
|
38 |
+
try:
|
39 |
+
from textattack import Attacker
|
40 |
+
from textattack.attack_recipes import (
|
41 |
+
BAEGarg2019,
|
42 |
+
PWWSRen2019,
|
43 |
+
TextFoolerJin2019,
|
44 |
+
PSOZang2020,
|
45 |
+
IGAWang2019,
|
46 |
+
GeneticAlgorithmAlzantot2018,
|
47 |
+
DeepWordBugGao2018,
|
48 |
+
)
|
49 |
+
from textattack.datasets import Dataset
|
50 |
+
from textattack.models.wrappers import HuggingFaceModelWrapper
|
51 |
+
|
52 |
+
class DemoModelWrapper(HuggingFaceModelWrapper):
|
53 |
+
def __init__(self, model):
|
54 |
+
self.model = model # pipeline = pipeline
|
55 |
+
|
56 |
+
def __call__(self, text_inputs, **kwargs):
|
57 |
+
outputs = []
|
58 |
+
for text_input in text_inputs:
|
59 |
+
raw_outputs = self.model.infer(
|
60 |
+
text_input, print_result=False, **kwargs
|
61 |
+
)
|
62 |
+
outputs.append(raw_outputs["probs"])
|
63 |
+
return outputs
|
64 |
+
|
65 |
+
class SentAttacker:
|
66 |
+
def __init__(self, model, recipe_class=BAEGarg2019):
|
67 |
+
model = model
|
68 |
+
model_wrapper = DemoModelWrapper(model)
|
69 |
+
|
70 |
+
recipe = recipe_class.build(model_wrapper)
|
71 |
+
|
72 |
+
_dataset = [("", 0)]
|
73 |
+
_dataset = Dataset(_dataset)
|
74 |
+
|
75 |
+
self.attacker = Attacker(recipe, _dataset)
|
76 |
+
|
77 |
+
attackers = {
|
78 |
+
"bae": BAEGarg2019,
|
79 |
+
"pwws": PWWSRen2019,
|
80 |
+
"textfooler": TextFoolerJin2019,
|
81 |
+
"pso": PSOZang2020,
|
82 |
+
"iga": IGAWang2019,
|
83 |
+
"ga": GeneticAlgorithmAlzantot2018,
|
84 |
+
"wordbugger": DeepWordBugGao2018,
|
85 |
+
}
|
86 |
+
return SentAttacker(tad_classifier, attackers[defense])
|
87 |
+
except Exception as e:
|
88 |
+
print("Original error:", e)
|
89 |
+
|
90 |
+
|
91 |
+
def get_mlm_and_tokenizer(text_classifier, config):
|
92 |
+
if isinstance(text_classifier, TADTextClassifier):
|
93 |
+
base_model = text_classifier.model.bert.base_model
|
94 |
+
else:
|
95 |
+
base_model = text_classifier.bert.base_model
|
96 |
+
pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
|
97 |
+
if "deberta-v3" in config.pretrained_bert:
|
98 |
+
MLM = DebertaV2ForMaskedLM(pretrained_config)
|
99 |
+
MLM.deberta = base_model
|
100 |
+
elif "roberta" in config.pretrained_bert:
|
101 |
+
MLM = RobertaForMaskedLM(pretrained_config)
|
102 |
+
MLM.roberta = base_model
|
103 |
+
else:
|
104 |
+
MLM = BertForMaskedLM(pretrained_config)
|
105 |
+
MLM.bert = base_model
|
106 |
+
return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)
|
107 |
+
|
108 |
+
|
109 |
+
class TADTextClassifier:
|
110 |
+
def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
|
111 |
+
"""
|
112 |
+
from_train_model: load inference model from trained model
|
113 |
+
"""
|
114 |
+
self.cal_perplexity = cal_perplexity
|
115 |
+
# load from a training
|
116 |
+
if not isinstance(model_arg, str):
|
117 |
+
print("Load text classifier from training")
|
118 |
+
self.model = model_arg[0]
|
119 |
+
self.opt = model_arg[1]
|
120 |
+
self.tokenizer = model_arg[2]
|
121 |
+
else:
|
122 |
+
try:
|
123 |
+
if "fine-tuned" in model_arg:
|
124 |
+
raise ValueError(
|
125 |
+
"Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!"
|
126 |
+
)
|
127 |
+
print("Load text classifier from", model_arg)
|
128 |
+
state_dict_path = find_file(
|
129 |
+
model_arg, key=".state_dict", exclude_key=["__MACOSX"]
|
130 |
+
)
|
131 |
+
model_path = find_file(
|
132 |
+
model_arg, key=".model", exclude_key=["__MACOSX"]
|
133 |
+
)
|
134 |
+
tokenizer_path = find_file(
|
135 |
+
model_arg, key=".tokenizer", exclude_key=["__MACOSX"]
|
136 |
+
)
|
137 |
+
config_path = find_file(
|
138 |
+
model_arg, key=".config", exclude_key=["__MACOSX"]
|
139 |
+
)
|
140 |
+
|
141 |
+
print("config: {}".format(config_path))
|
142 |
+
print("state_dict: {}".format(state_dict_path))
|
143 |
+
print("model: {}".format(model_path))
|
144 |
+
print("tokenizer: {}".format(tokenizer_path))
|
145 |
+
|
146 |
+
with open(config_path, mode="rb") as f:
|
147 |
+
self.opt = pickle.load(f)
|
148 |
+
self.opt.device = get_device(kwargs.pop("auto_device", True))[0]
|
149 |
+
|
150 |
+
if state_dict_path or model_path:
|
151 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
152 |
+
if state_dict_path:
|
153 |
+
if kwargs.pop("offline", False):
|
154 |
+
self.bert = AutoModel.from_pretrained(
|
155 |
+
find_cwd_dir(
|
156 |
+
self.opt.pretrained_bert.split("/")[-1]
|
157 |
+
)
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
self.bert = AutoModel.from_pretrained(
|
161 |
+
self.opt.pretrained_bert
|
162 |
+
)
|
163 |
+
self.model = self.opt.model(self.bert, self.opt)
|
164 |
+
self.model.load_state_dict(
|
165 |
+
torch.load(state_dict_path, map_location="cpu")
|
166 |
+
)
|
167 |
+
elif model_path:
|
168 |
+
self.model = torch.load(model_path, map_location="cpu")
|
169 |
+
|
170 |
+
try:
|
171 |
+
self.tokenizer = Tokenizer4Pretraining(
|
172 |
+
max_seq_len=self.opt.max_seq_len, opt=self.opt, **kwargs
|
173 |
+
)
|
174 |
+
except ValueError:
|
175 |
+
if tokenizer_path:
|
176 |
+
with open(tokenizer_path, mode="rb") as f:
|
177 |
+
self.tokenizer = pickle.load(f)
|
178 |
+
else:
|
179 |
+
raise TransformerConnectionError()
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
raise RuntimeError(
|
183 |
+
"Exception: {} Fail to load the model from {}! ".format(
|
184 |
+
e, model_arg
|
185 |
+
)
|
186 |
+
)
|
187 |
+
|
188 |
+
self.infer_dataloader = None
|
189 |
+
self.opt.eval_batch_size = kwargs.pop("eval_batch_size", 128)
|
190 |
+
|
191 |
+
self.opt.initializer = self.opt.initializer
|
192 |
+
|
193 |
+
if self.cal_perplexity:
|
194 |
+
try:
|
195 |
+
self.MLM, self.MLM_tokenizer = get_mlm_and_tokenizer(self, self.opt)
|
196 |
+
except Exception as e:
|
197 |
+
self.MLM, self.MLM_tokenizer = None, None
|
198 |
+
|
199 |
+
self.to(self.opt.device)
|
200 |
+
|
201 |
+
def to(self, device=None):
|
202 |
+
self.opt.device = device
|
203 |
+
self.model.to(device)
|
204 |
+
if hasattr(self, "MLM"):
|
205 |
+
self.MLM.to(self.opt.device)
|
206 |
+
|
207 |
+
def cpu(self):
|
208 |
+
self.opt.device = "cpu"
|
209 |
+
self.model.to("cpu")
|
210 |
+
if hasattr(self, "MLM"):
|
211 |
+
self.MLM.to("cpu")
|
212 |
+
|
213 |
+
def cuda(self, device="cuda:0"):
|
214 |
+
self.opt.device = device
|
215 |
+
self.model.to(device)
|
216 |
+
if hasattr(self, "MLM"):
|
217 |
+
self.MLM.to(device)
|
218 |
+
|
219 |
+
def _log_write_args(self):
|
220 |
+
n_trainable_params, n_nontrainable_params = 0, 0
|
221 |
+
for p in self.model.parameters():
|
222 |
+
n_params = torch.prod(torch.tensor(p.shape))
|
223 |
+
if p.requires_grad:
|
224 |
+
n_trainable_params += n_params
|
225 |
+
else:
|
226 |
+
n_nontrainable_params += n_params
|
227 |
+
print(
|
228 |
+
"n_trainable_params: {0}, n_nontrainable_params: {1}".format(
|
229 |
+
n_trainable_params, n_nontrainable_params
|
230 |
+
)
|
231 |
+
)
|
232 |
+
for arg in vars(self.opt):
|
233 |
+
if getattr(self.opt, arg) is not None:
|
234 |
+
print(">>> {0}: {1}".format(arg, getattr(self.opt, arg)))
|
235 |
+
|
236 |
+
def batch_infer(
|
237 |
+
self,
|
238 |
+
target_file=None,
|
239 |
+
print_result=True,
|
240 |
+
save_result=False,
|
241 |
+
ignore_error=True,
|
242 |
+
defense: str = None,
|
243 |
+
):
|
244 |
+
save_path = os.path.join(os.getcwd(), "tad_text_classification.result.json")
|
245 |
+
|
246 |
+
target_file = detect_infer_dataset(target_file, task="text_defense")
|
247 |
+
if not target_file:
|
248 |
+
raise FileNotFoundError("Can not find inference datasets!")
|
249 |
+
|
250 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
251 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
252 |
+
|
253 |
+
dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
|
254 |
+
self.infer_dataloader = DataLoader(
|
255 |
+
dataset=dataset,
|
256 |
+
batch_size=self.opt.eval_batch_size,
|
257 |
+
pin_memory=True,
|
258 |
+
shuffle=False,
|
259 |
+
)
|
260 |
+
return self._infer(
|
261 |
+
save_path=save_path if save_result else None,
|
262 |
+
print_result=print_result,
|
263 |
+
defense=defense,
|
264 |
+
)
|
265 |
+
|
266 |
+
def infer(
|
267 |
+
self,
|
268 |
+
text: str = None,
|
269 |
+
print_result=True,
|
270 |
+
ignore_error=True,
|
271 |
+
defense: str = None,
|
272 |
+
):
|
273 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
274 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
275 |
+
|
276 |
+
if text:
|
277 |
+
dataset.prepare_infer_sample(text, ignore_error=ignore_error)
|
278 |
+
else:
|
279 |
+
raise RuntimeError("Please specify your datasets path!")
|
280 |
+
self.infer_dataloader = DataLoader(
|
281 |
+
dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False
|
282 |
+
)
|
283 |
+
return self._infer(print_result=print_result, defense=defense)[0]
|
284 |
+
|
285 |
+
def _infer(self, save_path=None, print_result=True, defense=None):
|
286 |
+
_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
287 |
+
|
288 |
+
correct = {True: "Correct", False: "Wrong"}
|
289 |
+
results = []
|
290 |
+
|
291 |
+
with torch.no_grad():
|
292 |
+
self.model.eval()
|
293 |
+
n_correct = 0
|
294 |
+
n_labeled = 0
|
295 |
+
|
296 |
+
n_advdet_correct = 0
|
297 |
+
n_advdet_labeled = 0
|
298 |
+
if len(self.infer_dataloader.dataset) >= 100:
|
299 |
+
it = tqdm.tqdm(self.infer_dataloader, postfix="inferring...")
|
300 |
+
else:
|
301 |
+
it = self.infer_dataloader
|
302 |
+
for _, sample in enumerate(it):
|
303 |
+
inputs = [
|
304 |
+
sample[col].to(self.opt.device) for col in self.opt.inputs_cols
|
305 |
+
]
|
306 |
+
outputs = self.model(inputs)
|
307 |
+
logits, advdet_logits, adv_tr_logits = (
|
308 |
+
outputs["sent_logits"],
|
309 |
+
outputs["advdet_logits"],
|
310 |
+
outputs["adv_tr_logits"],
|
311 |
+
)
|
312 |
+
probs, advdet_probs, adv_tr_probs = (
|
313 |
+
torch.softmax(logits, dim=-1),
|
314 |
+
torch.softmax(advdet_logits, dim=-1),
|
315 |
+
torch.softmax(adv_tr_logits, dim=-1),
|
316 |
+
)
|
317 |
+
|
318 |
+
for i, (prob, advdet_prob, adv_tr_prob) in enumerate(
|
319 |
+
zip(probs, advdet_probs, adv_tr_probs)
|
320 |
+
):
|
321 |
+
text_raw = sample["text_raw"][i]
|
322 |
+
|
323 |
+
pred_label = int(prob.argmax(axis=-1))
|
324 |
+
pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
|
325 |
+
pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
|
326 |
+
ref_label = (
|
327 |
+
int(sample["label"][i])
|
328 |
+
if int(sample["label"][i]) in self.opt.index_to_label
|
329 |
+
else ""
|
330 |
+
)
|
331 |
+
ref_is_adv_label = (
|
332 |
+
int(sample["is_adv"][i])
|
333 |
+
if int(sample["is_adv"][i]) in self.opt.index_to_is_adv
|
334 |
+
else ""
|
335 |
+
)
|
336 |
+
ref_adv_tr_label = (
|
337 |
+
int(sample["adv_train_label"][i])
|
338 |
+
if int(sample["adv_train_label"][i])
|
339 |
+
in self.opt.index_to_adv_train_label
|
340 |
+
else ""
|
341 |
+
)
|
342 |
+
|
343 |
+
if self.cal_perplexity:
|
344 |
+
ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
|
345 |
+
ids["labels"] = ids["input_ids"].clone()
|
346 |
+
ids = ids.to(self.opt.device)
|
347 |
+
loss = self.MLM(**ids)["loss"]
|
348 |
+
perplexity = float(torch.exp(loss / ids["input_ids"].size(1)))
|
349 |
+
else:
|
350 |
+
perplexity = "N.A."
|
351 |
+
|
352 |
+
result = {
|
353 |
+
"text": text_raw,
|
354 |
+
"label": self.opt.index_to_label[pred_label],
|
355 |
+
"probs": prob.cpu().numpy(),
|
356 |
+
"confidence": float(max(prob)),
|
357 |
+
"ref_label": self.opt.index_to_label[ref_label]
|
358 |
+
if isinstance(ref_label, int)
|
359 |
+
else ref_label,
|
360 |
+
"ref_label_check": correct[pred_label == ref_label]
|
361 |
+
if ref_label != -100
|
362 |
+
else "",
|
363 |
+
"is_fixed": False,
|
364 |
+
"is_adv_label": self.opt.index_to_is_adv[pred_is_adv_label],
|
365 |
+
"is_adv_probs": advdet_prob.cpu().numpy(),
|
366 |
+
"is_adv_confidence": float(max(advdet_prob)),
|
367 |
+
"ref_is_adv_label": self.opt.index_to_is_adv[ref_is_adv_label]
|
368 |
+
if isinstance(ref_is_adv_label, int)
|
369 |
+
else ref_is_adv_label,
|
370 |
+
"ref_is_adv_check": correct[
|
371 |
+
pred_is_adv_label == ref_is_adv_label
|
372 |
+
]
|
373 |
+
if ref_is_adv_label != -100
|
374 |
+
and isinstance(ref_is_adv_label, int)
|
375 |
+
else "",
|
376 |
+
"pred_adv_tr_label": self.opt.index_to_label[pred_adv_tr_label],
|
377 |
+
"ref_adv_tr_label": self.opt.index_to_label[ref_adv_tr_label],
|
378 |
+
"perplexity": perplexity,
|
379 |
+
}
|
380 |
+
if defense:
|
381 |
+
try:
|
382 |
+
if not hasattr(self, "sent_attacker"):
|
383 |
+
self.sent_attacker = init_attacker(
|
384 |
+
self, defense.lower()
|
385 |
+
)
|
386 |
+
if result["is_adv_label"] == "1":
|
387 |
+
res = self.sent_attacker.attacker.simple_attack(
|
388 |
+
text_raw, int(result["label"])
|
389 |
+
)
|
390 |
+
new_infer_res = self.infer(
|
391 |
+
res.perturbed_result.attacked_text.text,
|
392 |
+
print_result=False,
|
393 |
+
)
|
394 |
+
result["perturbed_label"] = result["label"]
|
395 |
+
result["label"] = new_infer_res["label"]
|
396 |
+
result["probs"] = new_infer_res["probs"]
|
397 |
+
result["ref_label_check"] = (
|
398 |
+
correct[int(result["label"]) == ref_label]
|
399 |
+
if ref_label != -100
|
400 |
+
else ""
|
401 |
+
)
|
402 |
+
result[
|
403 |
+
"restored_text"
|
404 |
+
] = res.perturbed_result.attacked_text.text
|
405 |
+
result["is_fixed"] = True
|
406 |
+
else:
|
407 |
+
result["restored_text"] = ""
|
408 |
+
result["is_fixed"] = False
|
409 |
+
|
410 |
+
except Exception as e:
|
411 |
+
print(
|
412 |
+
"Error:{}, try install TextAttack and tensorflow_text after 10 seconds...".format(
|
413 |
+
e
|
414 |
+
)
|
415 |
+
)
|
416 |
+
time.sleep(10)
|
417 |
+
raise RuntimeError("Installation done, please run again...")
|
418 |
+
|
419 |
+
if ref_label != -100:
|
420 |
+
n_labeled += 1
|
421 |
+
|
422 |
+
if result["label"] == result["ref_label"]:
|
423 |
+
n_correct += 1
|
424 |
+
|
425 |
+
if ref_is_adv_label != -100:
|
426 |
+
n_advdet_labeled += 1
|
427 |
+
if ref_is_adv_label == pred_is_adv_label:
|
428 |
+
n_advdet_correct += 1
|
429 |
+
|
430 |
+
results.append(result)
|
431 |
+
|
432 |
+
try:
|
433 |
+
if print_result:
|
434 |
+
for ex_id, result in enumerate(results):
|
435 |
+
text_printing = result["text"][:]
|
436 |
+
text_info = ""
|
437 |
+
if result["label"] != "-100":
|
438 |
+
if not result["ref_label"]:
|
439 |
+
text_info += " -> <CLS:{}(ref:{} confidence:{})>".format(
|
440 |
+
result["label"],
|
441 |
+
result["ref_label"],
|
442 |
+
result["confidence"],
|
443 |
+
)
|
444 |
+
elif result["label"] == result["ref_label"]:
|
445 |
+
text_info += colored(
|
446 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
447 |
+
result["label"],
|
448 |
+
result["ref_label"],
|
449 |
+
result["confidence"],
|
450 |
+
),
|
451 |
+
"green",
|
452 |
+
)
|
453 |
+
else:
|
454 |
+
text_info += colored(
|
455 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
456 |
+
result["label"],
|
457 |
+
result["ref_label"],
|
458 |
+
result["confidence"],
|
459 |
+
),
|
460 |
+
"red",
|
461 |
+
)
|
462 |
+
|
463 |
+
# AdvDet
|
464 |
+
if result["is_adv_label"] != "-100":
|
465 |
+
if not result["ref_is_adv_label"]:
|
466 |
+
text_info += " -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
467 |
+
result["is_adv_label"],
|
468 |
+
result["ref_is_adv_check"],
|
469 |
+
result["is_adv_confidence"],
|
470 |
+
)
|
471 |
+
elif result["is_adv_label"] == result["ref_is_adv_label"]:
|
472 |
+
text_info += colored(
|
473 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
474 |
+
result["is_adv_label"],
|
475 |
+
result["ref_is_adv_label"],
|
476 |
+
result["is_adv_confidence"],
|
477 |
+
),
|
478 |
+
"green",
|
479 |
+
)
|
480 |
+
else:
|
481 |
+
text_info += colored(
|
482 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
483 |
+
result["is_adv_label"],
|
484 |
+
result["ref_is_adv_label"],
|
485 |
+
result["is_adv_confidence"],
|
486 |
+
),
|
487 |
+
"red",
|
488 |
+
)
|
489 |
+
text_printing += text_info
|
490 |
+
if self.cal_perplexity:
|
491 |
+
text_printing += colored(
|
492 |
+
" --> <perplexity:{}>".format(result["perplexity"]),
|
493 |
+
"yellow",
|
494 |
+
)
|
495 |
+
print("Example {}: {}".format(ex_id, text_printing))
|
496 |
+
if save_path:
|
497 |
+
with open(save_path, "w", encoding="utf8") as fout:
|
498 |
+
json.dump(str(results), fout, ensure_ascii=False)
|
499 |
+
print("inference result saved in: {}".format(save_path))
|
500 |
+
except Exception as e:
|
501 |
+
print("Can not save result: {}, Exception: {}".format(text_raw, e))
|
502 |
+
|
503 |
+
if len(results) > 1:
|
504 |
+
print(
|
505 |
+
"CLS Acc:{}%".format(100 * n_correct / n_labeled if n_labeled else "")
|
506 |
+
)
|
507 |
+
print(
|
508 |
+
"AdvDet Acc:{}%".format(
|
509 |
+
100 * n_advdet_correct / n_advdet_labeled
|
510 |
+
if n_advdet_labeled
|
511 |
+
else ""
|
512 |
+
)
|
513 |
+
)
|
514 |
+
|
515 |
+
return results
|
516 |
+
|
517 |
+
def clear_input_samples(self):
|
518 |
+
self.dataset.all_data = []
|
anonymous_demo/functional/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from anonymous_demo.functional.checkpoint.checkpoint_manager import TADCheckpointManager
|
2 |
+
|
3 |
+
from anonymous_demo.functional.config import TADConfigManager
|
anonymous_demo/functional/checkpoint/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .checkpoint_manager import TADCheckpointManager
|
anonymous_demo/functional/checkpoint/checkpoint_manager.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from findfile import find_file
|
3 |
+
|
4 |
+
from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier
|
5 |
+
from anonymous_demo.utils.demo_utils import retry
|
6 |
+
|
7 |
+
|
8 |
+
class CheckpointManager:
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class TADCheckpointManager(CheckpointManager):
|
13 |
+
@staticmethod
|
14 |
+
@retry
|
15 |
+
def get_tad_text_classifier(checkpoint: str = None, eval_batch_size=128, **kwargs):
|
16 |
+
tad_text_classifier = TADTextClassifier(
|
17 |
+
checkpoint, eval_batch_size=eval_batch_size, **kwargs
|
18 |
+
)
|
19 |
+
return tad_text_classifier
|
anonymous_demo/functional/config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .tad_config_manager import TADConfigManager
|
anonymous_demo/functional/config/config_manager.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
one_shot_messages = set()
|
6 |
+
|
7 |
+
|
8 |
+
def config_check(args):
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class ConfigManager(Namespace):
|
13 |
+
def __init__(self, args=None, **kwargs):
|
14 |
+
"""
|
15 |
+
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
|
16 |
+
:param args: A parameter dict
|
17 |
+
:param kwargs: Same param as Namespce
|
18 |
+
"""
|
19 |
+
if not args:
|
20 |
+
args = {}
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
|
23 |
+
if isinstance(args, Namespace):
|
24 |
+
self.args = vars(args)
|
25 |
+
self.args_call_count = {arg: 0 for arg in vars(args)}
|
26 |
+
else:
|
27 |
+
self.args = args
|
28 |
+
self.args_call_count = {arg: 0 for arg in args}
|
29 |
+
|
30 |
+
def __getattribute__(self, arg_name):
|
31 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
32 |
+
return super().__getattribute__(arg_name)
|
33 |
+
try:
|
34 |
+
value = super().__getattribute__("args")[arg_name]
|
35 |
+
args_call_count = super().__getattribute__("args_call_count")
|
36 |
+
args_call_count[arg_name] += 1
|
37 |
+
super().__setattr__("args_call_count", args_call_count)
|
38 |
+
return value
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
return super().__getattribute__(arg_name)
|
42 |
+
|
43 |
+
def __setattr__(self, arg_name, value):
|
44 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
45 |
+
super().__setattr__(arg_name, value)
|
46 |
+
return
|
47 |
+
try:
|
48 |
+
args = super().__getattribute__("args")
|
49 |
+
args[arg_name] = value
|
50 |
+
super().__setattr__("args", args)
|
51 |
+
args_call_count = super().__getattribute__("args_call_count")
|
52 |
+
|
53 |
+
if arg_name in args_call_count:
|
54 |
+
# args_call_count[arg_name] += 1
|
55 |
+
super().__setattr__("args_call_count", args_call_count)
|
56 |
+
|
57 |
+
else:
|
58 |
+
args_call_count[arg_name] = 0
|
59 |
+
super().__setattr__("args_call_count", args_call_count)
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
super().__setattr__(arg_name, value)
|
63 |
+
|
64 |
+
config_check(args)
|
anonymous_demo/functional/config/tad_config_manager.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from anonymous_demo.functional.config.config_manager import ConfigManager
|
4 |
+
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
|
5 |
+
|
6 |
+
_tad_config_template = {
|
7 |
+
"model": TADBERT,
|
8 |
+
"optimizer": "adamw",
|
9 |
+
"learning_rate": 0.00002,
|
10 |
+
"patience": 99999,
|
11 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
12 |
+
"cache_dataset": True,
|
13 |
+
"warmup_step": -1,
|
14 |
+
"show_metric": False,
|
15 |
+
"max_seq_len": 80,
|
16 |
+
"dropout": 0,
|
17 |
+
"l2reg": 0.000001,
|
18 |
+
"num_epoch": 10,
|
19 |
+
"batch_size": 16,
|
20 |
+
"initializer": "xavier_uniform_",
|
21 |
+
"seed": 52,
|
22 |
+
"polarities_dim": 3,
|
23 |
+
"log_step": 10,
|
24 |
+
"evaluate_begin": 0,
|
25 |
+
"cross_validate_fold": -1,
|
26 |
+
"use_amp": False,
|
27 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
28 |
+
}
|
29 |
+
|
30 |
+
_tad_config_base = {
|
31 |
+
"model": TADBERT,
|
32 |
+
"optimizer": "adamw",
|
33 |
+
"learning_rate": 0.00002,
|
34 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
35 |
+
"cache_dataset": True,
|
36 |
+
"warmup_step": -1,
|
37 |
+
"show_metric": False,
|
38 |
+
"max_seq_len": 80,
|
39 |
+
"patience": 99999,
|
40 |
+
"dropout": 0,
|
41 |
+
"l2reg": 0.000001,
|
42 |
+
"num_epoch": 10,
|
43 |
+
"batch_size": 16,
|
44 |
+
"initializer": "xavier_uniform_",
|
45 |
+
"seed": 52,
|
46 |
+
"polarities_dim": 3,
|
47 |
+
"log_step": 10,
|
48 |
+
"evaluate_begin": 0,
|
49 |
+
"cross_validate_fold": -1
|
50 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
51 |
+
}
|
52 |
+
|
53 |
+
_tad_config_english = {
|
54 |
+
"model": TADBERT,
|
55 |
+
"optimizer": "adamw",
|
56 |
+
"learning_rate": 0.00002,
|
57 |
+
"patience": 99999,
|
58 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
59 |
+
"cache_dataset": True,
|
60 |
+
"warmup_step": -1,
|
61 |
+
"show_metric": False,
|
62 |
+
"max_seq_len": 80,
|
63 |
+
"dropout": 0,
|
64 |
+
"l2reg": 0.000001,
|
65 |
+
"num_epoch": 10,
|
66 |
+
"batch_size": 16,
|
67 |
+
"initializer": "xavier_uniform_",
|
68 |
+
"seed": 52,
|
69 |
+
"polarities_dim": 3,
|
70 |
+
"log_step": 10,
|
71 |
+
"evaluate_begin": 0,
|
72 |
+
"cross_validate_fold": -1
|
73 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
74 |
+
}
|
75 |
+
|
76 |
+
_tad_config_multilingual = {
|
77 |
+
"model": TADBERT,
|
78 |
+
"optimizer": "adamw",
|
79 |
+
"learning_rate": 0.00002,
|
80 |
+
"patience": 99999,
|
81 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
82 |
+
"cache_dataset": True,
|
83 |
+
"warmup_step": -1,
|
84 |
+
"show_metric": False,
|
85 |
+
"max_seq_len": 80,
|
86 |
+
"dropout": 0,
|
87 |
+
"l2reg": 0.000001,
|
88 |
+
"num_epoch": 10,
|
89 |
+
"batch_size": 16,
|
90 |
+
"initializer": "xavier_uniform_",
|
91 |
+
"seed": 52,
|
92 |
+
"polarities_dim": 3,
|
93 |
+
"log_step": 10,
|
94 |
+
"evaluate_begin": 0,
|
95 |
+
"cross_validate_fold": -1
|
96 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
97 |
+
}
|
98 |
+
|
99 |
+
_tad_config_chinese = {
|
100 |
+
"model": TADBERT,
|
101 |
+
"optimizer": "adamw",
|
102 |
+
"learning_rate": 0.00002,
|
103 |
+
"patience": 99999,
|
104 |
+
"cache_dataset": True,
|
105 |
+
"warmup_step": -1,
|
106 |
+
"show_metric": False,
|
107 |
+
"pretrained_bert": "bert-base-chinese",
|
108 |
+
"max_seq_len": 80,
|
109 |
+
"dropout": 0,
|
110 |
+
"l2reg": 0.000001,
|
111 |
+
"num_epoch": 10,
|
112 |
+
"batch_size": 16,
|
113 |
+
"initializer": "xavier_uniform_",
|
114 |
+
"seed": 52,
|
115 |
+
"polarities_dim": 3,
|
116 |
+
"log_step": 10,
|
117 |
+
"evaluate_begin": 0,
|
118 |
+
"cross_validate_fold": -1
|
119 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
class TADConfigManager(ConfigManager):
|
124 |
+
def __init__(self, args, **kwargs):
|
125 |
+
"""
|
126 |
+
Available Params: {'model': BERT,
|
127 |
+
'optimizer': "adamw",
|
128 |
+
'learning_rate': 0.00002,
|
129 |
+
'pretrained_bert': "roberta-base",
|
130 |
+
'cache_dataset': True,
|
131 |
+
'warmup_step': -1,
|
132 |
+
'show_metric': False,
|
133 |
+
'max_seq_len': 80,
|
134 |
+
'patience': 99999,
|
135 |
+
'dropout': 0,
|
136 |
+
'l2reg': 0.000001,
|
137 |
+
'num_epoch': 10,
|
138 |
+
'batch_size': 16,
|
139 |
+
'initializer': 'xavier_uniform_',
|
140 |
+
'seed': {52, 25}
|
141 |
+
'embed_dim': 768,
|
142 |
+
'hidden_dim': 768,
|
143 |
+
'polarities_dim': 3,
|
144 |
+
'log_step': 10,
|
145 |
+
'evaluate_begin': 0,
|
146 |
+
'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
|
147 |
+
}
|
148 |
+
:param args:
|
149 |
+
:param kwargs:
|
150 |
+
"""
|
151 |
+
super().__init__(args, **kwargs)
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def set_tad_config(configType: str, newitem: dict):
|
155 |
+
if isinstance(newitem, dict):
|
156 |
+
if configType == "template":
|
157 |
+
_tad_config_template.update(newitem)
|
158 |
+
elif configType == "base":
|
159 |
+
_tad_config_base.update(newitem)
|
160 |
+
elif configType == "english":
|
161 |
+
_tad_config_english.update(newitem)
|
162 |
+
elif configType == "chinese":
|
163 |
+
_tad_config_chinese.update(newitem)
|
164 |
+
elif configType == "multilingual":
|
165 |
+
_tad_config_multilingual.update(newitem)
|
166 |
+
elif configType == "glove":
|
167 |
+
_tad_config_glove.update(newitem)
|
168 |
+
else:
|
169 |
+
raise ValueError(
|
170 |
+
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise TypeError(
|
174 |
+
"Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
|
175 |
+
)
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def set_tad_config_template(newitem):
|
179 |
+
TADConfigManager.set_tad_config("template", newitem)
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def set_tad_config_base(newitem):
|
183 |
+
TADConfigManager.set_tad_config("base", newitem)
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def set_tad_config_english(newitem):
|
187 |
+
TADConfigManager.set_tad_config("english", newitem)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def set_tad_config_chinese(newitem):
|
191 |
+
TADConfigManager.set_tad_config("chinese", newitem)
|
192 |
+
|
193 |
+
@staticmethod
|
194 |
+
def set_tad_config_multilingual(newitem):
|
195 |
+
TADConfigManager.set_tad_config("multilingual", newitem)
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def set_tad_config_glove(newitem):
|
199 |
+
TADConfigManager.set_tad_config("glove", newitem)
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def get_tad_config_template() -> ConfigManager:
|
203 |
+
_tad_config_template.update(_tad_config_template)
|
204 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def get_tad_config_base() -> ConfigManager:
|
208 |
+
_tad_config_template.update(_tad_config_base)
|
209 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def get_tad_config_english() -> ConfigManager:
|
213 |
+
_tad_config_template.update(_tad_config_english)
|
214 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def get_tad_config_chinese() -> ConfigManager:
|
218 |
+
_tad_config_template.update(_tad_config_chinese)
|
219 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def get_tad_config_multilingual() -> ConfigManager:
|
223 |
+
_tad_config_template.update(_tad_config_multilingual)
|
224 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def get_tad_config_glove() -> ConfigManager:
|
228 |
+
_tad_config_template.update(_tad_config_glove)
|
229 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
anonymous_demo/functional/dataset/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from anonymous_demo.functional.dataset.dataset_manager import detect_infer_dataset
|
anonymous_demo/functional/dataset/dataset_manager.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from findfile import find_files, find_dir
|
3 |
+
|
4 |
+
filter_key_words = [
|
5 |
+
".py",
|
6 |
+
".md",
|
7 |
+
"readme",
|
8 |
+
"log",
|
9 |
+
"result",
|
10 |
+
"zip",
|
11 |
+
".state_dict",
|
12 |
+
".model",
|
13 |
+
".png",
|
14 |
+
"acc_",
|
15 |
+
"f1_",
|
16 |
+
".backup",
|
17 |
+
".bak",
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def detect_infer_dataset(dataset_path, task="apc"):
|
22 |
+
dataset_file = []
|
23 |
+
if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
|
24 |
+
dataset_file.append(dataset_path)
|
25 |
+
return dataset_file
|
26 |
+
|
27 |
+
for d in dataset_path:
|
28 |
+
if not os.path.exists(d):
|
29 |
+
search_path = find_dir(
|
30 |
+
os.getcwd(),
|
31 |
+
[d, task, "dataset"],
|
32 |
+
exclude_key=filter_key_words,
|
33 |
+
disable_alert=False,
|
34 |
+
)
|
35 |
+
dataset_file += find_files(
|
36 |
+
search_path,
|
37 |
+
[".inference", d],
|
38 |
+
exclude_key=["train."] + filter_key_words,
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
dataset_file += find_files(
|
42 |
+
d, [".inference", task], exclude_key=["train."] + filter_key_words
|
43 |
+
)
|
44 |
+
|
45 |
+
return dataset_file
|
anonymous_demo/network/__init__.py
ADDED
File without changes
|
anonymous_demo/network/lcf_pooler.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class LCF_Pooler(nn.Module):
|
7 |
+
def __init__(self, config):
|
8 |
+
super().__init__()
|
9 |
+
self.config = config
|
10 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
11 |
+
self.activation = nn.Tanh()
|
12 |
+
|
13 |
+
def forward(self, hidden_states, lcf_vec):
|
14 |
+
device = hidden_states.device
|
15 |
+
lcf_vec = lcf_vec.detach().cpu().numpy()
|
16 |
+
|
17 |
+
pooled_output = numpy.zeros(
|
18 |
+
(hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
|
19 |
+
)
|
20 |
+
hidden_states = hidden_states.detach().cpu().numpy()
|
21 |
+
for i, vec in enumerate(lcf_vec):
|
22 |
+
lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
|
23 |
+
pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
|
24 |
+
|
25 |
+
pooled_output = torch.Tensor(pooled_output).to(device)
|
26 |
+
pooled_output = self.dense(pooled_output)
|
27 |
+
pooled_output = self.activation(pooled_output)
|
28 |
+
return pooled_output
|
anonymous_demo/network/lsa.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class LSA(nn.Module):
|
7 |
+
def __init__(self, bert, opt):
|
8 |
+
super(LSA, self).__init__()
|
9 |
+
self.opt = opt
|
10 |
+
|
11 |
+
self.encoder = Encoder(bert.config, opt)
|
12 |
+
self.encoder_left = Encoder(bert.config, opt)
|
13 |
+
self.encoder_right = Encoder(bert.config, opt)
|
14 |
+
self.linear_window_3h = nn.Linear(opt.embed_dim * 3, opt.embed_dim)
|
15 |
+
self.linear_window_2h = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
|
16 |
+
self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
17 |
+
self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
18 |
+
|
19 |
+
def forward(
|
20 |
+
self,
|
21 |
+
global_context_features,
|
22 |
+
spc_mask_vec,
|
23 |
+
lcf_matrix,
|
24 |
+
left_lcf_matrix,
|
25 |
+
right_lcf_matrix,
|
26 |
+
):
|
27 |
+
masked_global_context_features = torch.mul(
|
28 |
+
spc_mask_vec, global_context_features
|
29 |
+
)
|
30 |
+
|
31 |
+
# # --------------------------------------------------- #
|
32 |
+
lcf_features = torch.mul(global_context_features, lcf_matrix)
|
33 |
+
lcf_features = self.encoder(lcf_features)
|
34 |
+
# # --------------------------------------------------- #
|
35 |
+
left_lcf_features = torch.mul(masked_global_context_features, left_lcf_matrix)
|
36 |
+
left_lcf_features = self.encoder_left(left_lcf_features)
|
37 |
+
# # --------------------------------------------------- #
|
38 |
+
right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
|
39 |
+
right_lcf_features = self.encoder_right(right_lcf_features)
|
40 |
+
# # --------------------------------------------------- #
|
41 |
+
if "lr" == self.opt.window or "rl" == self.opt.window:
|
42 |
+
if self.eta1 <= 0 and self.opt.eta != -1:
|
43 |
+
torch.nn.init.uniform_(self.eta1)
|
44 |
+
print("reset eta1 to: {}".format(self.eta1.item()))
|
45 |
+
if self.eta2 <= 0 and self.opt.eta != -1:
|
46 |
+
torch.nn.init.uniform_(self.eta2)
|
47 |
+
print("reset eta2 to: {}".format(self.eta2.item()))
|
48 |
+
if self.opt.eta >= 0:
|
49 |
+
cat_features = torch.cat(
|
50 |
+
(
|
51 |
+
lcf_features,
|
52 |
+
self.eta1 * left_lcf_features,
|
53 |
+
self.eta2 * right_lcf_features,
|
54 |
+
),
|
55 |
+
-1,
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
cat_features = torch.cat(
|
59 |
+
(lcf_features, left_lcf_features, right_lcf_features), -1
|
60 |
+
)
|
61 |
+
sent_out = self.linear_window_3h(cat_features)
|
62 |
+
elif "l" == self.opt.window:
|
63 |
+
sent_out = self.linear_window_2h(
|
64 |
+
torch.cat((lcf_features, self.eta1 * left_lcf_features), -1)
|
65 |
+
)
|
66 |
+
elif "r" == self.opt.window:
|
67 |
+
sent_out = self.linear_window_2h(
|
68 |
+
torch.cat((lcf_features, self.eta2 * right_lcf_features), -1)
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
raise KeyError("Invalid parameter:", self.opt.window)
|
72 |
+
|
73 |
+
return sent_out
|
anonymous_demo/network/sa_encoder.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class BertSelfAttention(nn.Module):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__()
|
11 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
12 |
+
config, "embedding_size"
|
13 |
+
):
|
14 |
+
raise ValueError(
|
15 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
16 |
+
f"heads ({config.num_attention_heads})"
|
17 |
+
)
|
18 |
+
|
19 |
+
self.num_attention_heads = config.num_attention_heads
|
20 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
21 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
22 |
+
|
23 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
24 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
25 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
26 |
+
|
27 |
+
self.dropout = nn.Dropout(
|
28 |
+
config.attention_probs_dropout_prob
|
29 |
+
if hasattr(config, "attention_probs_dropout_prob")
|
30 |
+
else 0
|
31 |
+
)
|
32 |
+
self.position_embedding_type = getattr(
|
33 |
+
config, "position_embedding_type", "absolute"
|
34 |
+
)
|
35 |
+
if (
|
36 |
+
self.position_embedding_type == "relative_key"
|
37 |
+
or self.position_embedding_type == "relative_key_query"
|
38 |
+
):
|
39 |
+
self.max_position_embeddings = config.max_position_embeddings
|
40 |
+
self.distance_embedding = nn.Embedding(
|
41 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
42 |
+
)
|
43 |
+
|
44 |
+
self.is_decoder = config.is_decoder
|
45 |
+
|
46 |
+
def transpose_for_scores(self, x):
|
47 |
+
new_x_shape = x.size()[:-1] + (
|
48 |
+
self.num_attention_heads,
|
49 |
+
self.attention_head_size,
|
50 |
+
)
|
51 |
+
x = x.view(*new_x_shape)
|
52 |
+
return x.permute(0, 2, 1, 3)
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
hidden_states,
|
57 |
+
attention_mask=None,
|
58 |
+
head_mask=None,
|
59 |
+
encoder_hidden_states=None,
|
60 |
+
encoder_attention_mask=None,
|
61 |
+
past_key_value=None,
|
62 |
+
output_attentions=False,
|
63 |
+
):
|
64 |
+
mixed_query_layer = self.query(hidden_states)
|
65 |
+
|
66 |
+
# If this is instantiated as a cross-attention module, the keys
|
67 |
+
# and values come from an encoder; the attention mask needs to be
|
68 |
+
# such that the encoder's padding tokens are not attended to.
|
69 |
+
is_cross_attention = encoder_hidden_states is not None
|
70 |
+
|
71 |
+
if is_cross_attention and past_key_value is not None:
|
72 |
+
# reuse k,v, cross_attentions
|
73 |
+
key_layer = past_key_value[0]
|
74 |
+
value_layer = past_key_value[1]
|
75 |
+
attention_mask = encoder_attention_mask
|
76 |
+
elif is_cross_attention:
|
77 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
78 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
79 |
+
attention_mask = encoder_attention_mask
|
80 |
+
elif past_key_value is not None:
|
81 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
82 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
83 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
84 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
85 |
+
else:
|
86 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
87 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
88 |
+
|
89 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
90 |
+
|
91 |
+
if self.is_decoder:
|
92 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
93 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
94 |
+
# key/value_states (first "if" case)
|
95 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
96 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
97 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
98 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
99 |
+
past_key_value = (key_layer, value_layer)
|
100 |
+
|
101 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
102 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
103 |
+
|
104 |
+
if (
|
105 |
+
self.position_embedding_type == "relative_key"
|
106 |
+
or self.position_embedding_type == "relative_key_query"
|
107 |
+
):
|
108 |
+
seq_length = hidden_states.size()[1]
|
109 |
+
position_ids_l = torch.arange(
|
110 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
111 |
+
).view(-1, 1)
|
112 |
+
position_ids_r = torch.arange(
|
113 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
114 |
+
).view(1, -1)
|
115 |
+
distance = position_ids_l - position_ids_r
|
116 |
+
positional_embedding = self.distance_embedding(
|
117 |
+
distance + self.max_position_embeddings - 1
|
118 |
+
)
|
119 |
+
positional_embedding = positional_embedding.to(
|
120 |
+
dtype=query_layer.dtype
|
121 |
+
) # fp16 compatibility
|
122 |
+
|
123 |
+
if self.position_embedding_type == "relative_key":
|
124 |
+
relative_position_scores = torch.einsum(
|
125 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
126 |
+
)
|
127 |
+
attention_scores = attention_scores + relative_position_scores
|
128 |
+
elif self.position_embedding_type == "relative_key_query":
|
129 |
+
relative_position_scores_query = torch.einsum(
|
130 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
131 |
+
)
|
132 |
+
relative_position_scores_key = torch.einsum(
|
133 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
134 |
+
)
|
135 |
+
attention_scores = (
|
136 |
+
attention_scores
|
137 |
+
+ relative_position_scores_query
|
138 |
+
+ relative_position_scores_key
|
139 |
+
)
|
140 |
+
|
141 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
142 |
+
if attention_mask is not None:
|
143 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
144 |
+
attention_scores = attention_scores + attention_mask
|
145 |
+
|
146 |
+
# Normalize the attention scores to probabilities.
|
147 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
148 |
+
|
149 |
+
# This is actually dropping out entire tokens to attend to, which might
|
150 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
151 |
+
attention_probs = self.dropout(attention_probs)
|
152 |
+
|
153 |
+
# Mask heads if we want to
|
154 |
+
if head_mask is not None:
|
155 |
+
attention_probs = attention_probs * head_mask
|
156 |
+
|
157 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
158 |
+
|
159 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
160 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
161 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
162 |
+
|
163 |
+
outputs = (
|
164 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.is_decoder:
|
168 |
+
outputs = outputs + (past_key_value,)
|
169 |
+
return outputs
|
170 |
+
|
171 |
+
|
172 |
+
class Encoder(nn.Module):
|
173 |
+
def __init__(self, config, opt, layer_num=1):
|
174 |
+
super(Encoder, self).__init__()
|
175 |
+
self.opt = opt
|
176 |
+
self.config = config
|
177 |
+
self.encoder = nn.ModuleList(
|
178 |
+
[SelfAttention(config, opt) for _ in range(layer_num)]
|
179 |
+
)
|
180 |
+
self.tanh = torch.nn.Tanh()
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
for i, enc in enumerate(self.encoder):
|
184 |
+
x = self.tanh(enc(x)[0])
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class SelfAttention(nn.Module):
|
189 |
+
def __init__(self, config, opt):
|
190 |
+
super(SelfAttention, self).__init__()
|
191 |
+
self.opt = opt
|
192 |
+
self.config = config
|
193 |
+
self.SA = BertSelfAttention(config)
|
194 |
+
|
195 |
+
def forward(self, inputs):
|
196 |
+
zero_vec = np.zeros((inputs.size(0), 1, 1, self.opt.max_seq_len))
|
197 |
+
zero_tensor = torch.tensor(zero_vec).float().to(inputs.device)
|
198 |
+
SA_out = self.SA(inputs, zero_tensor)
|
199 |
+
return SA_out
|
anonymous_demo/utils/__init__.py
ADDED
File without changes
|
anonymous_demo/utils/demo_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import signal
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import gdown
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
import tqdm
|
14 |
+
from autocuda import auto_cuda, auto_cuda_name
|
15 |
+
from findfile import find_files, find_cwd_file, find_file
|
16 |
+
from termcolor import colored
|
17 |
+
from functools import wraps
|
18 |
+
|
19 |
+
from update_checker import parse_version
|
20 |
+
|
21 |
+
from anonymous_demo import __version__
|
22 |
+
|
23 |
+
|
24 |
+
def save_args(config, save_path):
|
25 |
+
f = open(os.path.join(save_path), mode="w", encoding="utf8")
|
26 |
+
for arg in config.args:
|
27 |
+
if config.args_call_count[arg]:
|
28 |
+
f.write("{}: {}\n".format(arg, config.args[arg]))
|
29 |
+
f.close()
|
30 |
+
|
31 |
+
|
32 |
+
def print_args(config, logger=None, mode=0):
|
33 |
+
args = [key for key in sorted(config.args.keys())]
|
34 |
+
for arg in args:
|
35 |
+
if logger:
|
36 |
+
logger.info(
|
37 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
38 |
+
arg, config.args[arg], config.args_call_count[arg]
|
39 |
+
)
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
44 |
+
arg, config.args[arg], config.args_call_count[arg]
|
45 |
+
)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
50 |
+
if "-100" in label_set:
|
51 |
+
label_to_index = {
|
52 |
+
origin_label: int(idx) - 1 if origin_label != "-100" else -100
|
53 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
54 |
+
}
|
55 |
+
index_to_label = {
|
56 |
+
int(idx) - 1 if origin_label != "-100" else -100: origin_label
|
57 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
58 |
+
}
|
59 |
+
else:
|
60 |
+
label_to_index = {
|
61 |
+
origin_label: int(idx)
|
62 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
63 |
+
}
|
64 |
+
index_to_label = {
|
65 |
+
int(idx): origin_label
|
66 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
67 |
+
}
|
68 |
+
if "index_to_label" not in opt.args:
|
69 |
+
opt.index_to_label = index_to_label
|
70 |
+
opt.label_to_index = label_to_index
|
71 |
+
|
72 |
+
if opt.index_to_label != index_to_label:
|
73 |
+
opt.index_to_label.update(index_to_label)
|
74 |
+
opt.label_to_index.update(label_to_index)
|
75 |
+
num_label = {l: 0 for l in label_set}
|
76 |
+
num_label["Sum"] = len(all_data)
|
77 |
+
for item in all_data:
|
78 |
+
try:
|
79 |
+
num_label[item[label_name]] += 1
|
80 |
+
item[label_name] = label_to_index[item[label_name]]
|
81 |
+
except Exception as e:
|
82 |
+
# print(e)
|
83 |
+
num_label[item.polarity] += 1
|
84 |
+
item.polarity = label_to_index[item.polarity]
|
85 |
+
print("Dataset Label Details: {}".format(num_label))
|
86 |
+
|
87 |
+
|
88 |
+
def check_and_fix_IOB_labels(label_map, opt):
|
89 |
+
index_to_IOB_label = {
|
90 |
+
int(label_map[origin_label]): origin_label for origin_label in label_map
|
91 |
+
}
|
92 |
+
opt.index_to_IOB_label = index_to_IOB_label
|
93 |
+
|
94 |
+
|
95 |
+
def get_device(auto_device):
|
96 |
+
if isinstance(auto_device, str) and auto_device == "allcuda":
|
97 |
+
device = "cuda"
|
98 |
+
elif isinstance(auto_device, str):
|
99 |
+
device = auto_device
|
100 |
+
elif isinstance(auto_device, bool):
|
101 |
+
device = auto_cuda() if auto_device else "cpu"
|
102 |
+
else:
|
103 |
+
device = auto_cuda()
|
104 |
+
try:
|
105 |
+
torch.device(device)
|
106 |
+
except RuntimeError as e:
|
107 |
+
print(
|
108 |
+
colored("Device assignment error: {}, redirect to CPU".format(e), "red")
|
109 |
+
)
|
110 |
+
device = "cpu"
|
111 |
+
device_name = auto_cuda_name()
|
112 |
+
return device, device_name
|
113 |
+
|
114 |
+
|
115 |
+
def _load_word_vec(path, word2idx=None, embed_dim=300):
|
116 |
+
fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore")
|
117 |
+
word_vec = {}
|
118 |
+
for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."):
|
119 |
+
tokens = line.rstrip().split()
|
120 |
+
word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:]
|
121 |
+
if word in word2idx.keys():
|
122 |
+
word_vec[word] = np.asarray(vec, dtype="float32")
|
123 |
+
return word_vec
|
124 |
+
|
125 |
+
|
126 |
+
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
|
127 |
+
if not os.path.exists("run"):
|
128 |
+
os.makedirs("run")
|
129 |
+
embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname))
|
130 |
+
if os.path.exists(embed_matrix_path):
|
131 |
+
print(
|
132 |
+
colored(
|
133 |
+
"Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format(
|
134 |
+
embed_matrix_path
|
135 |
+
),
|
136 |
+
"green",
|
137 |
+
)
|
138 |
+
)
|
139 |
+
embedding_matrix = pickle.load(open(embed_matrix_path, "rb"))
|
140 |
+
else:
|
141 |
+
glove_path = prepare_glove840_embedding(embed_matrix_path)
|
142 |
+
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
|
143 |
+
|
144 |
+
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
|
145 |
+
|
146 |
+
for word, i in tqdm.tqdm(
|
147 |
+
word2idx.items(),
|
148 |
+
postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"),
|
149 |
+
):
|
150 |
+
vec = word_vec.get(word)
|
151 |
+
if vec is not None:
|
152 |
+
embedding_matrix[i] = vec
|
153 |
+
pickle.dump(embedding_matrix, open(embed_matrix_path, "wb"))
|
154 |
+
return embedding_matrix
|
155 |
+
|
156 |
+
|
157 |
+
def pad_and_truncate(
|
158 |
+
sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0
|
159 |
+
):
|
160 |
+
x = (np.ones(maxlen) * value).astype(dtype)
|
161 |
+
if truncating == "pre":
|
162 |
+
trunc = sequence[-maxlen:]
|
163 |
+
else:
|
164 |
+
trunc = sequence[:maxlen]
|
165 |
+
trunc = np.asarray(trunc, dtype=dtype)
|
166 |
+
if padding == "post":
|
167 |
+
x[: len(trunc)] = trunc
|
168 |
+
else:
|
169 |
+
x[-len(trunc) :] = trunc
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class TransformerConnectionError(ValueError):
|
174 |
+
def __init__(self):
|
175 |
+
pass
|
176 |
+
|
177 |
+
|
178 |
+
def retry(f):
|
179 |
+
@wraps(f)
|
180 |
+
def decorated(*args, **kwargs):
|
181 |
+
count = 5
|
182 |
+
while count:
|
183 |
+
try:
|
184 |
+
return f(*args, **kwargs)
|
185 |
+
except (
|
186 |
+
TransformerConnectionError,
|
187 |
+
requests.exceptions.RequestException,
|
188 |
+
requests.exceptions.ConnectionError,
|
189 |
+
requests.exceptions.HTTPError,
|
190 |
+
requests.exceptions.ConnectTimeout,
|
191 |
+
requests.exceptions.ProxyError,
|
192 |
+
requests.exceptions.SSLError,
|
193 |
+
requests.exceptions.BaseHTTPError,
|
194 |
+
) as e:
|
195 |
+
print(colored("Training Exception: {}, will retry later".format(e)))
|
196 |
+
time.sleep(60)
|
197 |
+
count -= 1
|
198 |
+
|
199 |
+
return decorated
|
200 |
+
|
201 |
+
|
202 |
+
def save_json(dic, save_path):
|
203 |
+
if isinstance(dic, str):
|
204 |
+
dic = eval(dic)
|
205 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
206 |
+
# f.write(str(dict))
|
207 |
+
str_ = json.dumps(dic, ensure_ascii=False)
|
208 |
+
f.write(str_)
|
209 |
+
|
210 |
+
|
211 |
+
def load_json(save_path):
|
212 |
+
with open(save_path, "r", encoding="utf-8") as f:
|
213 |
+
data = f.readline().strip()
|
214 |
+
print(type(data), data)
|
215 |
+
dic = json.loads(data)
|
216 |
+
return dic
|
217 |
+
|
218 |
+
|
219 |
+
def init_optimizer(optimizer):
|
220 |
+
optimizers = {
|
221 |
+
"adadelta": torch.optim.Adadelta, # default lr=1.0
|
222 |
+
"adagrad": torch.optim.Adagrad, # default lr=0.01
|
223 |
+
"adam": torch.optim.Adam, # default lr=0.001
|
224 |
+
"adamax": torch.optim.Adamax, # default lr=0.002
|
225 |
+
"asgd": torch.optim.ASGD, # default lr=0.01
|
226 |
+
"rmsprop": torch.optim.RMSprop, # default lr=0.01
|
227 |
+
"sgd": torch.optim.SGD,
|
228 |
+
"adamw": torch.optim.AdamW,
|
229 |
+
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
|
230 |
+
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
|
231 |
+
torch.optim.Adam: torch.optim.Adam, # default lr=0.001
|
232 |
+
torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002
|
233 |
+
torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01
|
234 |
+
torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01
|
235 |
+
torch.optim.SGD: torch.optim.SGD,
|
236 |
+
torch.optim.AdamW: torch.optim.AdamW,
|
237 |
+
}
|
238 |
+
if optimizer in optimizers:
|
239 |
+
return optimizers[optimizer]
|
240 |
+
elif hasattr(torch.optim, optimizer.__name__):
|
241 |
+
return optimizer
|
242 |
+
else:
|
243 |
+
raise KeyError(
|
244 |
+
"Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format(
|
245 |
+
optimizer
|
246 |
+
)
|
247 |
+
)
|
anonymous_demo/utils/logger.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
|
6 |
+
import termcolor
|
7 |
+
|
8 |
+
today = time.strftime("%Y%m%d %H%M%S", time.localtime(time.time()))
|
9 |
+
|
10 |
+
|
11 |
+
def get_logger(log_path, log_name="", log_type="training_log"):
|
12 |
+
if not log_path:
|
13 |
+
log_dir = os.path.join(log_path, "logs")
|
14 |
+
else:
|
15 |
+
log_dir = os.path.join(".", "logs")
|
16 |
+
|
17 |
+
full_path = os.path.join(log_dir, log_name + "_" + today)
|
18 |
+
if not os.path.exists(full_path):
|
19 |
+
os.makedirs(full_path)
|
20 |
+
log_path = os.path.join(full_path, "{}.log".format(log_type))
|
21 |
+
logger = logging.getLogger(log_name)
|
22 |
+
if not logger.handlers:
|
23 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
|
24 |
+
|
25 |
+
file_handler = logging.FileHandler(log_path, encoding="utf8")
|
26 |
+
file_handler.setFormatter(formatter)
|
27 |
+
file_handler.setLevel(logging.INFO)
|
28 |
+
|
29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
30 |
+
console_handler.formatter = formatter
|
31 |
+
console_handler.setLevel(logging.INFO)
|
32 |
+
|
33 |
+
logger.addHandler(file_handler)
|
34 |
+
logger.addHandler(console_handler)
|
35 |
+
|
36 |
+
logger.setLevel(logging.INFO)
|
37 |
+
|
38 |
+
return logger
|
app.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zipfile
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import nltk
|
6 |
+
import pandas as pd
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from pyabsa import TADCheckpointManager
|
10 |
+
from textattack.attack_recipes import (
|
11 |
+
BAEGarg2019,
|
12 |
+
PWWSRen2019,
|
13 |
+
TextFoolerJin2019,
|
14 |
+
PSOZang2020,
|
15 |
+
IGAWang2019,
|
16 |
+
GeneticAlgorithmAlzantot2018,
|
17 |
+
DeepWordBugGao2018,
|
18 |
+
CLARE2020,
|
19 |
+
)
|
20 |
+
from textattack.attack_results import SuccessfulAttackResult
|
21 |
+
from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts
|
22 |
+
# from utils import get_yahoo_example
|
23 |
+
|
24 |
+
sent_attackers = {}
|
25 |
+
tad_classifiers = {}
|
26 |
+
|
27 |
+
attack_recipes = {
|
28 |
+
"bae": BAEGarg2019,
|
29 |
+
"pwws": PWWSRen2019,
|
30 |
+
"textfooler": TextFoolerJin2019,
|
31 |
+
"pso": PSOZang2020,
|
32 |
+
"iga": IGAWang2019,
|
33 |
+
"ga": GeneticAlgorithmAlzantot2018,
|
34 |
+
"deepwordbug": DeepWordBugGao2018,
|
35 |
+
"clare": CLARE2020,
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def init():
|
40 |
+
nltk.download("omw-1.4")
|
41 |
+
|
42 |
+
if not os.path.exists("TAD-SST2"):
|
43 |
+
z = zipfile.ZipFile("checkpoints.zip", "r")
|
44 |
+
z.extractall(os.getcwd())
|
45 |
+
|
46 |
+
for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]:
|
47 |
+
for dataset in [
|
48 |
+
"agnews10k",
|
49 |
+
"sst2",
|
50 |
+
"MR",
|
51 |
+
'imdb'
|
52 |
+
]:
|
53 |
+
if "tad-{}".format(dataset) not in tad_classifiers:
|
54 |
+
tad_classifiers[
|
55 |
+
"tad-{}".format(dataset)
|
56 |
+
] = TADCheckpointManager.get_tad_text_classifier(
|
57 |
+
"tad-{}".format(dataset).upper()
|
58 |
+
)
|
59 |
+
|
60 |
+
sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker(
|
61 |
+
tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker]
|
62 |
+
)
|
63 |
+
tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[
|
64 |
+
"tad-{}pwws".format(dataset)
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
cache = set()
|
69 |
+
|
70 |
+
|
71 |
+
def generate_adversarial_example(dataset, attacker, text=None, label=None):
|
72 |
+
"""if not text or text in cache:
|
73 |
+
if "agnews" in dataset.lower():
|
74 |
+
text, label = get_agnews_example()
|
75 |
+
elif "sst2" in dataset.lower():
|
76 |
+
text, label = get_sst2_example()
|
77 |
+
elif "MR" in dataset.lower():
|
78 |
+
text, label = get_amazon_example()
|
79 |
+
# elif "yahoo" in dataset.lower():
|
80 |
+
# text, label = get_yahoo_example()
|
81 |
+
elif "imdb" in dataset.lower():
|
82 |
+
text, label = get_imdb_example()"""
|
83 |
+
|
84 |
+
cache.add(text)
|
85 |
+
|
86 |
+
result = None
|
87 |
+
attack_result = sent_attackers[
|
88 |
+
"tad-{}{}".format(dataset.lower(), attacker.lower())
|
89 |
+
].attacker.simple_attack(text, int(label))
|
90 |
+
if isinstance(attack_result, SuccessfulAttackResult):
|
91 |
+
if (
|
92 |
+
attack_result.perturbed_result.output
|
93 |
+
!= attack_result.original_result.ground_truth_output
|
94 |
+
) and (
|
95 |
+
attack_result.original_result.output
|
96 |
+
== attack_result.original_result.ground_truth_output
|
97 |
+
):
|
98 |
+
# with defense
|
99 |
+
result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
|
100 |
+
attack_result.perturbed_result.attacked_text.text
|
101 |
+
+ "$LABEL${},{},{}".format(
|
102 |
+
attack_result.original_result.ground_truth_output,
|
103 |
+
1,
|
104 |
+
attack_result.perturbed_result.output,
|
105 |
+
),
|
106 |
+
print_result=True,
|
107 |
+
defense=attacker,
|
108 |
+
)
|
109 |
+
|
110 |
+
if result:
|
111 |
+
classification_df = {}
|
112 |
+
classification_df["is_repaired"] = result["is_fixed"]
|
113 |
+
classification_df["pred_label"] = result["label"]
|
114 |
+
classification_df["confidence"] = round(result["confidence"], 3)
|
115 |
+
classification_df["is_correct"] = str(result["pred_label"]) == str(label)
|
116 |
+
|
117 |
+
advdetection_df = {}
|
118 |
+
if result["is_adv_label"] != "0":
|
119 |
+
advdetection_df["is_adversarial"] = {
|
120 |
+
"0": False,
|
121 |
+
"1": True,
|
122 |
+
0: False,
|
123 |
+
1: True,
|
124 |
+
}[result["is_adv_label"]]
|
125 |
+
advdetection_df["perturbed_label"] = result["perturbed_label"]
|
126 |
+
advdetection_df["confidence"] = round(result["is_adv_confidence"], 3)
|
127 |
+
advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
|
128 |
+
advdetection_df['is_correct'] = result['ref_is_adv_check']
|
129 |
+
|
130 |
+
else:
|
131 |
+
return generate_adversarial_example(dataset, attacker)
|
132 |
+
|
133 |
+
return (
|
134 |
+
text,
|
135 |
+
label,
|
136 |
+
result["restored_text"],
|
137 |
+
result["label"],
|
138 |
+
attack_result.perturbed_result.attacked_text.text,
|
139 |
+
diff_texts(text, text),
|
140 |
+
diff_texts(text, attack_result.perturbed_result.attacked_text.text),
|
141 |
+
diff_texts(text, result["restored_text"]),
|
142 |
+
attack_result.perturbed_result.output,
|
143 |
+
pd.DataFrame(classification_df, index=[0]),
|
144 |
+
pd.DataFrame(advdetection_df, index=[0]),
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def run_demo(dataset, attacker, text=None, label=None):
|
149 |
+
try:
|
150 |
+
data = {
|
151 |
+
"dataset": dataset,
|
152 |
+
"attacker": attacker,
|
153 |
+
"text": text,
|
154 |
+
"label": label,
|
155 |
+
}
|
156 |
+
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data)
|
157 |
+
result = response.json()
|
158 |
+
print(response.json())
|
159 |
+
return (
|
160 |
+
result["text"],
|
161 |
+
result["label"],
|
162 |
+
result["restored_text"],
|
163 |
+
result["result_label"],
|
164 |
+
result["perturbed_text"],
|
165 |
+
result["text_diff"],
|
166 |
+
result["perturbed_diff"],
|
167 |
+
result["restored_diff"],
|
168 |
+
result["output"],
|
169 |
+
pd.DataFrame(result["classification_df"]),
|
170 |
+
pd.DataFrame(result["advdetection_df"]),
|
171 |
+
result["message"]
|
172 |
+
)
|
173 |
+
except Exception as e:
|
174 |
+
print(e)
|
175 |
+
return generate_adversarial_example(dataset, attacker, text, label)
|
176 |
+
|
177 |
+
|
178 |
+
def check_gpu():
|
179 |
+
try:
|
180 |
+
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', timeout=3)
|
181 |
+
if response.status_code < 500:
|
182 |
+
return 'GPU available'
|
183 |
+
else:
|
184 |
+
return 'GPU not available'
|
185 |
+
except Exception as e:
|
186 |
+
return 'GPU not available'
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
try:
|
191 |
+
init()
|
192 |
+
except Exception as e:
|
193 |
+
print(e)
|
194 |
+
print("Failed to initialize the demo. Please try again later.")
|
195 |
+
|
196 |
+
demo = gr.Blocks()
|
197 |
+
|
198 |
+
with demo:
|
199 |
+
gr.Markdown("<h1 align='center'>Detection and Correction based on Word Importance Ranking (DCWIR) </h1>")
|
200 |
+
gr.Markdown("<h2 align='center'>Clarifications</h2>")
|
201 |
+
gr.Markdown("""
|
202 |
+
- This demo has no mechanism to ensure the adversarial example will be correctly repaired by Rapid. The repair success rate is actually the performance reported in the paper.The user must know the resulted output for sake of demonstration.
|
203 |
+
- The adversarial example and corrected adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations.
|
204 |
+
- All the proposed attacks are Black Box attack where the attacker has no access to the model parameters.
|
205 |
+
""")
|
206 |
+
gr.Markdown("<h2 align='center'>Natural Example Input</h2>")
|
207 |
+
with gr.Group():
|
208 |
+
with gr.Row():
|
209 |
+
input_dataset = gr.Radio(
|
210 |
+
choices=["SST2", "IMDB", "MR", "AGNews10K"],
|
211 |
+
value="SST2",
|
212 |
+
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
|
213 |
+
)
|
214 |
+
input_attacker = gr.Radio(
|
215 |
+
choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"],
|
216 |
+
value="TextFooler",
|
217 |
+
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
|
218 |
+
)
|
219 |
+
with gr.Group(visible=True):
|
220 |
+
|
221 |
+
with gr.Row():
|
222 |
+
input_sentence = gr.Textbox(
|
223 |
+
placeholder="Input a natural example...",
|
224 |
+
label="Alternatively, input a natural example and its original label (from above datasets) to generate an adversarial example.",
|
225 |
+
|
226 |
+
)
|
227 |
+
input_label = gr.Textbox(
|
228 |
+
placeholder="Original label, (must be a integer, because we use digits to represent labels in training)",
|
229 |
+
label="Original Label",
|
230 |
+
)
|
231 |
+
gr.Markdown(
|
232 |
+
"<h3 align='center'>Default parameters are set according to the main experiment setup in the report.</h2>",
|
233 |
+
)
|
234 |
+
with gr.Row():
|
235 |
+
wir_percentage = gr.Textbox(
|
236 |
+
placeholder="Enter percentage from WIR...",
|
237 |
+
label="Percentage from WIR",
|
238 |
+
)
|
239 |
+
frequency_threshold = gr.Textbox(
|
240 |
+
placeholder="Enter frequency threshold...",
|
241 |
+
label="Frequency Threshold",
|
242 |
+
)
|
243 |
+
max_candidates = gr.Textbox(
|
244 |
+
placeholder="Enter maximum number of candidates...",
|
245 |
+
label="Maximum Number of Candidates",
|
246 |
+
)
|
247 |
+
msg_text = gr.Textbox(
|
248 |
+
label="Message",
|
249 |
+
placeholder="This is a message box to show any error messages.",
|
250 |
+
)
|
251 |
+
button_gen = gr.Button(
|
252 |
+
"Generate an adversarial example to repair using Rapid (GPU: < 1 minute, CPU: 1-10 minutes)",
|
253 |
+
variant="primary",
|
254 |
+
)
|
255 |
+
gpu_status_text = gr.Textbox(
|
256 |
+
label='GPU status',
|
257 |
+
placeholder="Please click to check",
|
258 |
+
)
|
259 |
+
button_check = gr.Button(
|
260 |
+
"Check if GPU available",
|
261 |
+
variant="primary"
|
262 |
+
)
|
263 |
+
|
264 |
+
button_check.click(
|
265 |
+
fn=check_gpu,
|
266 |
+
inputs=[],
|
267 |
+
outputs=[
|
268 |
+
gpu_status_text
|
269 |
+
]
|
270 |
+
)
|
271 |
+
|
272 |
+
gr.Markdown("<h2 align='center'>Generated Adversarial Example and Repaired Adversarial Example</h2>")
|
273 |
+
|
274 |
+
with gr.Column():
|
275 |
+
with gr.Group():
|
276 |
+
with gr.Row():
|
277 |
+
output_original_example = gr.Textbox(label="Original Example")
|
278 |
+
output_original_label = gr.Textbox(label="Original Label")
|
279 |
+
with gr.Row():
|
280 |
+
output_adv_example = gr.Textbox(label="Adversarial Example")
|
281 |
+
output_adv_label = gr.Textbox(label="Predicted Label of the Adversarial Example")
|
282 |
+
with gr.Row():
|
283 |
+
output_repaired_example = gr.Textbox(
|
284 |
+
label="Repaired Adversarial Example by Rapid"
|
285 |
+
)
|
286 |
+
output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example")
|
287 |
+
|
288 |
+
gr.Markdown("<h2 align='center'>Example Difference (Comparisons)</p>")
|
289 |
+
gr.Markdown("""
|
290 |
+
<p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p>
|
291 |
+
""")
|
292 |
+
ori_text_diff = gr.HighlightedText(
|
293 |
+
label="The Original Natural Example",
|
294 |
+
combine_adjacent=True,
|
295 |
+
show_legend=True,
|
296 |
+
)
|
297 |
+
adv_text_diff = gr.HighlightedText(
|
298 |
+
label="Character Editions of Adversarial Example Compared to the Natural Example",
|
299 |
+
combine_adjacent=True,
|
300 |
+
show_legend=True,
|
301 |
+
)
|
302 |
+
|
303 |
+
restored_text_diff = gr.HighlightedText(
|
304 |
+
label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
|
305 |
+
combine_adjacent=True,
|
306 |
+
show_legend=True,
|
307 |
+
)
|
308 |
+
|
309 |
+
gr.Markdown(
|
310 |
+
"## <h2 align='center'>The Output of Reactive Perturbation Defocusing</p>"
|
311 |
+
)
|
312 |
+
with gr.Row():
|
313 |
+
with gr.Column():
|
314 |
+
with gr.Group():
|
315 |
+
output_is_adv_df = gr.DataFrame(
|
316 |
+
label="Adversarial Example Detection Result"
|
317 |
+
)
|
318 |
+
gr.Markdown(
|
319 |
+
"""
|
320 |
+
- The is_adversarial field indicates if an adversarial example is detected.
|
321 |
+
- The perturbed_label is the predicted label of the adversarial example.
|
322 |
+
- The confidence field represents the ratio of Inverted samples among the total number of generated candidates.
|
323 |
+
"""
|
324 |
+
)
|
325 |
+
with gr.Column():
|
326 |
+
with gr.Group():
|
327 |
+
output_df = gr.DataFrame(
|
328 |
+
label="Correction Classification Result"
|
329 |
+
)
|
330 |
+
gr.Markdown(
|
331 |
+
"""
|
332 |
+
- If is_corrected=true, it has been Corrected by DCWIR.
|
333 |
+
- The pred_label field indicates the standard classification result.
|
334 |
+
- The confidence field represents ratio of the dominant class among all Inverted candidates.
|
335 |
+
- The is_correct field indicates whether the predicted label is correct.
|
336 |
+
|
337 |
+
"""
|
338 |
+
)
|
339 |
+
|
340 |
+
# Bind functions to buttons
|
341 |
+
button_gen.click(
|
342 |
+
fn=run_demo,
|
343 |
+
inputs=[input_dataset, input_attacker, input_sentence, input_label],
|
344 |
+
outputs=[
|
345 |
+
output_original_example,
|
346 |
+
output_original_label,
|
347 |
+
output_repaired_example,
|
348 |
+
output_repaired_label,
|
349 |
+
output_adv_example,
|
350 |
+
ori_text_diff,
|
351 |
+
adv_text_diff,
|
352 |
+
restored_text_diff,
|
353 |
+
output_adv_label,
|
354 |
+
output_df,
|
355 |
+
output_is_adv_df,
|
356 |
+
msg_text
|
357 |
+
],
|
358 |
+
)
|
359 |
+
|
360 |
+
demo.queue(2).launch()
|
checkpoints.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f77ae4a45785183900ee874cb318a16b0e2f173b31749a2555215aca93672f26
|
3 |
+
size 2456834455
|
flow correction 30%.ipynb
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"data": {
|
10 |
+
"application/vnd.jupyter.widget-view+json": {
|
11 |
+
"model_id": "f0f495f5278946bebfcef7f58113879b",
|
12 |
+
"version_major": 2,
|
13 |
+
"version_minor": 0
|
14 |
+
},
|
15 |
+
"text/plain": [
|
16 |
+
"pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"metadata": {},
|
20 |
+
"output_type": "display_data"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"name": "stderr",
|
24 |
+
"output_type": "stream",
|
25 |
+
"text": [
|
26 |
+
"C:\\Users\\Isaac\\AppData\\Roaming\\Python\\Python38\\site-packages\\huggingface_hub\\file_download.py:148: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Isaac\\.cache\\huggingface\\hub\\models--textattack--bert-base-uncased-SST-2. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
|
27 |
+
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
|
28 |
+
" warnings.warn(message)\n"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"data": {
|
33 |
+
"application/vnd.jupyter.widget-view+json": {
|
34 |
+
"model_id": "7430033906354f39b753fb46f7113501",
|
35 |
+
"version_major": 2,
|
36 |
+
"version_minor": 0
|
37 |
+
},
|
38 |
+
"text/plain": [
|
39 |
+
"tokenizer_config.json: 0%| | 0.00/48.0 [00:00<?, ?B/s]"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
"metadata": {},
|
43 |
+
"output_type": "display_data"
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"data": {
|
47 |
+
"application/vnd.jupyter.widget-view+json": {
|
48 |
+
"model_id": "bce8defec56e4810ba28e42b2d18538e",
|
49 |
+
"version_major": 2,
|
50 |
+
"version_minor": 0
|
51 |
+
},
|
52 |
+
"text/plain": [
|
53 |
+
"vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
"metadata": {},
|
57 |
+
"output_type": "display_data"
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"data": {
|
61 |
+
"application/vnd.jupyter.widget-view+json": {
|
62 |
+
"model_id": "a4c13d3d415c4ec8adf87cd6efca5989",
|
63 |
+
"version_major": 2,
|
64 |
+
"version_minor": 0
|
65 |
+
},
|
66 |
+
"text/plain": [
|
67 |
+
"special_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
"metadata": {},
|
71 |
+
"output_type": "display_data"
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "stderr",
|
75 |
+
"output_type": "stream",
|
76 |
+
"text": [
|
77 |
+
"textattack: Unknown if model of class <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.\n"
|
78 |
+
]
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"import textattack\n",
|
83 |
+
"import transformers\n",
|
84 |
+
"\n",
|
85 |
+
"# Load model, tokenizer, and model_wrapper\n",
|
86 |
+
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\n",
|
87 |
+
" \"textattack/bert-base-uncased-SST-2\"\n",
|
88 |
+
")\n",
|
89 |
+
"tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
|
90 |
+
" \"textattack/bert-base-uncased-SST-2\"\n",
|
91 |
+
")\n",
|
92 |
+
"model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)\n",
|
93 |
+
"\n",
|
94 |
+
"# Construct our four components for `Attack`\n",
|
95 |
+
"from textattack.constraints.pre_transformation import (\n",
|
96 |
+
" RepeatModification,\n",
|
97 |
+
" StopwordModification,\n",
|
98 |
+
")\n",
|
99 |
+
"from textattack.constraints.semantics import WordEmbeddingDistance\n",
|
100 |
+
"from textattack.transformations import WordSwapEmbedding\n",
|
101 |
+
"from textattack.search_methods import GreedyWordSwapWIR\n",
|
102 |
+
"\n",
|
103 |
+
"goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)\n",
|
104 |
+
"constraints = [\n",
|
105 |
+
" RepeatModification(),\n",
|
106 |
+
" StopwordModification(),\n",
|
107 |
+
" WordEmbeddingDistance(min_cos_sim=0.9),\n",
|
108 |
+
"]\n",
|
109 |
+
"transformation = WordSwapEmbedding(max_candidates=50)\n",
|
110 |
+
"# weighted-saliency\n",
|
111 |
+
"search_method = GreedyWordSwapWIR(wir_method=\"weighted-saliency\")\n",
|
112 |
+
"\n",
|
113 |
+
"# Construct the actual attack\n",
|
114 |
+
"attack = textattack.Attack(goal_function, constraints, transformation, search_method)\n",
|
115 |
+
"attack.cuda_()"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": 2,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"import pandas as pd\n",
|
125 |
+
"\n",
|
126 |
+
"results = pd.read_csv(\"ag-news_pwws_bert.csv\")\n",
|
127 |
+
"#results.columns"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": 25,
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [],
|
135 |
+
"source": [
|
136 |
+
"\n",
|
137 |
+
"\"\"\"successful_perturbed_texts = results.loc[results[\"result_type\"] == \"Successful\", \"perturbed_text\"].tolist()\n",
|
138 |
+
"failed_perturbed_texts = results.loc[results[\"result_type\"] == \"Failed\", \"perturbed_text\"].tolist()\n",
|
139 |
+
"\n",
|
140 |
+
"failed_perturbed_outputs = results.loc[results[\"result_type\"] == \"Failed\", \"perturbed_output\"].tolist()\n",
|
141 |
+
"successful_perturbed_outputs = results.loc[results[\"result_type\"] == \"Successful\", \"original_output\"].tolist()\"\"\"\n",
|
142 |
+
"\n",
|
143 |
+
"\n",
|
144 |
+
"original_texts = results[\"original_text\"].tolist()\n",
|
145 |
+
"perturbed_texts =results[\"adversarial_text\"].tolist() \n",
|
146 |
+
"\n",
|
147 |
+
"original_outputs = results[\"original_class\"].tolist()\n",
|
148 |
+
"perturbed_outputs =results[\"adversarial_class\"].tolist() "
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": 4,
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"import re\n",
|
158 |
+
"import string\n",
|
159 |
+
"# Clean Text\n",
|
160 |
+
"def remove_brackets(text):\n",
|
161 |
+
" text = text.replace('[[', '')\n",
|
162 |
+
" text = text.replace(']]', '')\n",
|
163 |
+
" return text\n",
|
164 |
+
"\n",
|
165 |
+
"perturbed_texts = [remove_brackets(text) for text in perturbed_texts]\n",
|
166 |
+
"original_texts = [remove_brackets(text) for text in original_texts]\n",
|
167 |
+
"\n",
|
168 |
+
"def clean_text(text):\n",
|
169 |
+
" pattern = \"[\" + re.escape(string.punctuation) + \"]\"\n",
|
170 |
+
" cleaned_text = re.sub(pattern, \" \", text)\n",
|
171 |
+
"\n",
|
172 |
+
" return cleaned_text\n",
|
173 |
+
"\n",
|
174 |
+
"perturbed_texts = [clean_text(text) for text in perturbed_texts]\n",
|
175 |
+
"original_texts = [clean_text(text) for text in original_texts]\n",
|
176 |
+
"\n"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 5,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": [
|
185 |
+
"perturbed_texts = [text.lower() for text in perturbed_texts]\n",
|
186 |
+
"original_texts = [text.lower() for text in original_texts]"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": 6,
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"from FlowCorrector import Flow_Corrector\n",
|
196 |
+
"\n",
|
197 |
+
"corrector = Flow_Corrector(\n",
|
198 |
+
" attack,\n",
|
199 |
+
" word_rank_file=\"en_full_ranked.json\",\n",
|
200 |
+
" word_freq_file=\"en_full_freq.json\",\n",
|
201 |
+
")\n"
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"execution_count": 7,
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [
|
209 |
+
{
|
210 |
+
"data": {
|
211 |
+
"application/vnd.jupyter.widget-view+json": {
|
212 |
+
"model_id": "a1241448bb324872a1da1f2b659150c5",
|
213 |
+
"version_major": 2,
|
214 |
+
"version_minor": 0
|
215 |
+
},
|
216 |
+
"text/plain": [
|
217 |
+
" 0%| | 0/424 [00:00<?, ?it/s]"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
"metadata": {},
|
221 |
+
"output_type": "display_data"
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"source": [
|
225 |
+
"import torch\n",
|
226 |
+
"import torch.nn.functional as F\n",
|
227 |
+
"from tqdm.notebook import tqdm_notebook\n",
|
228 |
+
"\n",
|
229 |
+
"victim_model = attack.goal_function.model\n",
|
230 |
+
"\n",
|
231 |
+
"original_classes = [\n",
|
232 |
+
" torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()\n",
|
233 |
+
" for original_text in tqdm_notebook(original_texts)\n",
|
234 |
+
"]\n"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": 18,
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [
|
242 |
+
{
|
243 |
+
"data": {
|
244 |
+
"application/vnd.jupyter.widget-view+json": {
|
245 |
+
"model_id": "d2927a578aa5401796a0c4cd48a11de6",
|
246 |
+
"version_major": 2,
|
247 |
+
"version_minor": 0
|
248 |
+
},
|
249 |
+
"text/plain": [
|
250 |
+
" 0%| | 0/424 [00:00<?, ?it/s]"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
"metadata": {},
|
254 |
+
"output_type": "display_data"
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"name": "stderr",
|
258 |
+
"output_type": "stream",
|
259 |
+
"text": [
|
260 |
+
"c:\\Users\\Isaac\\anaconda3\\envs\\textattackenv\\lib\\site-packages\\torch\\nn\\modules\\module.py:1117: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
|
261 |
+
" warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n"
|
262 |
+
]
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"\"\"\" 0 :World\n",
|
267 |
+
" 1 : Sports\n",
|
268 |
+
" 2 : Business\n",
|
269 |
+
" 3 : Sci/Tech\"\"\"\n",
|
270 |
+
"\n",
|
271 |
+
"corrected_classes = corrector.correct(perturbed_texts)"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": 19,
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [],
|
279 |
+
"source": [
|
280 |
+
"def count_matching_classes(original, corrected, perturbed_texts=None):\n",
|
281 |
+
" if len(original) != len(corrected):\n",
|
282 |
+
" raise ValueError(\"Arrays must have the same length\")\n",
|
283 |
+
" hard_samples = []\n",
|
284 |
+
" easy_samples = []\n",
|
285 |
+
"\n",
|
286 |
+
" matching_count = 0\n",
|
287 |
+
"\n",
|
288 |
+
" for i in range(len(corrected)):\n",
|
289 |
+
" if original[i] == corrected[i]:\n",
|
290 |
+
" matching_count += 1\n",
|
291 |
+
" easy_samples.append(perturbed_texts[i])\n",
|
292 |
+
" elif perturbed_texts != None:\n",
|
293 |
+
" hard_samples.append(perturbed_texts[i])\n",
|
294 |
+
"\n",
|
295 |
+
" return matching_count, hard_samples, easy_samples\n",
|
296 |
+
"\n",
|
297 |
+
"\n",
|
298 |
+
"match_, hard_samples, easy_samples = count_matching_classes(\n",
|
299 |
+
" original_classes, corrected_classes, perturbed_texts\n",
|
300 |
+
")"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"execution_count": 23,
|
306 |
+
"metadata": {},
|
307 |
+
"outputs": [
|
308 |
+
{
|
309 |
+
"data": {
|
310 |
+
"text/plain": [
|
311 |
+
"0.6014150943396226"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
"execution_count": 23,
|
315 |
+
"metadata": {},
|
316 |
+
"output_type": "execute_result"
|
317 |
+
}
|
318 |
+
],
|
319 |
+
"source": [
|
320 |
+
"match_/424"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"execution_count": null,
|
326 |
+
"metadata": {},
|
327 |
+
"outputs": [],
|
328 |
+
"source": [
|
329 |
+
"with open(\"detected_samples_ag_news.txt\", \"w\") as f:\n",
|
330 |
+
" for sample in easy_samples:\n",
|
331 |
+
" f.write(sample + \"\\n\")"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": null,
|
337 |
+
"metadata": {},
|
338 |
+
"outputs": [],
|
339 |
+
"source": [
|
340 |
+
"index_order = wir_gradient(attack, goal_function.model, detected_text)\n",
|
341 |
+
"\n",
|
342 |
+
"# The prblm is tht it exists many rare words with freq from 1 to 10\n",
|
343 |
+
"\n",
|
344 |
+
"# freq_thershold = len(word_ranked_frequence) * 0.01\n",
|
345 |
+
"\n",
|
346 |
+
"# for now i sugest to take the words taht have freq > freq_thershold (200 in paper)\n",
|
347 |
+
"\n",
|
348 |
+
"freq_thershold = 2000\n",
|
349 |
+
"\n",
|
350 |
+
"index_order_1 = [\n",
|
351 |
+
" idx\n",
|
352 |
+
" for idx in index_order\n",
|
353 |
+
"\n",
|
354 |
+
" if detected_text.words[idx] in word_frequence.keys()\n",
|
355 |
+
"\n",
|
356 |
+
" and word_frequence[detected_text.words[idx]] < freq_thershold\n",
|
357 |
+
"\n",
|
358 |
+
"]\n",
|
359 |
+
"\n",
|
360 |
+
"print(\n",
|
361 |
+
"\n",
|
362 |
+
" f\"from {len(index_order)} ranked word it remain only {len(index_order_1)} within frequency theshold = {freq_thershold} \"\n",
|
363 |
+
"\n",
|
364 |
+
")\n",
|
365 |
+
"\n",
|
366 |
+
"# or we take the lowest 30% in the important ranked words ?\n",
|
367 |
+
"index_order = index_order[:int(len(index_order) * 0.3)]\n",
|
368 |
+
"index_order_ = {\n",
|
369 |
+
" idx : word_ranked_frequence[detected_text.words[idx]]\n",
|
370 |
+
" for idx in index_order\n",
|
371 |
+
" if detected_text.words[idx] in word_ranked_frequence.keys()\n",
|
372 |
+
"}\n",
|
373 |
+
"\n",
|
374 |
+
"index_order_ = sorted(index_order_.items(), key=lambda item: item[1], reverse=False)\n",
|
375 |
+
"lowest = 0.15\n",
|
376 |
+
"index_order_ = [idx[0]for idx in index_order_][:int(len(index_order) * lowest)]\n",
|
377 |
+
"\n",
|
378 |
+
"print(f\"from {len(index_order)} ranked word {len(index_order_)} word represent {lowest * 100}% with the lowest frequency\")"
|
379 |
+
]
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"cell_type": "code",
|
383 |
+
"execution_count": null,
|
384 |
+
"metadata": {},
|
385 |
+
"outputs": [],
|
386 |
+
"source": [
|
387 |
+
"def remove_brackets(text):\n",
|
388 |
+
" text = text.replace('[[', '')\n",
|
389 |
+
" text = text.replace(']]', '')\n",
|
390 |
+
" return text\n",
|
391 |
+
"\n",
|
392 |
+
"text = \"Fears for T [[percent]] pension after [[debate]] [[Syndicates]] [[portrayal]] [[worker]] at Turner Newall say they are 'disappointed' after [[chatter]] with [[bereaved]] [[parenting]] [[corporations]] [[Canada]] Mogul.\" \n",
|
393 |
+
"print(remove_brackets(text))\n"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"cell_type": "code",
|
398 |
+
"execution_count": null,
|
399 |
+
"metadata": {},
|
400 |
+
"outputs": [],
|
401 |
+
"source": []
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": null,
|
406 |
+
"metadata": {},
|
407 |
+
"outputs": [],
|
408 |
+
"source": []
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"execution_count": null,
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [],
|
415 |
+
"source": [
|
416 |
+
"import json\n",
|
417 |
+
"\n",
|
418 |
+
"with open('en_full.txt', 'r') as f:\n",
|
419 |
+
" lines = f.readlines()\n",
|
420 |
+
"\n",
|
421 |
+
"\n",
|
422 |
+
"freq_dict = {line.split()[0]: int(line.split()[1]) for line in lines}\n",
|
423 |
+
"\n",
|
424 |
+
"\n",
|
425 |
+
"sorted_dict = dict(sorted(freq_dict.items(), key=lambda item: item[1], reverse=True))\n",
|
426 |
+
"\n",
|
427 |
+
"\n",
|
428 |
+
"ranked_dict = {word: freq for word, freq in sorted_dict.items() }\n",
|
429 |
+
"\n",
|
430 |
+
"\n",
|
431 |
+
"with open('en_full_freq.json', 'w') as f:\n",
|
432 |
+
" json.dump(ranked_dict, f)\n",
|
433 |
+
"\n",
|
434 |
+
"print(\"The word frequencies have been successfully ranked and saved to ranked_freq.json file.\")\n"
|
435 |
+
]
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"cell_type": "code",
|
439 |
+
"execution_count": 7,
|
440 |
+
"metadata": {},
|
441 |
+
"outputs": [
|
442 |
+
{
|
443 |
+
"data": {
|
444 |
+
"image/png": "",
|
445 |
+
"text/plain": [
|
446 |
+
"<Figure size 1200x500 with 2 Axes>"
|
447 |
+
]
|
448 |
+
},
|
449 |
+
"metadata": {},
|
450 |
+
"output_type": "display_data"
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"data": {
|
454 |
+
"text/plain": [
|
455 |
+
"<Figure size 640x480 with 0 Axes>"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
"metadata": {},
|
459 |
+
"output_type": "display_data"
|
460 |
+
}
|
461 |
+
],
|
462 |
+
"source": [
|
463 |
+
"import matplotlib.pyplot as plt\n",
|
464 |
+
"import seaborn as sns\n",
|
465 |
+
"# Assuming these are your accuracy and loss values\n",
|
466 |
+
"accuracy = [0.6, 0.65, 0.7, 0.72, 0.74, 0.76, 0.77, 0.78, 0.81, 0.83, 0.83, 0.87,0.88, 0.91, 0.915, 0.924, 0.934, 0.954, 0.957, 0.959, 0.96, 0.959, 0.958, 0.956, 0.957, 0.958]\n",
|
467 |
+
"loss = [0.8, 0.5, 0.45, 0.30, 0.28, 0.22, 0.19, 0.18, 0.18, 0.15, 0.15, 0.15, 0.12, 0.13, 0.11, 0.09, 0.086, 0.083, 0.082, 0.077, 0.076, 0.074, 0.073, 0.072, 0.070, 0.069]\n",
|
468 |
+
"\n",
|
469 |
+
"epochs = range(1, len(accuracy) + 1)\n",
|
470 |
+
"\n",
|
471 |
+
"plt.figure(figsize=(12, 5))\n",
|
472 |
+
"\n",
|
473 |
+
"# Plotting accuracy\n",
|
474 |
+
"plt.subplot(1, 2, 1)\n",
|
475 |
+
"plt.plot(epochs, accuracy, 'bo', label='Training acc')\n",
|
476 |
+
"plt.title('Training accuracy')\n",
|
477 |
+
"plt.xlabel('Epochs')\n",
|
478 |
+
"plt.ylabel('Accuracy')\n",
|
479 |
+
"plt.legend()\n",
|
480 |
+
"\n",
|
481 |
+
"# Plotting loss\n",
|
482 |
+
"plt.subplot(1, 2, 2)\n",
|
483 |
+
"plt.plot(epochs, loss, 'bo', label='Training loss')\n",
|
484 |
+
"plt.title('Training loss')\n",
|
485 |
+
"plt.xlabel('Epochs')\n",
|
486 |
+
"plt.ylabel('Loss')\n",
|
487 |
+
"plt.legend()\n",
|
488 |
+
"plt.tight_layout()\n",
|
489 |
+
"plt.show()\n",
|
490 |
+
"\n",
|
491 |
+
"plt.savefig(\"accuracy loss.pdf\")\n"
|
492 |
+
]
|
493 |
+
}
|
494 |
+
],
|
495 |
+
"metadata": {
|
496 |
+
"kernelspec": {
|
497 |
+
"display_name": "textattackenv",
|
498 |
+
"language": "python",
|
499 |
+
"name": "python3"
|
500 |
+
},
|
501 |
+
"language_info": {
|
502 |
+
"codemirror_mode": {
|
503 |
+
"name": "ipython",
|
504 |
+
"version": 3
|
505 |
+
},
|
506 |
+
"file_extension": ".py",
|
507 |
+
"mimetype": "text/x-python",
|
508 |
+
"name": "python",
|
509 |
+
"nbconvert_exporter": "python",
|
510 |
+
"pygments_lexer": "ipython3",
|
511 |
+
"version": "3.8.18"
|
512 |
+
}
|
513 |
+
},
|
514 |
+
"nbformat": 4,
|
515 |
+
"nbformat_minor": 2
|
516 |
+
}
|
flow_correction_ag_news.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textattack
|
2 |
+
import transformers
|
3 |
+
import pandas as pd
|
4 |
+
import csv
|
5 |
+
import string
|
6 |
+
import pickle
|
7 |
+
# Construct our four components for `Attack`
|
8 |
+
from textattack.constraints.pre_transformation import (
|
9 |
+
RepeatModification,
|
10 |
+
StopwordModification,
|
11 |
+
)
|
12 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
13 |
+
from textattack.transformations import WordSwapEmbedding
|
14 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import json
|
18 |
+
import random
|
19 |
+
import re
|
20 |
+
import textattack.shared.attacked_text as atk
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
class InvertedText:
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
swapped_indexes,
|
30 |
+
score,
|
31 |
+
attacked_text,
|
32 |
+
new_class,
|
33 |
+
):
|
34 |
+
self.attacked_text = attacked_text
|
35 |
+
self.swapped_indexes = (
|
36 |
+
swapped_indexes # dict of swapped indexes with their synonym
|
37 |
+
)
|
38 |
+
self.score = score # value of original class
|
39 |
+
self.new_class = new_class # class after inversion
|
40 |
+
|
41 |
+
def __repr__(self):
|
42 |
+
return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
|
43 |
+
|
44 |
+
|
45 |
+
def count_matching_classes(original, corrected, perturbed_texts=None):
|
46 |
+
if len(original) != len(corrected):
|
47 |
+
raise ValueError("Arrays must have the same length")
|
48 |
+
hard_samples = []
|
49 |
+
easy_samples = []
|
50 |
+
|
51 |
+
matching_count = 0
|
52 |
+
|
53 |
+
for i in range(len(corrected)):
|
54 |
+
if original[i] == corrected[i]:
|
55 |
+
matching_count += 1
|
56 |
+
easy_samples.append(perturbed_texts[i])
|
57 |
+
elif perturbed_texts != None:
|
58 |
+
hard_samples.append(perturbed_texts[i])
|
59 |
+
|
60 |
+
return matching_count, hard_samples, easy_samples
|
61 |
+
|
62 |
+
|
63 |
+
class Flow_Corrector:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
attack,
|
67 |
+
word_rank_file="en_full_ranked.json",
|
68 |
+
word_freq_file="en_full_freq.json",
|
69 |
+
wir_threshold=0.3,
|
70 |
+
):
|
71 |
+
self.attack = attack
|
72 |
+
self.attack.cuda_()
|
73 |
+
self.wir_threshold = wir_threshold
|
74 |
+
with open(word_rank_file, "r") as f:
|
75 |
+
self.word_ranked_frequence = json.load(f)
|
76 |
+
with open(word_freq_file, "r") as f:
|
77 |
+
self.word_frequence = json.load(f)
|
78 |
+
self.victim_model = attack.goal_function.model
|
79 |
+
|
80 |
+
def wir_gradient(
|
81 |
+
self,
|
82 |
+
attack,
|
83 |
+
victim_model,
|
84 |
+
detected_text,
|
85 |
+
):
|
86 |
+
_, indices_to_order = attack.get_indices_to_order(detected_text)
|
87 |
+
|
88 |
+
index_scores = np.zeros(len(indices_to_order))
|
89 |
+
grad_output = victim_model.get_grad(detected_text.tokenizer_input)
|
90 |
+
gradient = grad_output["gradient"]
|
91 |
+
word2token_mapping = detected_text.align_with_model_tokens(victim_model)
|
92 |
+
for i, index in enumerate(indices_to_order):
|
93 |
+
matched_tokens = word2token_mapping[index]
|
94 |
+
if not matched_tokens:
|
95 |
+
index_scores[i] = 0.0
|
96 |
+
else:
|
97 |
+
agg_grad = np.mean(gradient[matched_tokens], axis=0)
|
98 |
+
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
|
99 |
+
index_order = np.array(indices_to_order)[(-index_scores).argsort()]
|
100 |
+
return index_order
|
101 |
+
|
102 |
+
def get_syn_freq_dict(
|
103 |
+
self,
|
104 |
+
index_order,
|
105 |
+
detected_text,
|
106 |
+
):
|
107 |
+
most_frequent_syn_dict = {}
|
108 |
+
|
109 |
+
no_syn = []
|
110 |
+
freq_thershold = len(self.word_ranked_frequence) / 10
|
111 |
+
|
112 |
+
for idx in index_order:
|
113 |
+
# get the synonyms of a specific index
|
114 |
+
|
115 |
+
try:
|
116 |
+
synonyms = [
|
117 |
+
attacked_text.words[idx]
|
118 |
+
for attacked_text in self.attack.get_transformations(
|
119 |
+
detected_text, detected_text, indices_to_modify=[idx]
|
120 |
+
)
|
121 |
+
]
|
122 |
+
# getting synonyms that exists in dataset with thiere frequency rank
|
123 |
+
ranked_synonyms = {
|
124 |
+
syn: self.word_ranked_frequence[syn]
|
125 |
+
for syn in synonyms
|
126 |
+
if syn in self.word_ranked_frequence.keys()
|
127 |
+
and self.word_ranked_frequence[syn] < freq_thershold
|
128 |
+
and self.word_ranked_frequence[detected_text.words[idx]]
|
129 |
+
> self.word_ranked_frequence[syn]
|
130 |
+
}
|
131 |
+
# selecting the M most frequent synonym
|
132 |
+
|
133 |
+
if list(ranked_synonyms.keys()) != []:
|
134 |
+
most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
|
135 |
+
except:
|
136 |
+
# no synonyms avaialble in the dataset
|
137 |
+
no_syn.append(idx)
|
138 |
+
|
139 |
+
return most_frequent_syn_dict
|
140 |
+
|
141 |
+
def build_candidates(
|
142 |
+
self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
|
143 |
+
):
|
144 |
+
candidates = {}
|
145 |
+
for _ in range(max_attempt):
|
146 |
+
syn_dict = {}
|
147 |
+
current_text = detected_text
|
148 |
+
for index in most_frequent_syn_dict.keys():
|
149 |
+
syn = random.choice(most_frequent_syn_dict[index])
|
150 |
+
syn_dict[index] = syn
|
151 |
+
current_text = current_text.replace_word_at_index(index, syn)
|
152 |
+
|
153 |
+
candidates[current_text] = syn_dict
|
154 |
+
return candidates
|
155 |
+
|
156 |
+
def find_dominant_class(self, inverted_texts):
|
157 |
+
class_counts = {} # Dictionary to store the count of each new class
|
158 |
+
|
159 |
+
for text in inverted_texts:
|
160 |
+
new_class = text.new_class
|
161 |
+
class_counts[new_class] = class_counts.get(new_class, 0) + 1
|
162 |
+
|
163 |
+
# Find the most dominant class
|
164 |
+
most_dominant_class = max(class_counts, key=class_counts.get)
|
165 |
+
|
166 |
+
return most_dominant_class
|
167 |
+
|
168 |
+
def correct(self, detected_texts):
|
169 |
+
corrected_classes = []
|
170 |
+
for detected_text in detected_texts:
|
171 |
+
|
172 |
+
# convert to Attacked texts
|
173 |
+
detected_text = atk.AttackedText(detected_text)
|
174 |
+
|
175 |
+
# getting 30% most important indexes
|
176 |
+
index_order = self.wir_gradient(
|
177 |
+
self.attack, self.victim_model, detected_text
|
178 |
+
)
|
179 |
+
index_order = index_order[: int(len(index_order) * self.wir_threshold)]
|
180 |
+
|
181 |
+
# getting synonyms according to frequency conditiontions
|
182 |
+
most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
|
183 |
+
|
184 |
+
# generate M candidates
|
185 |
+
candidates = self.build_candidates(
|
186 |
+
detected_text, most_frequent_syn_dict, max_attempt=100
|
187 |
+
)
|
188 |
+
|
189 |
+
original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
|
190 |
+
original_class = torch.argmax(original_probs).item()
|
191 |
+
original_golden_prob = float(original_probs[0][original_class])
|
192 |
+
|
193 |
+
nbr_inverted = 0
|
194 |
+
inverted_texts = [] # a dictionary of inverted texts with
|
195 |
+
bad, impr = 0, 0
|
196 |
+
dict_deltas = {}
|
197 |
+
|
198 |
+
batch_inputs = [candidate.text for candidate in candidates.keys()]
|
199 |
+
|
200 |
+
batch_outputs = self.victim_model(batch_inputs)
|
201 |
+
|
202 |
+
probabilities = F.softmax(batch_outputs, dim=1)
|
203 |
+
for i, (candidate, syn_dict) in enumerate(candidates.items()):
|
204 |
+
|
205 |
+
corrected_class = torch.argmax(probabilities[i]).item()
|
206 |
+
new_golden_probability = float(probabilities[i][corrected_class])
|
207 |
+
if corrected_class != original_class:
|
208 |
+
nbr_inverted += 1
|
209 |
+
inverted_texts.append(
|
210 |
+
InvertedText(
|
211 |
+
syn_dict, new_golden_probability, candidate, corrected_class
|
212 |
+
)
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
delta = new_golden_probability - original_golden_prob
|
216 |
+
if delta <= 0:
|
217 |
+
bad += 1
|
218 |
+
else:
|
219 |
+
impr += 1
|
220 |
+
dict_deltas[candidate] = delta
|
221 |
+
|
222 |
+
if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
|
223 |
+
len(original_probs[0])
|
224 |
+
):
|
225 |
+
# selecting the most dominant class
|
226 |
+
dominant_class = self.find_dominant_class(inverted_texts)
|
227 |
+
elif len(inverted_texts) >= len(candidates) / 2:
|
228 |
+
dominant_class = corrected_class
|
229 |
+
else:
|
230 |
+
dominant_class = original_class
|
231 |
+
|
232 |
+
corrected_classes.append(dominant_class)
|
233 |
+
|
234 |
+
return corrected_classes
|
235 |
+
|
236 |
+
|
237 |
+
def remove_brackets(text):
|
238 |
+
text = text.replace("[[", "")
|
239 |
+
text = text.replace("]]", "")
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
def clean_text(text):
|
244 |
+
pattern = "[" + re.escape(string.punctuation) + "]"
|
245 |
+
cleaned_text = re.sub(pattern, " ", text)
|
246 |
+
|
247 |
+
return cleaned_text
|
248 |
+
|
249 |
+
|
250 |
+
# Load model, tokenizer, and model_wrapper
|
251 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
252 |
+
"textattack/bert-base-uncased-ag-news"
|
253 |
+
)
|
254 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
255 |
+
"textattack/bert-base-uncased-ag-news"
|
256 |
+
)
|
257 |
+
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
258 |
+
|
259 |
+
|
260 |
+
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
261 |
+
constraints = [
|
262 |
+
RepeatModification(),
|
263 |
+
StopwordModification(),
|
264 |
+
WordEmbeddingDistance(min_cos_sim=0.9),
|
265 |
+
]
|
266 |
+
transformation = WordSwapEmbedding(max_candidates=50)
|
267 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
268 |
+
|
269 |
+
# Construct the actual attack
|
270 |
+
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
|
271 |
+
attack.cuda_()
|
272 |
+
|
273 |
+
|
274 |
+
results = pd.read_csv("ag_news_results.csv")
|
275 |
+
perturbed_texts = [
|
276 |
+
results["perturbed_text"][i]
|
277 |
+
for i in range(len(results))
|
278 |
+
if results["result_type"][i] == "Successful"
|
279 |
+
]
|
280 |
+
original_texts = [
|
281 |
+
results["original_text"][i]
|
282 |
+
for i in range(len(results))
|
283 |
+
if results["result_type"][i] == "Successful"
|
284 |
+
]
|
285 |
+
|
286 |
+
perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
|
287 |
+
original_texts = [remove_brackets(text) for text in original_texts]
|
288 |
+
|
289 |
+
perturbed_texts = [clean_text(text) for text in perturbed_texts]
|
290 |
+
original_texts = [clean_text(text) for text in original_texts]
|
291 |
+
|
292 |
+
|
293 |
+
victim_model = attack.goal_function.model
|
294 |
+
|
295 |
+
print("Getting corrected classes")
|
296 |
+
print("This may take a while ...")
|
297 |
+
# we can use directly resultds in csv file
|
298 |
+
original_classes = [
|
299 |
+
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
|
300 |
+
for original_text in original_texts
|
301 |
+
]
|
302 |
+
|
303 |
+
batch_size = 1000
|
304 |
+
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
|
305 |
+
batched_perturbed_texts = []
|
306 |
+
batched_original_texts = []
|
307 |
+
batched_original_classes = []
|
308 |
+
|
309 |
+
for i in range(num_batches):
|
310 |
+
start = i * batch_size
|
311 |
+
end = min(start + batch_size, len(perturbed_texts))
|
312 |
+
batched_perturbed_texts.append(perturbed_texts[start:end])
|
313 |
+
batched_original_texts.append(original_texts[start:end])
|
314 |
+
batched_original_classes.append(original_classes[start:end])
|
315 |
+
print(batched_original_classes)
|
316 |
+
hard_samples_list = []
|
317 |
+
easy_samples_list = []
|
318 |
+
|
319 |
+
|
320 |
+
# Open a CSV file for writing
|
321 |
+
csv_filename = "flow_correction_results_ag_news.csv"
|
322 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
323 |
+
fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
|
324 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
325 |
+
|
326 |
+
# Write the header row
|
327 |
+
writer.writeheader()
|
328 |
+
|
329 |
+
# Iterate over batched lists
|
330 |
+
batch_num = 0
|
331 |
+
for perturbed, original, classes in zip(
|
332 |
+
batched_perturbed_texts, batched_original_texts, batched_original_classes
|
333 |
+
):
|
334 |
+
batch_num += 1
|
335 |
+
print(f"Processing batch number: {batch_num}")
|
336 |
+
|
337 |
+
for i in range(2):
|
338 |
+
wir_threshold = 0.1 * (i + 1)
|
339 |
+
print(f"Setting Word threshold to: {wir_threshold}")
|
340 |
+
|
341 |
+
corrector = Flow_Corrector(
|
342 |
+
attack,
|
343 |
+
word_rank_file="en_full_ranked.json",
|
344 |
+
word_freq_file="en_full_freq.json",
|
345 |
+
wir_threshold=wir_threshold,
|
346 |
+
)
|
347 |
+
|
348 |
+
# Correct perturbed texts
|
349 |
+
print("Correcting perturbed texts...")
|
350 |
+
corrected_perturbed_classes = corrector.correct(perturbed)
|
351 |
+
|
352 |
+
match_perturbed, hard_samples, easy_samples = count_matching_classes(
|
353 |
+
classes, corrected_perturbed_classes, perturbed
|
354 |
+
)
|
355 |
+
hard_samples_list.extend(hard_samples)
|
356 |
+
easy_samples_list.extend(easy_samples)
|
357 |
+
|
358 |
+
|
359 |
+
print(f"Number of matching classes (perturbed): {match_perturbed}")
|
360 |
+
|
361 |
+
# Correct original texts
|
362 |
+
print("Correcting original texts...")
|
363 |
+
corrected_original_classes = corrector.correct(original)
|
364 |
+
match_original, hard_samples, easy_samples = count_matching_classes(
|
365 |
+
classes, corrected_original_classes, perturbed
|
366 |
+
)
|
367 |
+
print(f"Number of matching classes (original): {match_original}")
|
368 |
+
|
369 |
+
# Write results to CSV file
|
370 |
+
print("Writing results to CSV file...")
|
371 |
+
writer.writerow(
|
372 |
+
{
|
373 |
+
"freq_threshold": wir_threshold,
|
374 |
+
"batch_num": batch_num,
|
375 |
+
"match_perturbed": match_perturbed/len(perturbed),
|
376 |
+
"match_original": match_original/len(perturbed),
|
377 |
+
}
|
378 |
+
)
|
379 |
+
print("-" * 20)
|
380 |
+
|
381 |
+
print("savig samples for more statistics studies")
|
382 |
+
|
383 |
+
# Save hard_samples_list and easy_samples_list to files
|
384 |
+
with open('hard_samples.pkl', 'wb') as f:
|
385 |
+
pickle.dump(hard_samples_list, f)
|
386 |
+
|
387 |
+
with open('easy_samples.pkl', 'wb') as f:
|
388 |
+
pickle.dump(easy_samples_list, f)
|
flow_correction_imdb.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textattack
|
2 |
+
import transformers
|
3 |
+
import pandas as pd
|
4 |
+
import csv
|
5 |
+
import string
|
6 |
+
import pickle
|
7 |
+
# Construct our four components for `Attack`
|
8 |
+
from textattack.constraints.pre_transformation import (
|
9 |
+
RepeatModification,
|
10 |
+
StopwordModification,
|
11 |
+
)
|
12 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
13 |
+
from textattack.transformations import WordSwapEmbedding
|
14 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import json
|
18 |
+
import random
|
19 |
+
import re
|
20 |
+
import textattack.shared.attacked_text as atk
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
class InvertedText:
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
swapped_indexes,
|
30 |
+
score,
|
31 |
+
attacked_text,
|
32 |
+
new_class,
|
33 |
+
):
|
34 |
+
self.attacked_text = attacked_text
|
35 |
+
self.swapped_indexes = (
|
36 |
+
swapped_indexes # dict of swapped indexes with their synonym
|
37 |
+
)
|
38 |
+
self.score = score # value of original class
|
39 |
+
self.new_class = new_class # class after inversion
|
40 |
+
|
41 |
+
def __repr__(self):
|
42 |
+
return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
|
43 |
+
|
44 |
+
|
45 |
+
def count_matching_classes(original, corrected, perturbed_texts=None):
|
46 |
+
if len(original) != len(corrected):
|
47 |
+
raise ValueError("Arrays must have the same length")
|
48 |
+
hard_samples = []
|
49 |
+
easy_samples = []
|
50 |
+
|
51 |
+
matching_count = 0
|
52 |
+
|
53 |
+
for i in range(len(corrected)):
|
54 |
+
if original[i] == corrected[i]:
|
55 |
+
matching_count += 1
|
56 |
+
easy_samples.append(perturbed_texts[i])
|
57 |
+
elif perturbed_texts != None:
|
58 |
+
hard_samples.append(perturbed_texts[i])
|
59 |
+
|
60 |
+
return matching_count, hard_samples, easy_samples
|
61 |
+
|
62 |
+
|
63 |
+
class Flow_Corrector:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
attack,
|
67 |
+
word_rank_file="en_full_ranked.json",
|
68 |
+
word_freq_file="en_full_freq.json",
|
69 |
+
wir_threshold=0.3,
|
70 |
+
):
|
71 |
+
self.attack = attack
|
72 |
+
self.attack.cuda_()
|
73 |
+
self.wir_threshold = wir_threshold
|
74 |
+
with open(word_rank_file, "r") as f:
|
75 |
+
self.word_ranked_frequence = json.load(f)
|
76 |
+
with open(word_freq_file, "r") as f:
|
77 |
+
self.word_frequence = json.load(f)
|
78 |
+
self.victim_model = attack.goal_function.model
|
79 |
+
|
80 |
+
def wir_gradient(
|
81 |
+
self,
|
82 |
+
attack,
|
83 |
+
victim_model,
|
84 |
+
detected_text,
|
85 |
+
):
|
86 |
+
_, indices_to_order = attack.get_indices_to_order(detected_text)
|
87 |
+
|
88 |
+
index_scores = np.zeros(len(indices_to_order))
|
89 |
+
grad_output = victim_model.get_grad(detected_text.tokenizer_input)
|
90 |
+
gradient = grad_output["gradient"]
|
91 |
+
word2token_mapping = detected_text.align_with_model_tokens(victim_model)
|
92 |
+
for i, index in enumerate(indices_to_order):
|
93 |
+
matched_tokens = word2token_mapping[index]
|
94 |
+
if not matched_tokens:
|
95 |
+
index_scores[i] = 0.0
|
96 |
+
else:
|
97 |
+
agg_grad = np.mean(gradient[matched_tokens], axis=0)
|
98 |
+
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
|
99 |
+
index_order = np.array(indices_to_order)[(-index_scores).argsort()]
|
100 |
+
return index_order
|
101 |
+
|
102 |
+
def get_syn_freq_dict(
|
103 |
+
self,
|
104 |
+
index_order,
|
105 |
+
detected_text,
|
106 |
+
):
|
107 |
+
most_frequent_syn_dict = {}
|
108 |
+
|
109 |
+
no_syn = []
|
110 |
+
freq_thershold = len(self.word_ranked_frequence) / 10
|
111 |
+
|
112 |
+
for idx in index_order:
|
113 |
+
# get the synonyms of a specific index
|
114 |
+
|
115 |
+
try:
|
116 |
+
synonyms = [
|
117 |
+
attacked_text.words[idx]
|
118 |
+
for attacked_text in self.attack.get_transformations(
|
119 |
+
detected_text, detected_text, indices_to_modify=[idx]
|
120 |
+
)
|
121 |
+
]
|
122 |
+
# getting synonyms that exists in dataset with thiere frequency rank
|
123 |
+
ranked_synonyms = {
|
124 |
+
syn: self.word_ranked_frequence[syn]
|
125 |
+
for syn in synonyms
|
126 |
+
if syn in self.word_ranked_frequence.keys()
|
127 |
+
and self.word_ranked_frequence[syn] < freq_thershold
|
128 |
+
and self.word_ranked_frequence[detected_text.words[idx]]
|
129 |
+
> self.word_ranked_frequence[syn]
|
130 |
+
}
|
131 |
+
# selecting the M most frequent synonym
|
132 |
+
|
133 |
+
if list(ranked_synonyms.keys()) != []:
|
134 |
+
most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
|
135 |
+
except:
|
136 |
+
# no synonyms avaialble in the dataset
|
137 |
+
no_syn.append(idx)
|
138 |
+
|
139 |
+
return most_frequent_syn_dict
|
140 |
+
|
141 |
+
def build_candidates(
|
142 |
+
self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
|
143 |
+
):
|
144 |
+
candidates = {}
|
145 |
+
for _ in range(max_attempt):
|
146 |
+
syn_dict = {}
|
147 |
+
current_text = detected_text
|
148 |
+
for index in most_frequent_syn_dict.keys():
|
149 |
+
syn = random.choice(most_frequent_syn_dict[index])
|
150 |
+
syn_dict[index] = syn
|
151 |
+
current_text = current_text.replace_word_at_index(index, syn)
|
152 |
+
|
153 |
+
candidates[current_text] = syn_dict
|
154 |
+
return candidates
|
155 |
+
|
156 |
+
def find_dominant_class(self, inverted_texts):
|
157 |
+
class_counts = {} # Dictionary to store the count of each new class
|
158 |
+
|
159 |
+
for text in inverted_texts:
|
160 |
+
new_class = text.new_class
|
161 |
+
class_counts[new_class] = class_counts.get(new_class, 0) + 1
|
162 |
+
|
163 |
+
# Find the most dominant class
|
164 |
+
most_dominant_class = max(class_counts, key=class_counts.get)
|
165 |
+
|
166 |
+
return most_dominant_class
|
167 |
+
|
168 |
+
def correct(self, detected_texts):
|
169 |
+
corrected_classes = []
|
170 |
+
for detected_text in detected_texts:
|
171 |
+
|
172 |
+
# convert to Attacked texts
|
173 |
+
detected_text = atk.AttackedText(detected_text)
|
174 |
+
|
175 |
+
# getting 30% most important indexes
|
176 |
+
index_order = self.wir_gradient(
|
177 |
+
self.attack, self.victim_model, detected_text
|
178 |
+
)
|
179 |
+
index_order = index_order[: int(len(index_order) * self.wir_threshold)]
|
180 |
+
|
181 |
+
# getting synonyms according to frequency conditiontions
|
182 |
+
most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
|
183 |
+
|
184 |
+
# generate M candidates
|
185 |
+
candidates = self.build_candidates(
|
186 |
+
detected_text, most_frequent_syn_dict, max_attempt=100
|
187 |
+
)
|
188 |
+
|
189 |
+
original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
|
190 |
+
original_class = torch.argmax(original_probs).item()
|
191 |
+
original_golden_prob = float(original_probs[0][original_class])
|
192 |
+
|
193 |
+
nbr_inverted = 0
|
194 |
+
inverted_texts = [] # a dictionary of inverted texts with
|
195 |
+
bad, impr = 0, 0
|
196 |
+
dict_deltas = {}
|
197 |
+
|
198 |
+
batch_inputs = [candidate.text for candidate in candidates.keys()]
|
199 |
+
|
200 |
+
batch_outputs = self.victim_model(batch_inputs)
|
201 |
+
|
202 |
+
probabilities = F.softmax(batch_outputs, dim=1)
|
203 |
+
for i, (candidate, syn_dict) in enumerate(candidates.items()):
|
204 |
+
|
205 |
+
corrected_class = torch.argmax(probabilities[i]).item()
|
206 |
+
new_golden_probability = float(probabilities[i][corrected_class])
|
207 |
+
if corrected_class != original_class:
|
208 |
+
nbr_inverted += 1
|
209 |
+
inverted_texts.append(
|
210 |
+
InvertedText(
|
211 |
+
syn_dict, new_golden_probability, candidate, corrected_class
|
212 |
+
)
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
delta = new_golden_probability - original_golden_prob
|
216 |
+
if delta <= 0:
|
217 |
+
bad += 1
|
218 |
+
else:
|
219 |
+
impr += 1
|
220 |
+
dict_deltas[candidate] = delta
|
221 |
+
|
222 |
+
if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
|
223 |
+
len(original_probs[0])
|
224 |
+
):
|
225 |
+
# selecting the most dominant class
|
226 |
+
dominant_class = self.find_dominant_class(inverted_texts)
|
227 |
+
elif len(inverted_texts) >= len(candidates) / 2:
|
228 |
+
dominant_class = corrected_class
|
229 |
+
else:
|
230 |
+
dominant_class = original_class
|
231 |
+
|
232 |
+
corrected_classes.append(dominant_class)
|
233 |
+
|
234 |
+
return corrected_classes
|
235 |
+
|
236 |
+
|
237 |
+
def remove_brackets(text):
|
238 |
+
text = text.replace("[[", "")
|
239 |
+
text = text.replace("]]", "")
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
def clean_text(text):
|
244 |
+
pattern = "[" + re.escape(string.punctuation) + "]"
|
245 |
+
cleaned_text = re.sub(pattern, " ", text)
|
246 |
+
|
247 |
+
return cleaned_text
|
248 |
+
|
249 |
+
|
250 |
+
# Load model, tokenizer, and model_wrapper
|
251 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
252 |
+
"textattack/bert-base-uncased-imdb"
|
253 |
+
)
|
254 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
255 |
+
"textattack/bert-base-uncased-imdb"
|
256 |
+
)
|
257 |
+
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
258 |
+
|
259 |
+
|
260 |
+
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
261 |
+
constraints = [
|
262 |
+
RepeatModification(),
|
263 |
+
StopwordModification(),
|
264 |
+
WordEmbeddingDistance(min_cos_sim=0.9),
|
265 |
+
]
|
266 |
+
transformation = WordSwapEmbedding(max_candidates=50)
|
267 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
268 |
+
|
269 |
+
# Construct the actual attack
|
270 |
+
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
|
271 |
+
attack.cuda_()
|
272 |
+
|
273 |
+
|
274 |
+
results = pd.read_csv("IMDB_results.csv")
|
275 |
+
perturbed_texts = [
|
276 |
+
results["perturbed_text"][i]
|
277 |
+
for i in range(len(results))
|
278 |
+
if results["result_type"][i] == "Successful"
|
279 |
+
]
|
280 |
+
original_texts = [
|
281 |
+
results["original_text"][i]
|
282 |
+
for i in range(len(results))
|
283 |
+
if results["result_type"][i] == "Successful"
|
284 |
+
]
|
285 |
+
|
286 |
+
perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
|
287 |
+
original_texts = [remove_brackets(text) for text in original_texts]
|
288 |
+
|
289 |
+
perturbed_texts = [clean_text(text) for text in perturbed_texts]
|
290 |
+
original_texts = [clean_text(text) for text in original_texts]
|
291 |
+
|
292 |
+
|
293 |
+
victim_model = attack.goal_function.model
|
294 |
+
|
295 |
+
print("Getting corrected classes")
|
296 |
+
print("This may take a while ...")
|
297 |
+
# we can use directly resultds in csv file
|
298 |
+
original_classes = [
|
299 |
+
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
|
300 |
+
for original_text in original_texts
|
301 |
+
]
|
302 |
+
|
303 |
+
batch_size = 1000
|
304 |
+
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
|
305 |
+
batched_perturbed_texts = []
|
306 |
+
batched_original_texts = []
|
307 |
+
batched_original_classes = []
|
308 |
+
|
309 |
+
for i in range(num_batches):
|
310 |
+
start = i * batch_size
|
311 |
+
end = min(start + batch_size, len(perturbed_texts))
|
312 |
+
batched_perturbed_texts.append(perturbed_texts[start:end])
|
313 |
+
batched_original_texts.append(original_texts[start:end])
|
314 |
+
batched_original_classes.append(original_classes[start:end])
|
315 |
+
print(batched_original_classes)
|
316 |
+
hard_samples_list = []
|
317 |
+
easy_samples_list = []
|
318 |
+
|
319 |
+
|
320 |
+
# Open a CSV file for writing
|
321 |
+
csv_filename = "flow_correction_results_imdb.csv"
|
322 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
323 |
+
fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
|
324 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
325 |
+
|
326 |
+
# Write the header row
|
327 |
+
writer.writeheader()
|
328 |
+
|
329 |
+
# Iterate over batched lists
|
330 |
+
batch_num = 0
|
331 |
+
for perturbed, original, classes in zip(
|
332 |
+
batched_perturbed_texts, batched_original_texts, batched_original_classes
|
333 |
+
):
|
334 |
+
batch_num += 1
|
335 |
+
print(f"Processing batch number: {batch_num}")
|
336 |
+
|
337 |
+
for i in range(2):
|
338 |
+
wir_threshold = 0.1 * (i + 1)
|
339 |
+
print(f"Setting Word threshold to: {wir_threshold}")
|
340 |
+
|
341 |
+
corrector = Flow_Corrector(
|
342 |
+
attack,
|
343 |
+
word_rank_file="en_full_ranked.json",
|
344 |
+
word_freq_file="en_full_freq.json",
|
345 |
+
wir_threshold=wir_threshold,
|
346 |
+
)
|
347 |
+
|
348 |
+
# Correct perturbed texts
|
349 |
+
print("Correcting perturbed texts...")
|
350 |
+
corrected_perturbed_classes = corrector.correct(perturbed)
|
351 |
+
|
352 |
+
match_perturbed, hard_samples, easy_samples = count_matching_classes(
|
353 |
+
classes, corrected_perturbed_classes, perturbed
|
354 |
+
)
|
355 |
+
hard_samples_list.extend(hard_samples)
|
356 |
+
easy_samples_list.extend(easy_samples)
|
357 |
+
|
358 |
+
|
359 |
+
print(f"Number of matching classes (perturbed): {match_perturbed}")
|
360 |
+
|
361 |
+
# Correct original texts
|
362 |
+
print("Correcting original texts...")
|
363 |
+
corrected_original_classes = corrector.correct(original)
|
364 |
+
match_original, hard_samples, easy_samples = count_matching_classes(
|
365 |
+
classes, corrected_original_classes, perturbed
|
366 |
+
)
|
367 |
+
print(f"Number of matching classes (original): {match_original}")
|
368 |
+
|
369 |
+
# Write results to CSV file
|
370 |
+
print("Writing results to CSV file...")
|
371 |
+
writer.writerow(
|
372 |
+
{
|
373 |
+
"freq_threshold": wir_threshold,
|
374 |
+
"batch_num": batch_num,
|
375 |
+
"match_perturbed": match_perturbed/len(perturbed),
|
376 |
+
"match_original": match_original/len(perturbed),
|
377 |
+
}
|
378 |
+
)
|
379 |
+
print("-" * 20)
|
380 |
+
|
381 |
+
print("savig samples for more statistics studies")
|
382 |
+
|
383 |
+
# Save hard_samples_list and easy_samples_list to files
|
384 |
+
with open('hard_samples.pkl', 'wb') as f:
|
385 |
+
pickle.dump(hard_samples_list, f)
|
386 |
+
|
387 |
+
with open('easy_samples.pkl', 'wb') as f:
|
388 |
+
pickle.dump(easy_samples_list, f)
|
gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dev files
|
2 |
+
*.cache
|
3 |
+
*.dev.py
|
4 |
+
state_dict/
|
5 |
+
TAD*/
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
*.pyc
|
11 |
+
tests/
|
12 |
+
*.result.json
|
13 |
+
.idea/
|
14 |
+
|
15 |
+
# Embedding
|
16 |
+
glove.840B.300d.txt
|
17 |
+
glove.42B.300d.txt
|
18 |
+
glove.twitter.27B.txt
|
19 |
+
|
20 |
+
# project main files
|
21 |
+
release_note.json
|
22 |
+
|
23 |
+
# C extensions
|
24 |
+
*.so
|
25 |
+
|
26 |
+
# Distribution / packaging
|
27 |
+
.Python
|
28 |
+
build/
|
29 |
+
develop-eggs/
|
30 |
+
dist/
|
31 |
+
downloads/
|
32 |
+
eggs/
|
33 |
+
.eggs/
|
34 |
+
lib64/
|
35 |
+
parts/
|
36 |
+
sdist/
|
37 |
+
var/
|
38 |
+
wheels/
|
39 |
+
pip-wheel-metadata/
|
40 |
+
share/python-wheels/
|
41 |
+
*.egg-info/
|
42 |
+
.installed.cfg
|
43 |
+
*.egg
|
44 |
+
MANIFEST
|
45 |
+
|
46 |
+
# PyInstaller
|
47 |
+
# Usually these files are written by a python script from a template
|
48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
49 |
+
*.manifest
|
50 |
+
*.spec
|
51 |
+
|
52 |
+
# Installer training_logs
|
53 |
+
pip-log.txt
|
54 |
+
pip-delete-this-directory.txt
|
55 |
+
|
56 |
+
# Unit test / coverage reports
|
57 |
+
htmlcov/
|
58 |
+
.tox/
|
59 |
+
.nox/
|
60 |
+
.coverage
|
61 |
+
.coverage.*
|
62 |
+
.cache
|
63 |
+
nosetests.xml
|
64 |
+
coverage.xml
|
65 |
+
*.cover
|
66 |
+
*.py,cover
|
67 |
+
.hypothesis/
|
68 |
+
.pytest_cache/
|
69 |
+
|
70 |
+
# Translations
|
71 |
+
*.mo
|
72 |
+
*.pot
|
73 |
+
|
74 |
+
# Django stuff:
|
75 |
+
*.log
|
76 |
+
local_settings.py
|
77 |
+
db.sqlite3
|
78 |
+
db.sqlite3-journal
|
79 |
+
|
80 |
+
# Flask stuff:
|
81 |
+
instance/
|
82 |
+
.webassets-cache
|
83 |
+
|
84 |
+
# Scrapy stuff:
|
85 |
+
.scrapy
|
86 |
+
|
87 |
+
# Sphinx documentation
|
88 |
+
docs/_build/
|
89 |
+
|
90 |
+
# PyBuilder
|
91 |
+
target/
|
92 |
+
|
93 |
+
# Jupyter Notebook
|
94 |
+
.ipynb_checkpoints
|
95 |
+
|
96 |
+
# IPython
|
97 |
+
profile_default/
|
98 |
+
ipython_config.py
|
99 |
+
|
100 |
+
# pyenv
|
101 |
+
.python-version
|
102 |
+
|
103 |
+
# pipenv
|
104 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
105 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
106 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
107 |
+
# install all needed dependencies.
|
108 |
+
#Pipfile.lock
|
109 |
+
|
110 |
+
# celery beat schedule file
|
111 |
+
celerybeat-schedule
|
112 |
+
|
113 |
+
# SageMath parsed files
|
114 |
+
*.sage.py
|
115 |
+
|
116 |
+
# Environments
|
117 |
+
.env
|
118 |
+
.venv
|
119 |
+
env/
|
120 |
+
venv/
|
121 |
+
ENV/
|
122 |
+
env.bak/
|
123 |
+
venv.bak/
|
124 |
+
|
125 |
+
# Spyder project settings
|
126 |
+
.spyderproject
|
127 |
+
.spyproject
|
128 |
+
|
129 |
+
# Rope project settings
|
130 |
+
.ropeproject
|
131 |
+
|
132 |
+
# mkdocs documentation
|
133 |
+
/site
|
134 |
+
|
135 |
+
# mypy
|
136 |
+
.mypy_cache/
|
137 |
+
.dmypy.json
|
138 |
+
dmypy.json
|
139 |
+
|
140 |
+
# Pyre type checker
|
141 |
+
.pyre/
|
142 |
+
.DS_Store
|
143 |
+
examples/.DS_Store
|
main_correction.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textattack
|
2 |
+
import transformers
|
3 |
+
from FlowCorrector import Flow_Corrector
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
def count_matching_classes(original, corrected):
|
8 |
+
if len(original) != len(corrected):
|
9 |
+
raise ValueError("Arrays must have the same length")
|
10 |
+
|
11 |
+
matching_count = 0
|
12 |
+
|
13 |
+
for i in range(len(corrected)):
|
14 |
+
if original[i] == corrected[i]:
|
15 |
+
matching_count += 1
|
16 |
+
|
17 |
+
return matching_count
|
18 |
+
|
19 |
+
if __name__ == "main" :
|
20 |
+
|
21 |
+
# Load model, tokenizer, and model_wrapper
|
22 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
23 |
+
"textattack/bert-base-uncased-ag-news"
|
24 |
+
)
|
25 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
26 |
+
"textattack/bert-base-uncased-ag-news"
|
27 |
+
)
|
28 |
+
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
29 |
+
|
30 |
+
# Construct our four components for `Attack`
|
31 |
+
from textattack.constraints.pre_transformation import (
|
32 |
+
RepeatModification,
|
33 |
+
StopwordModification,
|
34 |
+
)
|
35 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
36 |
+
from textattack.transformations import WordSwapEmbedding
|
37 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
38 |
+
|
39 |
+
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
40 |
+
constraints = [
|
41 |
+
RepeatModification(),
|
42 |
+
StopwordModification(),
|
43 |
+
WordEmbeddingDistance(min_cos_sim=0.9),
|
44 |
+
]
|
45 |
+
transformation = WordSwapEmbedding(max_candidates=50)
|
46 |
+
search_method = GreedyWordSwapWIR(wir_method="weighted-saliency")
|
47 |
+
|
48 |
+
# Construct the actual attack
|
49 |
+
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
|
50 |
+
attack.cuda_()
|
51 |
+
|
52 |
+
# intialisation de coreecteur
|
53 |
+
corrector = Flow_Corrector(
|
54 |
+
attack,
|
55 |
+
word_rank_file="en_full_ranked.json",
|
56 |
+
word_freq_file="en_full_freq.json",
|
57 |
+
)
|
58 |
+
|
59 |
+
# All these texts are adverserial ones
|
60 |
+
|
61 |
+
with open('perturbed_texts_ag_news.txt', 'r') as f:
|
62 |
+
detected_texts = [line.strip() for line in f]
|
63 |
+
|
64 |
+
|
65 |
+
#These are orginal texts in same order of adverserial ones
|
66 |
+
|
67 |
+
with open("original_texts_ag_news.txt", "r") as f:
|
68 |
+
original_texts = [line.strip() for line in f]
|
69 |
+
|
70 |
+
victim_model = attack.goal_function.model
|
71 |
+
|
72 |
+
# getting original labels for benchmarking later
|
73 |
+
original_classes = [
|
74 |
+
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
|
75 |
+
for original_text in original_texts
|
76 |
+
]
|
77 |
+
|
78 |
+
""" 0 :World
|
79 |
+
1 : Sports
|
80 |
+
2 : Business
|
81 |
+
3 : Sci/Tech"""
|
82 |
+
|
83 |
+
corrected_classes = corrector.correct(original_texts)
|
84 |
+
print(f"match {count_matching_classes()}")
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
findfile>=1.7.9.8
|
2 |
+
autocuda>=0.11
|
3 |
+
metric-visualizer>=0.5.5
|
4 |
+
boostaug>=2.2.3
|
5 |
+
spacy
|
6 |
+
networkx
|
7 |
+
seqeval
|
8 |
+
update-checker
|
9 |
+
typing_extensions
|
10 |
+
tqdm
|
11 |
+
pytorch_warmup
|
12 |
+
termcolor
|
13 |
+
gitpython
|
14 |
+
gdown>=4.4.0
|
15 |
+
transformers>4.20.0
|
16 |
+
torch>1.0.0
|
17 |
+
sentencepiece
|
18 |
+
tensorflow_text
|
19 |
+
tensorflow_hub
|
20 |
+
tensorflow>=2.6.0
|
21 |
+
datasets
|
22 |
+
textattack
|
23 |
+
jieba
|
24 |
+
pycld2
|
25 |
+
OpenHowNet
|
26 |
+
pinyin
|
27 |
+
flask
|
28 |
+
|
text_defense/201.SST2/stsa.binary.dev.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/201.SST2/stsa.binary.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/201.SST2/stsa.binary.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/202.IMDB10K/imdb10k.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/202.IMDB10K/imdb10k.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/202.IMDB10K/imdb10k.valid.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/204.AGNews10K/AGNews10K.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/204.AGNews10K/AGNews10K.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/204.AGNews10K/AGNews10K.valid.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
textattack/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Welcome to the API references for TextAttack!
|
2 |
+
|
3 |
+
What is TextAttack?
|
4 |
+
|
5 |
+
`TextAttack <https://github.com/QData/TextAttack>`__ is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
|
6 |
+
|
7 |
+
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It's also useful for NLP model training, adversarial training, and data augmentation.
|
8 |
+
|
9 |
+
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
|
10 |
+
"""
|
11 |
+
from .attack_args import AttackArgs, CommandLineAttackArgs
|
12 |
+
from .augment_args import AugmenterArgs
|
13 |
+
from .dataset_args import DatasetArgs
|
14 |
+
from .model_args import ModelArgs
|
15 |
+
from .training_args import TrainingArgs, CommandLineTrainingArgs
|
16 |
+
from .attack import Attack
|
17 |
+
from .attacker import Attacker
|
18 |
+
from .trainer import Trainer
|
19 |
+
from .metrics import Metric
|
20 |
+
|
21 |
+
from . import (
|
22 |
+
attack_recipes,
|
23 |
+
attack_results,
|
24 |
+
augmentation,
|
25 |
+
commands,
|
26 |
+
constraints,
|
27 |
+
datasets,
|
28 |
+
goal_function_results,
|
29 |
+
goal_functions,
|
30 |
+
loggers,
|
31 |
+
metrics,
|
32 |
+
models,
|
33 |
+
search_methods,
|
34 |
+
shared,
|
35 |
+
transformations,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
name = "textattack"
|
textattack/__main__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
import textattack
|
5 |
+
|
6 |
+
textattack.commands.textattack_cli.main()
|