sxtforreal commited on
Commit
4f3b260
·
verified ·
1 Parent(s): 18f03d1

Create model.py

Browse files

This file holds 4 models: SimCSE, SimCSE_w, Samp, Samp_w.
SimCSE: Simple Contrastive Learning model
SimCSE_w: SimCSE+weighting
Samp: Our positive & negative sampling model
Samp_w: Samp+weighting

Files changed (1) hide show
  1. model.py +492 -0
model.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning.pytorch as pl
2
+ from transformers import (
3
+ AdamW,
4
+ AutoModel,
5
+ get_linear_schedule_with_warmup,
6
+ )
7
+ import torch
8
+ from torch import nn
9
+ from loss import (
10
+ ContrastiveLoss_simcse,
11
+ ContrastiveLoss_simcse_w,
12
+ ContrastiveLoss_samp,
13
+ ContrastiveLoss_samp_w,
14
+ )
15
+
16
+
17
+ class BERTContrastiveLearning_simcse(pl.LightningModule):
18
+ def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
19
+ super().__init__()
20
+ ### Parameters
21
+ self.n_batches = n_batches
22
+ self.n_epochs = n_epochs
23
+ self.lr = lr
24
+
25
+ ### Architecture
26
+ self.bert = AutoModel.from_pretrained(
27
+ "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
28
+ )
29
+ # Unfreeze encoder
30
+ self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
31
+ self.num_unfreeze_layer = self.bert_layer_num
32
+ self.ratio_unfreeze_layer = 0.0
33
+ if kwargs:
34
+ for key, value in kwargs.items():
35
+ if key == "unfreeze" and isinstance(value, float):
36
+ assert (
37
+ value >= 0.0 and value <= 1.0
38
+ ), "ValueError: value must be a ratio between 0.0 and 1.0"
39
+ self.ratio_unfreeze_layer = value
40
+ if self.ratio_unfreeze_layer > 0.0:
41
+ self.num_unfreeze_layer = int(
42
+ self.bert_layer_num * self.ratio_unfreeze_layer
43
+ )
44
+ for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
45
+ param.requires_grad = False
46
+ # Random dropouts
47
+ self.dropout1 = nn.Dropout(p=0.1)
48
+ self.dropout2 = nn.Dropout(p=0.1)
49
+ # Linear projector
50
+ self.projector = nn.Linear(self.bert.config.hidden_size, 128)
51
+ print("Model Initialized!")
52
+
53
+ ### Loss
54
+ self.criterion = ContrastiveLoss_simcse()
55
+
56
+ ### Logs
57
+ self.train_loss, self.val_loss, self.test_loss = [], [], []
58
+ self.training_step_outputs = []
59
+ self.validation_step_outputs = []
60
+
61
+ def configure_optimizers(self):
62
+ # Optimizer
63
+ self.trainable_params = [
64
+ param for param in self.parameters() if param.requires_grad
65
+ ]
66
+ optimizer = AdamW(self.trainable_params, lr=self.lr)
67
+
68
+ # Scheduler
69
+ # warmup_steps = self.n_batches // 3
70
+ # total_steps = self.n_batches * self.n_epochs - warmup_steps
71
+ # scheduler = get_linear_schedule_with_warmup(
72
+ # optimizer, warmup_steps, total_steps
73
+ # )
74
+ return [optimizer]
75
+
76
+ def forward(self, input_ids, attention_mask):
77
+ emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
78
+ cls = emb.pooler_output
79
+ out = self.projector(cls)
80
+ anchor_out = self.dropout1(out[0:1])
81
+ rest_out = self.dropout2(out[1:])
82
+ output = torch.cat([anchor_out, rest_out])
83
+ return cls, output
84
+
85
+ def training_step(self, batch, batch_idx):
86
+ label = batch["label"]
87
+ input_ids = batch["input_ids"]
88
+ attention_mask = batch["attention_mask"]
89
+ cls, out = self(
90
+ input_ids,
91
+ attention_mask,
92
+ )
93
+ loss = self.criterion(out, label)
94
+ logs = {"loss": loss}
95
+ self.training_step_outputs.append(logs)
96
+ self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
97
+ return loss
98
+
99
+ def on_train_epoch_end(self):
100
+ loss = (
101
+ torch.stack([x["loss"] for x in self.training_step_outputs])
102
+ .mean()
103
+ .detach()
104
+ .cpu()
105
+ .numpy()
106
+ )
107
+ self.train_loss.append(loss)
108
+ print("train_epoch:", self.current_epoch, "avg_loss:", loss)
109
+ self.training_step_outputs.clear()
110
+
111
+ def validation_step(self, batch, batch_idx):
112
+ label = batch["label"]
113
+ input_ids = batch["input_ids"]
114
+ attention_mask = batch["attention_mask"]
115
+ cls, out = self(
116
+ input_ids,
117
+ attention_mask,
118
+ )
119
+ loss = self.criterion(out, label)
120
+ logs = {"loss": loss}
121
+ self.validation_step_outputs.append(logs)
122
+ self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
123
+ return loss
124
+
125
+ def on_validation_epoch_end(self):
126
+ loss = (
127
+ torch.stack([x["loss"] for x in self.validation_step_outputs])
128
+ .mean()
129
+ .detach()
130
+ .cpu()
131
+ .numpy()
132
+ )
133
+ self.val_loss.append(loss)
134
+ print("val_epoch:", self.current_epoch, "avg_loss:", loss)
135
+ self.validation_step_outputs.clear()
136
+
137
+
138
+ class BERTContrastiveLearning_simcse_w(pl.LightningModule):
139
+ def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
140
+ super().__init__()
141
+ ### Parameters
142
+ self.n_batches = n_batches
143
+ self.n_epochs = n_epochs
144
+ self.lr = lr
145
+
146
+ ### Architecture
147
+ self.bert = AutoModel.from_pretrained(
148
+ "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
149
+ )
150
+ # Unfreeze encoder
151
+ self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
152
+ self.num_unfreeze_layer = self.bert_layer_num
153
+ self.ratio_unfreeze_layer = 0.0
154
+ if kwargs:
155
+ for key, value in kwargs.items():
156
+ if key == "unfreeze" and isinstance(value, float):
157
+ assert (
158
+ value >= 0.0 and value <= 1.0
159
+ ), "ValueError: value must be a ratio between 0.0 and 1.0"
160
+ self.ratio_unfreeze_layer = value
161
+ if self.ratio_unfreeze_layer > 0.0:
162
+ self.num_unfreeze_layer = int(
163
+ self.bert_layer_num * self.ratio_unfreeze_layer
164
+ )
165
+ for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
166
+ param.requires_grad = False
167
+ # Random dropouts
168
+ self.dropout1 = nn.Dropout(p=0.1)
169
+ self.dropout2 = nn.Dropout(p=0.1)
170
+ # Linear projector
171
+ self.projector = nn.Linear(self.bert.config.hidden_size, 128)
172
+ print("Model Initialized!")
173
+
174
+ ### Loss
175
+ self.criterion = ContrastiveLoss_simcse_w()
176
+
177
+ ### Logs
178
+ self.train_loss, self.val_loss, self.test_loss = [], [], []
179
+ self.training_step_outputs = []
180
+ self.validation_step_outputs = []
181
+
182
+ def configure_optimizers(self):
183
+ # Optimizer
184
+ self.trainable_params = [
185
+ param for param in self.parameters() if param.requires_grad
186
+ ]
187
+ optimizer = AdamW(self.trainable_params, lr=self.lr)
188
+
189
+ # Scheduler
190
+ # warmup_steps = self.n_batches // 3
191
+ # total_steps = self.n_batches * self.n_epochs - warmup_steps
192
+ # scheduler = get_linear_schedule_with_warmup(
193
+ # optimizer, warmup_steps, total_steps
194
+ # )
195
+ return [optimizer]
196
+
197
+ def forward(self, input_ids, attention_mask):
198
+ emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
199
+ cls = emb.pooler_output
200
+ out = self.projector(cls)
201
+ anchor_out = self.dropout1(out[0:1])
202
+ rest_out = self.dropout2(out[1:])
203
+ output = torch.cat([anchor_out, rest_out])
204
+ return cls, output
205
+
206
+ def training_step(self, batch, batch_idx):
207
+ label = batch["label"]
208
+ input_ids = batch["input_ids"]
209
+ attention_mask = batch["attention_mask"]
210
+ score = batch["score"]
211
+ cls, out = self(
212
+ input_ids,
213
+ attention_mask,
214
+ )
215
+ loss = self.criterion(out, label, score)
216
+ logs = {"loss": loss}
217
+ self.training_step_outputs.append(logs)
218
+ self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
219
+ return loss
220
+
221
+ def on_train_epoch_end(self):
222
+ loss = (
223
+ torch.stack([x["loss"] for x in self.training_step_outputs])
224
+ .mean()
225
+ .detach()
226
+ .cpu()
227
+ .numpy()
228
+ )
229
+ self.train_loss.append(loss)
230
+ print("train_epoch:", self.current_epoch, "avg_loss:", loss)
231
+ self.training_step_outputs.clear()
232
+
233
+ def validation_step(self, batch, batch_idx):
234
+ label = batch["label"]
235
+ input_ids = batch["input_ids"]
236
+ attention_mask = batch["attention_mask"]
237
+ score = batch["score"]
238
+ cls, out = self(
239
+ input_ids,
240
+ attention_mask,
241
+ )
242
+ loss = self.criterion(out, label, score)
243
+ logs = {"loss": loss}
244
+ self.validation_step_outputs.append(logs)
245
+ self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
246
+ return loss
247
+
248
+ def on_validation_epoch_end(self):
249
+ loss = (
250
+ torch.stack([x["loss"] for x in self.validation_step_outputs])
251
+ .mean()
252
+ .detach()
253
+ .cpu()
254
+ .numpy()
255
+ )
256
+ self.val_loss.append(loss)
257
+ print("val_epoch:", self.current_epoch, "avg_loss:", loss)
258
+ self.validation_step_outputs.clear()
259
+
260
+
261
+ class BERTContrastiveLearning_samp(pl.LightningModule):
262
+
263
+ def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
264
+ super().__init__()
265
+ ### Parameters
266
+ self.n_batches = n_batches
267
+ self.n_epochs = n_epochs
268
+ self.lr = lr
269
+
270
+ ### Architecture
271
+ self.bert = AutoModel.from_pretrained(
272
+ "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
273
+ )
274
+ # Unfreeze encoder
275
+ self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
276
+ self.num_unfreeze_layer = self.bert_layer_num
277
+ self.ratio_unfreeze_layer = 0.0
278
+ if kwargs:
279
+ for key, value in kwargs.items():
280
+ if key == "unfreeze" and isinstance(value, float):
281
+ assert (
282
+ value >= 0.0 and value <= 1.0
283
+ ), "ValueError: value must be a ratio between 0.0 and 1.0"
284
+ self.ratio_unfreeze_layer = value
285
+ if self.ratio_unfreeze_layer > 0.0:
286
+ self.num_unfreeze_layer = int(
287
+ self.bert_layer_num * self.ratio_unfreeze_layer
288
+ )
289
+ for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
290
+ param.requires_grad = False
291
+ # Linear projector
292
+ self.projector = nn.Linear(self.bert.config.hidden_size, 128)
293
+ print("Model Initialized!")
294
+
295
+ ### Loss
296
+ self.criterion = ContrastiveLoss_samp()
297
+
298
+ ### Logs
299
+ self.train_loss, self.val_loss, self.test_loss = [], [], []
300
+ self.training_step_outputs = []
301
+ self.validation_step_outputs = []
302
+
303
+ def configure_optimizers(self):
304
+ # Optimizer
305
+ self.trainable_params = [
306
+ param for param in self.parameters() if param.requires_grad
307
+ ]
308
+ optimizer = AdamW(self.trainable_params, lr=self.lr)
309
+
310
+ # Scheduler
311
+ # warmup_steps = self.n_batches // 3
312
+ # total_steps = self.n_batches * self.n_epochs - warmup_steps
313
+ # scheduler = get_linear_schedule_with_warmup(
314
+ # optimizer, warmup_steps, total_steps
315
+ # )
316
+ return [optimizer]
317
+
318
+ def forward(self, input_ids, attention_mask):
319
+ emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
320
+ cls = emb.pooler_output
321
+ out = self.projector(cls)
322
+ return cls, out
323
+
324
+ def training_step(self, batch, batch_idx):
325
+ label = batch["label"]
326
+ input_ids = batch["input_ids"]
327
+ attention_mask = batch["attention_mask"]
328
+ cls, out = self(
329
+ input_ids,
330
+ attention_mask,
331
+ )
332
+ loss = self.criterion(out, label)
333
+ logs = {"loss": loss}
334
+ self.training_step_outputs.append(logs)
335
+ self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
336
+ return loss
337
+
338
+ def on_train_epoch_end(self):
339
+ loss = (
340
+ torch.stack([x["loss"] for x in self.training_step_outputs])
341
+ .mean()
342
+ .detach()
343
+ .cpu()
344
+ .numpy()
345
+ )
346
+ self.train_loss.append(loss)
347
+ print("train_epoch:", self.current_epoch, "avg_loss:", loss)
348
+ self.training_step_outputs.clear()
349
+
350
+ def validation_step(self, batch, batch_idx):
351
+ label = batch["label"]
352
+ input_ids = batch["input_ids"]
353
+ attention_mask = batch["attention_mask"]
354
+ cls, out = self(
355
+ input_ids,
356
+ attention_mask,
357
+ )
358
+ loss = self.criterion(out, label)
359
+ logs = {"loss": loss}
360
+ self.validation_step_outputs.append(logs)
361
+ self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
362
+ return loss
363
+
364
+ def on_validation_epoch_end(self):
365
+ loss = (
366
+ torch.stack([x["loss"] for x in self.validation_step_outputs])
367
+ .mean()
368
+ .detach()
369
+ .cpu()
370
+ .numpy()
371
+ )
372
+ self.val_loss.append(loss)
373
+ print("val_epoch:", self.current_epoch, "avg_loss:", loss)
374
+ self.validation_step_outputs.clear()
375
+
376
+
377
+ class BERTContrastiveLearning_samp_w(pl.LightningModule):
378
+
379
+ def __init__(self, n_batches=None, n_epochs=None, lr=None, **kwargs):
380
+ super().__init__()
381
+ ### Parameters
382
+ self.n_batches = n_batches
383
+ self.n_epochs = n_epochs
384
+ self.lr = lr
385
+
386
+ ### Architecture
387
+ self.bert = AutoModel.from_pretrained(
388
+ "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
389
+ )
390
+ # Unfreeze encoder
391
+ self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
392
+ self.num_unfreeze_layer = self.bert_layer_num
393
+ self.ratio_unfreeze_layer = 0.0
394
+ if kwargs:
395
+ for key, value in kwargs.items():
396
+ if key == "unfreeze" and isinstance(value, float):
397
+ assert (
398
+ value >= 0.0 and value <= 1.0
399
+ ), "ValueError: value must be a ratio between 0.0 and 1.0"
400
+ self.ratio_unfreeze_layer = value
401
+ if self.ratio_unfreeze_layer > 0.0:
402
+ self.num_unfreeze_layer = int(
403
+ self.bert_layer_num * self.ratio_unfreeze_layer
404
+ )
405
+ for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
406
+ param.requires_grad = False
407
+ # Linear projector
408
+ self.projector = nn.Linear(self.bert.config.hidden_size, 128)
409
+ print("Model Initialized!")
410
+
411
+ ### Loss
412
+ self.criterion = ContrastiveLoss_samp_w()
413
+
414
+ ### Logs
415
+ self.train_loss, self.val_loss, self.test_loss = [], [], []
416
+ self.training_step_outputs = []
417
+ self.validation_step_outputs = []
418
+
419
+ def configure_optimizers(self):
420
+ # Optimizer
421
+ self.trainable_params = [
422
+ param for param in self.parameters() if param.requires_grad
423
+ ]
424
+ optimizer = AdamW(self.trainable_params, lr=self.lr)
425
+
426
+ # Scheduler
427
+ # warmup_steps = self.n_batches // 3
428
+ # total_steps = self.n_batches * self.n_epochs - warmup_steps
429
+ # scheduler = get_linear_schedule_with_warmup(
430
+ # optimizer, warmup_steps, total_steps
431
+ # )
432
+ return [optimizer]
433
+
434
+ def forward(self, input_ids, attention_mask):
435
+ emb = self.bert(input_ids=input_ids, attention_mask=attention_mask)
436
+ cls = emb.pooler_output
437
+ out = self.projector(cls)
438
+ return cls, out
439
+
440
+ def training_step(self, batch, batch_idx):
441
+ label = batch["label"]
442
+ input_ids = batch["input_ids"]
443
+ attention_mask = batch["attention_mask"]
444
+ score = batch["score"]
445
+ cls, out = self(
446
+ input_ids,
447
+ attention_mask,
448
+ )
449
+ loss = self.criterion(out, label, score)
450
+ logs = {"loss": loss}
451
+ self.training_step_outputs.append(logs)
452
+ self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
453
+ return loss
454
+
455
+ def on_train_epoch_end(self):
456
+ loss = (
457
+ torch.stack([x["loss"] for x in self.training_step_outputs])
458
+ .mean()
459
+ .detach()
460
+ .cpu()
461
+ .numpy()
462
+ )
463
+ self.train_loss.append(loss)
464
+ print("train_epoch:", self.current_epoch, "avg_loss:", loss)
465
+ self.training_step_outputs.clear()
466
+
467
+ def validation_step(self, batch, batch_idx):
468
+ label = batch["label"]
469
+ input_ids = batch["input_ids"]
470
+ attention_mask = batch["attention_mask"]
471
+ score = batch["score"]
472
+ cls, out = self(
473
+ input_ids,
474
+ attention_mask,
475
+ )
476
+ loss = self.criterion(out, label, score)
477
+ logs = {"loss": loss}
478
+ self.validation_step_outputs.append(logs)
479
+ self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True)
480
+ return loss
481
+
482
+ def on_validation_epoch_end(self):
483
+ loss = (
484
+ torch.stack([x["loss"] for x in self.validation_step_outputs])
485
+ .mean()
486
+ .detach()
487
+ .cpu()
488
+ .numpy()
489
+ )
490
+ self.val_loss.append(loss)
491
+ print("val_epoch:", self.current_epoch, "avg_loss:", loss)
492
+ self.validation_step_outputs.clear()