Spaces:
Runtime error
Runtime error
Delete expbert.py
Browse files- expbert.py +0 -284
expbert.py
DELETED
@@ -1,284 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
import random
|
5 |
-
from datetime import datetime
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
from sklearn.metrics import accuracy_score, f1_score
|
10 |
-
from torch import nn
|
11 |
-
from torch.utils.data import DataLoader, Dataset
|
12 |
-
from tqdm import tqdm
|
13 |
-
from transformers import (AutoConfig, AutoModel,
|
14 |
-
AutoModelForSequenceClassification, AutoTokenizer,
|
15 |
-
BertForSequenceClassification, BertModel)
|
16 |
-
|
17 |
-
if not os.path.exists('logs/'):
|
18 |
-
os.mkdir('logs/')
|
19 |
-
|
20 |
-
logging.basicConfig(
|
21 |
-
filename='logs/expbert-{}.log'.format(str(datetime.now())),
|
22 |
-
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
23 |
-
datefmt='%m/%d/%Y %H:%M:%S',
|
24 |
-
level=logging.INFO)
|
25 |
-
logger = logging.getLogger(__name__)
|
26 |
-
|
27 |
-
|
28 |
-
TASK2PATH = {
|
29 |
-
"disease-train": "data/disease/train.txt",
|
30 |
-
"disease-test": "data/disease/test.txt",
|
31 |
-
"spouse-train": "data/spouse/train.txt",
|
32 |
-
"spouse-test": "data/spouse/test.txt",
|
33 |
-
}
|
34 |
-
|
35 |
-
ANNOTATED_EXP = {
|
36 |
-
"spouse": "data/exp/expbert_spouse_explanation.txt",
|
37 |
-
"disease": "data/exp/expbert_disease_explanation.txt",
|
38 |
-
}
|
39 |
-
|
40 |
-
GENERATED_EXP = {
|
41 |
-
"spouse": "data/exp/orion_spouse_explanation.txt",
|
42 |
-
"disease": "data/exp/orion_disease_explanation.txt",
|
43 |
-
}
|
44 |
-
|
45 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
-
|
47 |
-
def set_random_seed(seed):
|
48 |
-
random.seed(seed)
|
49 |
-
np.random.seed(seed)
|
50 |
-
torch.manual_seed(seed)
|
51 |
-
if torch.cuda.is_available():
|
52 |
-
torch.cuda.manual_seed(seed)
|
53 |
-
torch.cuda.manual_seed_all(seed)
|
54 |
-
torch.backends.cudnn.deterministic = True
|
55 |
-
torch.backends.cudnn.benchmark = False
|
56 |
-
|
57 |
-
|
58 |
-
def print_config(config):
|
59 |
-
config = vars(config)
|
60 |
-
logger.info("**************** MODEL CONFIGURATION ****************")
|
61 |
-
for key in sorted(config.keys()):
|
62 |
-
val = config[key]
|
63 |
-
keystr = "{}".format(key) + (" " * (25 - len(key)))
|
64 |
-
logger.info("{} --> {}".format(keystr, val))
|
65 |
-
logger.info("**************** MODEL CONFIGURATION ****************")
|
66 |
-
|
67 |
-
|
68 |
-
class ExpBERT(nn.Module):
|
69 |
-
def __init__(self, args, exp_num):
|
70 |
-
super(ExpBERT, self).__init__()
|
71 |
-
self.args = args
|
72 |
-
self.exp_num = exp_num
|
73 |
-
self.config = AutoConfig.from_pretrained(args.model)
|
74 |
-
self.model = AutoModel.from_pretrained(args.model, config=self.config)
|
75 |
-
self.dropout = nn.Dropout(p=0.1)
|
76 |
-
self.linear = nn.Linear(self.config.hidden_size * exp_num, 2)
|
77 |
-
|
78 |
-
self.criterion = nn.CrossEntropyLoss()
|
79 |
-
|
80 |
-
def forward(self, inputs):
|
81 |
-
for k, v in inputs["encoding"].items():
|
82 |
-
inputs["encoding"][k] = v.to(device)
|
83 |
-
pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
|
84 |
-
pooler_output = self.dropout(pooler_output)
|
85 |
-
logits = self.linear(pooler_output)
|
86 |
-
|
87 |
-
loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).to(device))
|
88 |
-
prediction = torch.argmax(logits)
|
89 |
-
|
90 |
-
return {
|
91 |
-
"loss": loss,
|
92 |
-
"prediction": prediction,
|
93 |
-
}
|
94 |
-
|
95 |
-
|
96 |
-
class REDataset(Dataset):
|
97 |
-
def __init__(self, path, exp, tokenizer):
|
98 |
-
super(REDataset, self).__init__()
|
99 |
-
self.tokenizer = tokenizer
|
100 |
-
self.exp = exp
|
101 |
-
self.sentences = []
|
102 |
-
self.labels = []
|
103 |
-
self.entities = []
|
104 |
-
with open(path, "r", encoding="utf-8") as file:
|
105 |
-
data = file.readlines()
|
106 |
-
for example in data:
|
107 |
-
sentence, entity1, entity2, id, label = example.strip().split("\t")
|
108 |
-
self.sentences.append(sentence)
|
109 |
-
if eval(label) == 1:
|
110 |
-
self.labels.append(1)
|
111 |
-
elif eval(label) == -1:
|
112 |
-
self.labels.append(0)
|
113 |
-
|
114 |
-
self.entities.append([entity1, entity2])
|
115 |
-
|
116 |
-
logger.info("Number of Example in {}: {}".format(path, str(len(self.labels))))
|
117 |
-
|
118 |
-
def __len__(self):
|
119 |
-
return len(self.labels)
|
120 |
-
|
121 |
-
def __getitem__(self, index):
|
122 |
-
return {
|
123 |
-
"sentence": self.sentences[index],
|
124 |
-
"entity": self.entities[index],
|
125 |
-
"label": self.labels[index],
|
126 |
-
}
|
127 |
-
|
128 |
-
def collate_fn(self, batch):
|
129 |
-
outputs = []
|
130 |
-
for ex in batch:
|
131 |
-
temp = []
|
132 |
-
for exp in self.exp:
|
133 |
-
if "{e1}" in exp or "{e2}" in exp:
|
134 |
-
exp = exp.replace("{e1}", ex["entity"][0]).replace("{e2}", ex["entity"][1])
|
135 |
-
else:
|
136 |
-
for entity in ex["entity"]:
|
137 |
-
index = exp.index('<mask>')
|
138 |
-
exp = exp[:index] + entity + exp[index + len('<mask>'):]
|
139 |
-
temp.append(exp)
|
140 |
-
outputs.append(
|
141 |
-
{
|
142 |
-
"encoding": self.tokenizer(
|
143 |
-
[ex["sentence"]] * len(temp), temp,
|
144 |
-
add_special_tokens=True,
|
145 |
-
padding="longest",
|
146 |
-
truncation=True,
|
147 |
-
max_length=156,
|
148 |
-
return_tensors="pt",
|
149 |
-
return_attention_mask=True,
|
150 |
-
return_token_type_ids=True,
|
151 |
-
),
|
152 |
-
"label": ex["label"],
|
153 |
-
}
|
154 |
-
)
|
155 |
-
return outputs
|
156 |
-
|
157 |
-
def collate_fn_(self, batch):
|
158 |
-
texts = []
|
159 |
-
labels = []
|
160 |
-
for ex in batch:
|
161 |
-
texts.append(ex["sentence"])
|
162 |
-
labels.append(ex["label"])
|
163 |
-
|
164 |
-
outputs = self.tokenizer(
|
165 |
-
texts,
|
166 |
-
add_special_tokens=True,
|
167 |
-
padding="longest",
|
168 |
-
truncation=True,
|
169 |
-
max_length=156,
|
170 |
-
return_tensors="pt",
|
171 |
-
return_attention_mask=True,
|
172 |
-
return_token_type_ids=True,
|
173 |
-
)
|
174 |
-
|
175 |
-
outputs["labels"] = torch.LongTensor(labels)
|
176 |
-
|
177 |
-
return outputs
|
178 |
-
|
179 |
-
|
180 |
-
class Trainer(object):
|
181 |
-
def __init__(self, args):
|
182 |
-
self.args = args
|
183 |
-
print_config(args)
|
184 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model)
|
185 |
-
|
186 |
-
TASK2EXP = GENERATED_EXP if args.generated_rules else ANNOTATED_EXP
|
187 |
-
with open(TASK2EXP[args.task], "r", encoding="utf-8") as file:
|
188 |
-
exp = file.readlines()
|
189 |
-
|
190 |
-
self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
|
191 |
-
self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
|
192 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) if self.args.no_exp else ExpBERT(args, len(exp)).to(device)
|
193 |
-
|
194 |
-
self.train_loader = DataLoader(
|
195 |
-
self.train_dataset,
|
196 |
-
batch_size=args.batch_size,
|
197 |
-
shuffle=args.shuffle,
|
198 |
-
collate_fn=self.train_dataset.collate_fn_ if self.args.no_exp else self.train_dataset.collate_fn,
|
199 |
-
)
|
200 |
-
|
201 |
-
self.test_loader = DataLoader(
|
202 |
-
self.test_dataset,
|
203 |
-
batch_size=args.batch_size,
|
204 |
-
shuffle=args.shuffle,
|
205 |
-
collate_fn=self.test_dataset.collate_fn_ if self.args.no_exp else self.test_dataset.collate_fn,
|
206 |
-
)
|
207 |
-
|
208 |
-
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.learning_rate)
|
209 |
-
|
210 |
-
def compute_metrics(self, labels, predictions):
|
211 |
-
accuracy = accuracy_score(y_pred=predictions, y_true=labels)
|
212 |
-
f1 = f1_score(y_pred=predictions, y_true=labels)
|
213 |
-
|
214 |
-
return accuracy, f1
|
215 |
-
|
216 |
-
def train(self):
|
217 |
-
self.model.train()
|
218 |
-
self.test(-1)
|
219 |
-
for e in range(self.args.epochs):
|
220 |
-
with tqdm(total=len(self.train_loader)) as pbar:
|
221 |
-
for step, examples in enumerate(self.train_loader):
|
222 |
-
self.model.zero_grad()
|
223 |
-
if self.args.no_exp:
|
224 |
-
for k, v in examples.items():
|
225 |
-
examples[k] = v.to(device)
|
226 |
-
outputs = self.model(**examples)
|
227 |
-
outputs.loss.backward()
|
228 |
-
|
229 |
-
else:
|
230 |
-
for ex in examples:
|
231 |
-
outputs = self.model(ex)
|
232 |
-
(outputs["loss"] / len(examples)).backward()
|
233 |
-
|
234 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
235 |
-
self.optimizer.step()
|
236 |
-
pbar.update(1)
|
237 |
-
|
238 |
-
self.test(e)
|
239 |
-
|
240 |
-
def test(self, epoch):
|
241 |
-
self.model.eval()
|
242 |
-
with torch.no_grad():
|
243 |
-
with tqdm(total=len(self.test_loader)) as pbar:
|
244 |
-
loss = []
|
245 |
-
labels = []
|
246 |
-
predictions = []
|
247 |
-
for step, examples in enumerate(self.test_loader):
|
248 |
-
if self.args.no_exp:
|
249 |
-
for k, v in examples.items():
|
250 |
-
examples[k] = v.to(device)
|
251 |
-
outputs = self.model(**examples)
|
252 |
-
loss.append(outputs.loss.float())
|
253 |
-
labels.extend(examples["labels"].tolist())
|
254 |
-
predictions.extend(torch.argmax(outputs.logits, dim=1).tolist())
|
255 |
-
|
256 |
-
else:
|
257 |
-
for ex in examples:
|
258 |
-
labels.append(ex['label'])
|
259 |
-
outputs = self.model(ex)
|
260 |
-
loss.append(outputs["loss"].item())
|
261 |
-
predictions.append(outputs['prediction'].tolist())
|
262 |
-
|
263 |
-
pbar.update(1)
|
264 |
-
accuracy, f1 = self.compute_metrics(predictions, labels)
|
265 |
-
logger.info("[EPOCH {}] Accuracy: {} | F1-Score: {}. (Number of Data {})".format(epoch, accuracy, f1, len(predictions)))
|
266 |
-
|
267 |
-
|
268 |
-
if __name__ == "__main__":
|
269 |
-
parser = argparse.ArgumentParser()
|
270 |
-
parser.add_argument("--task", type=str, default="spouse")
|
271 |
-
parser.add_argument("--model", type=str, default="bert-base-uncased")
|
272 |
-
parser.add_argument("--batch_size", type=int, default=32)
|
273 |
-
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
274 |
-
parser.add_argument("--shuffle", type=bool, default=False)
|
275 |
-
parser.add_argument("--epochs", type=int, default=5)
|
276 |
-
parser.add_argument("--no_exp", type=bool, default=False)
|
277 |
-
parser.add_argument("--generated_rules", type=bool, default=False)
|
278 |
-
|
279 |
-
args = parser.parse_args()
|
280 |
-
|
281 |
-
for seed in range(42, 47):
|
282 |
-
set_random_seed(seed)
|
283 |
-
trainer = Trainer(args)
|
284 |
-
trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|