andreslu commited on
Commit
2b6af87
·
1 Parent(s): f184a52

Delete expbert.py

Browse files
Files changed (1) hide show
  1. expbert.py +0 -284
expbert.py DELETED
@@ -1,284 +0,0 @@
1
- import argparse
2
- import logging
3
- import os
4
- import random
5
- from datetime import datetime
6
-
7
- import numpy as np
8
- import torch
9
- from sklearn.metrics import accuracy_score, f1_score
10
- from torch import nn
11
- from torch.utils.data import DataLoader, Dataset
12
- from tqdm import tqdm
13
- from transformers import (AutoConfig, AutoModel,
14
- AutoModelForSequenceClassification, AutoTokenizer,
15
- BertForSequenceClassification, BertModel)
16
-
17
- if not os.path.exists('logs/'):
18
- os.mkdir('logs/')
19
-
20
- logging.basicConfig(
21
- filename='logs/expbert-{}.log'.format(str(datetime.now())),
22
- format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
23
- datefmt='%m/%d/%Y %H:%M:%S',
24
- level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- TASK2PATH = {
29
- "disease-train": "data/disease/train.txt",
30
- "disease-test": "data/disease/test.txt",
31
- "spouse-train": "data/spouse/train.txt",
32
- "spouse-test": "data/spouse/test.txt",
33
- }
34
-
35
- ANNOTATED_EXP = {
36
- "spouse": "data/exp/expbert_spouse_explanation.txt",
37
- "disease": "data/exp/expbert_disease_explanation.txt",
38
- }
39
-
40
- GENERATED_EXP = {
41
- "spouse": "data/exp/orion_spouse_explanation.txt",
42
- "disease": "data/exp/orion_disease_explanation.txt",
43
- }
44
-
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
-
47
- def set_random_seed(seed):
48
- random.seed(seed)
49
- np.random.seed(seed)
50
- torch.manual_seed(seed)
51
- if torch.cuda.is_available():
52
- torch.cuda.manual_seed(seed)
53
- torch.cuda.manual_seed_all(seed)
54
- torch.backends.cudnn.deterministic = True
55
- torch.backends.cudnn.benchmark = False
56
-
57
-
58
- def print_config(config):
59
- config = vars(config)
60
- logger.info("**************** MODEL CONFIGURATION ****************")
61
- for key in sorted(config.keys()):
62
- val = config[key]
63
- keystr = "{}".format(key) + (" " * (25 - len(key)))
64
- logger.info("{} --> {}".format(keystr, val))
65
- logger.info("**************** MODEL CONFIGURATION ****************")
66
-
67
-
68
- class ExpBERT(nn.Module):
69
- def __init__(self, args, exp_num):
70
- super(ExpBERT, self).__init__()
71
- self.args = args
72
- self.exp_num = exp_num
73
- self.config = AutoConfig.from_pretrained(args.model)
74
- self.model = AutoModel.from_pretrained(args.model, config=self.config)
75
- self.dropout = nn.Dropout(p=0.1)
76
- self.linear = nn.Linear(self.config.hidden_size * exp_num, 2)
77
-
78
- self.criterion = nn.CrossEntropyLoss()
79
-
80
- def forward(self, inputs):
81
- for k, v in inputs["encoding"].items():
82
- inputs["encoding"][k] = v.to(device)
83
- pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
84
- pooler_output = self.dropout(pooler_output)
85
- logits = self.linear(pooler_output)
86
-
87
- loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).to(device))
88
- prediction = torch.argmax(logits)
89
-
90
- return {
91
- "loss": loss,
92
- "prediction": prediction,
93
- }
94
-
95
-
96
- class REDataset(Dataset):
97
- def __init__(self, path, exp, tokenizer):
98
- super(REDataset, self).__init__()
99
- self.tokenizer = tokenizer
100
- self.exp = exp
101
- self.sentences = []
102
- self.labels = []
103
- self.entities = []
104
- with open(path, "r", encoding="utf-8") as file:
105
- data = file.readlines()
106
- for example in data:
107
- sentence, entity1, entity2, id, label = example.strip().split("\t")
108
- self.sentences.append(sentence)
109
- if eval(label) == 1:
110
- self.labels.append(1)
111
- elif eval(label) == -1:
112
- self.labels.append(0)
113
-
114
- self.entities.append([entity1, entity2])
115
-
116
- logger.info("Number of Example in {}: {}".format(path, str(len(self.labels))))
117
-
118
- def __len__(self):
119
- return len(self.labels)
120
-
121
- def __getitem__(self, index):
122
- return {
123
- "sentence": self.sentences[index],
124
- "entity": self.entities[index],
125
- "label": self.labels[index],
126
- }
127
-
128
- def collate_fn(self, batch):
129
- outputs = []
130
- for ex in batch:
131
- temp = []
132
- for exp in self.exp:
133
- if "{e1}" in exp or "{e2}" in exp:
134
- exp = exp.replace("{e1}", ex["entity"][0]).replace("{e2}", ex["entity"][1])
135
- else:
136
- for entity in ex["entity"]:
137
- index = exp.index('<mask>')
138
- exp = exp[:index] + entity + exp[index + len('<mask>'):]
139
- temp.append(exp)
140
- outputs.append(
141
- {
142
- "encoding": self.tokenizer(
143
- [ex["sentence"]] * len(temp), temp,
144
- add_special_tokens=True,
145
- padding="longest",
146
- truncation=True,
147
- max_length=156,
148
- return_tensors="pt",
149
- return_attention_mask=True,
150
- return_token_type_ids=True,
151
- ),
152
- "label": ex["label"],
153
- }
154
- )
155
- return outputs
156
-
157
- def collate_fn_(self, batch):
158
- texts = []
159
- labels = []
160
- for ex in batch:
161
- texts.append(ex["sentence"])
162
- labels.append(ex["label"])
163
-
164
- outputs = self.tokenizer(
165
- texts,
166
- add_special_tokens=True,
167
- padding="longest",
168
- truncation=True,
169
- max_length=156,
170
- return_tensors="pt",
171
- return_attention_mask=True,
172
- return_token_type_ids=True,
173
- )
174
-
175
- outputs["labels"] = torch.LongTensor(labels)
176
-
177
- return outputs
178
-
179
-
180
- class Trainer(object):
181
- def __init__(self, args):
182
- self.args = args
183
- print_config(args)
184
- self.tokenizer = AutoTokenizer.from_pretrained(self.args.model)
185
-
186
- TASK2EXP = GENERATED_EXP if args.generated_rules else ANNOTATED_EXP
187
- with open(TASK2EXP[args.task], "r", encoding="utf-8") as file:
188
- exp = file.readlines()
189
-
190
- self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
191
- self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
192
- self.model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) if self.args.no_exp else ExpBERT(args, len(exp)).to(device)
193
-
194
- self.train_loader = DataLoader(
195
- self.train_dataset,
196
- batch_size=args.batch_size,
197
- shuffle=args.shuffle,
198
- collate_fn=self.train_dataset.collate_fn_ if self.args.no_exp else self.train_dataset.collate_fn,
199
- )
200
-
201
- self.test_loader = DataLoader(
202
- self.test_dataset,
203
- batch_size=args.batch_size,
204
- shuffle=args.shuffle,
205
- collate_fn=self.test_dataset.collate_fn_ if self.args.no_exp else self.test_dataset.collate_fn,
206
- )
207
-
208
- self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.learning_rate)
209
-
210
- def compute_metrics(self, labels, predictions):
211
- accuracy = accuracy_score(y_pred=predictions, y_true=labels)
212
- f1 = f1_score(y_pred=predictions, y_true=labels)
213
-
214
- return accuracy, f1
215
-
216
- def train(self):
217
- self.model.train()
218
- self.test(-1)
219
- for e in range(self.args.epochs):
220
- with tqdm(total=len(self.train_loader)) as pbar:
221
- for step, examples in enumerate(self.train_loader):
222
- self.model.zero_grad()
223
- if self.args.no_exp:
224
- for k, v in examples.items():
225
- examples[k] = v.to(device)
226
- outputs = self.model(**examples)
227
- outputs.loss.backward()
228
-
229
- else:
230
- for ex in examples:
231
- outputs = self.model(ex)
232
- (outputs["loss"] / len(examples)).backward()
233
-
234
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
235
- self.optimizer.step()
236
- pbar.update(1)
237
-
238
- self.test(e)
239
-
240
- def test(self, epoch):
241
- self.model.eval()
242
- with torch.no_grad():
243
- with tqdm(total=len(self.test_loader)) as pbar:
244
- loss = []
245
- labels = []
246
- predictions = []
247
- for step, examples in enumerate(self.test_loader):
248
- if self.args.no_exp:
249
- for k, v in examples.items():
250
- examples[k] = v.to(device)
251
- outputs = self.model(**examples)
252
- loss.append(outputs.loss.float())
253
- labels.extend(examples["labels"].tolist())
254
- predictions.extend(torch.argmax(outputs.logits, dim=1).tolist())
255
-
256
- else:
257
- for ex in examples:
258
- labels.append(ex['label'])
259
- outputs = self.model(ex)
260
- loss.append(outputs["loss"].item())
261
- predictions.append(outputs['prediction'].tolist())
262
-
263
- pbar.update(1)
264
- accuracy, f1 = self.compute_metrics(predictions, labels)
265
- logger.info("[EPOCH {}] Accuracy: {} | F1-Score: {}. (Number of Data {})".format(epoch, accuracy, f1, len(predictions)))
266
-
267
-
268
- if __name__ == "__main__":
269
- parser = argparse.ArgumentParser()
270
- parser.add_argument("--task", type=str, default="spouse")
271
- parser.add_argument("--model", type=str, default="bert-base-uncased")
272
- parser.add_argument("--batch_size", type=int, default=32)
273
- parser.add_argument("--learning_rate", type=float, default=2e-5)
274
- parser.add_argument("--shuffle", type=bool, default=False)
275
- parser.add_argument("--epochs", type=int, default=5)
276
- parser.add_argument("--no_exp", type=bool, default=False)
277
- parser.add_argument("--generated_rules", type=bool, default=False)
278
-
279
- args = parser.parse_args()
280
-
281
- for seed in range(42, 47):
282
- set_random_seed(seed)
283
- trainer = Trainer(args)
284
- trainer.train()