Gagan Bhatia commited on
Commit
b9412d1
·
1 Parent(s): 9f217b5
Files changed (2) hide show
  1. src/models/model.py +435 -0
  2. src/models/train_model.py +0 -441
src/models/model.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from datasets import load_metric
6
+ from transformers import (
7
+ AdamW,
8
+ T5ForConditionalGeneration,
9
+ T5TokenizerFast as T5Tokenizer,
10
+ )
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import pytorch_lightning as pl
13
+ from pytorch_lightning.loggers import MLFlowLogger
14
+ from pytorch_lightning import Trainer
15
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
16
+ from pytorch_lightning import LightningDataModule
17
+ from pytorch_lightning import LightningModule
18
+
19
+ torch.cuda.empty_cache()
20
+ pl.seed_everything(42)
21
+
22
+
23
+ class DataModule(Dataset):
24
+ """
25
+ Data Module for pytorch
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ data: pd.DataFrame,
31
+ tokenizer: T5Tokenizer,
32
+ source_max_token_len: int = 512,
33
+ target_max_token_len: int = 512,
34
+ ):
35
+ """
36
+ :param data:
37
+ :param tokenizer:
38
+ :param source_max_token_len:
39
+ :param target_max_token_len:
40
+ """
41
+ self.data = data
42
+ self.target_max_token_len = target_max_token_len
43
+ self.source_max_token_len = source_max_token_len
44
+ self.tokenizer = tokenizer
45
+
46
+ def __len__(self):
47
+ return len(self.data)
48
+
49
+ def __getitem__(self, index: int):
50
+ data_row = self.data.iloc[index]
51
+
52
+ input_encoding = self.tokenizer(
53
+ data_row["input_text"],
54
+ max_length=self.source_max_token_len,
55
+ padding="max_length",
56
+ truncation=True,
57
+ return_attention_mask=True,
58
+ add_special_tokens=True,
59
+ return_tensors="pt",
60
+ )
61
+
62
+ output_encoding = self.tokenizer(
63
+ data_row["output_text"],
64
+ max_length=self.target_max_token_len,
65
+ padding="max_length",
66
+ truncation=True,
67
+ return_attention_mask=True,
68
+ add_special_tokens=True,
69
+ return_tensors="pt",
70
+ )
71
+
72
+ labels = output_encoding["input_ids"]
73
+ labels[
74
+ labels == 0
75
+ ] = -100
76
+
77
+ return dict(
78
+ keywords=data_row["keywords"],
79
+ text=data_row["text"],
80
+ keywords_input_ids=input_encoding["input_ids"].flatten(),
81
+ keywords_attention_mask=input_encoding["attention_mask"].flatten(),
82
+ labels=labels.flatten(),
83
+ labels_attention_mask=output_encoding["attention_mask"].flatten(),
84
+ )
85
+
86
+
87
+ class PLDataModule(LightningDataModule):
88
+ def __init__(
89
+ self,
90
+ train_df: pd.DataFrame,
91
+ test_df: pd.DataFrame,
92
+ tokenizer: T5Tokenizer,
93
+ source_max_token_len: int = 512,
94
+ target_max_token_len: int = 512,
95
+ batch_size: int = 4,
96
+ split: float = 0.1
97
+ ):
98
+ """
99
+ :param data_df:
100
+ :param tokenizer:
101
+ :param source_max_token_len:
102
+ :param target_max_token_len:
103
+ :param batch_size:
104
+ :param split:
105
+ """
106
+ super().__init__()
107
+ self.train_df = train_df
108
+ self.test_df = test_df
109
+ self.split = split
110
+ self.batch_size = batch_size
111
+ self.target_max_token_len = target_max_token_len
112
+ self.source_max_token_len = source_max_token_len
113
+ self.tokenizer = tokenizer
114
+
115
+ def setup(self, stage=None):
116
+ self.train_dataset = DataModule(
117
+ self.train_df,
118
+ self.tokenizer,
119
+ self.source_max_token_len,
120
+ self.target_max_token_len,
121
+ )
122
+ self.test_dataset = DataModule(
123
+ self.test_df,
124
+ self.tokenizer,
125
+ self.source_max_token_len,
126
+ self.target_max_token_len,
127
+ )
128
+
129
+ def train_dataloader(self):
130
+ """ training dataloader """
131
+ return DataLoader(
132
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2
133
+ )
134
+
135
+ def test_dataloader(self):
136
+ """ test dataloader """
137
+ return DataLoader(
138
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
139
+ )
140
+
141
+ def val_dataloader(self):
142
+ """ validation dataloader """
143
+ return DataLoader(
144
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
145
+ )
146
+
147
+
148
+ class LightningModel(LightningModule):
149
+ """ PyTorch Lightning Model class"""
150
+
151
+ def __init__(self, tokenizer, model, output: str = "outputs"):
152
+ """
153
+ initiates a PyTorch Lightning Model
154
+ Args:
155
+ tokenizer : T5 tokenizer
156
+ model : T5 model
157
+ output (str, optional): output directory to save model checkpoints. Defaults to "outputs".
158
+ """
159
+ super().__init__()
160
+ self.model = model
161
+ self.tokenizer = tokenizer
162
+ self.output = output
163
+ # self.val_acc = Accuracy()
164
+ # self.train_acc = Accuracy()
165
+
166
+ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
167
+ """ forward step """
168
+ output = self.model(
169
+ input_ids,
170
+ attention_mask=attention_mask,
171
+ labels=labels,
172
+ decoder_attention_mask=decoder_attention_mask,
173
+ )
174
+
175
+ return output.loss, output.logits
176
+
177
+ def training_step(self, batch, batch_size):
178
+ """ training step """
179
+ input_ids = batch["keywords_input_ids"]
180
+ attention_mask = batch["keywords_attention_mask"]
181
+ labels = batch["labels"]
182
+ labels_attention_mask = batch["labels_attention_mask"]
183
+
184
+ loss, outputs = self(
185
+ input_ids=input_ids,
186
+ attention_mask=attention_mask,
187
+ decoder_attention_mask=labels_attention_mask,
188
+ labels=labels,
189
+ )
190
+ self.log("train_loss", loss, prog_bar=True, logger=True)
191
+ return loss
192
+
193
+ def validation_step(self, batch, batch_size):
194
+ """ validation step """
195
+ input_ids = batch["keywords_input_ids"]
196
+ attention_mask = batch["keywords_attention_mask"]
197
+ labels = batch["labels"]
198
+ labels_attention_mask = batch["labels_attention_mask"]
199
+
200
+ loss, outputs = self(
201
+ input_ids=input_ids,
202
+ attention_mask=attention_mask,
203
+ decoder_attention_mask=labels_attention_mask,
204
+ labels=labels,
205
+ )
206
+ self.log("val_loss", loss, prog_bar=True, logger=True)
207
+ return loss
208
+
209
+ def test_step(self, batch, batch_size):
210
+ """ test step """
211
+ input_ids = batch["keywords_input_ids"]
212
+ attention_mask = batch["keywords_attention_mask"]
213
+ labels = batch["labels"]
214
+ labels_attention_mask = batch["labels_attention_mask"]
215
+
216
+ loss, outputs = self(
217
+ input_ids=input_ids,
218
+ attention_mask=attention_mask,
219
+ decoder_attention_mask=labels_attention_mask,
220
+ labels=labels,
221
+ )
222
+
223
+ self.log("test_loss", loss, prog_bar=True, logger=True)
224
+ return loss
225
+
226
+ def configure_optimizers(self):
227
+ """ configure optimizers """
228
+ model = self.model
229
+ no_decay = ["bias", "LayerNorm.weight"]
230
+ optimizer_grouped_parameters = [
231
+ {
232
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
233
+ "weight_decay": self.hparams.weight_decay,
234
+ },
235
+ {
236
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
237
+ "weight_decay": 0.0,
238
+ },
239
+ ]
240
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
241
+ self.opt = optimizer
242
+ return [optimizer]
243
+
244
+
245
+ class Summarization:
246
+ """ Custom Summarization class """
247
+
248
+ def __init__(self) -> None:
249
+ """ initiates Summarization class """
250
+ pass
251
+
252
+ def from_pretrained(self, model_name="t5-base") -> None:
253
+ """
254
+ loads T5/MT5 Model model for training/finetuning
255
+ Args:
256
+ model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
257
+ """
258
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
259
+ self.model = T5ForConditionalGeneration.from_pretrained(
260
+ f"{model_name}", return_dict=True
261
+ )
262
+
263
+ def train(
264
+ self,
265
+ train_df: pd.DataFrame,
266
+ eval_df: pd.DataFrame,
267
+ source_max_token_len: int = 512,
268
+ target_max_token_len: int = 512,
269
+ batch_size: int = 8,
270
+ max_epochs: int = 5,
271
+ use_gpu: bool = True,
272
+ outputdir: str = "model",
273
+ early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
274
+ ):
275
+ """
276
+ trains T5/MT5 model on custom dataset
277
+ Args:
278
+ train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "input_text" and "output_text"
279
+ eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "input_text" and
280
+ "output_text"
281
+ source_max_token_len (int, optional): max token length of source text. Defaults to 512.
282
+ target_max_token_len (int, optional): max token length of target text. Defaults to 512.
283
+ batch_size (int, optional): batch size. Defaults to 8.
284
+ max_epochs (int, optional): max number of epochs. Defaults to 5.
285
+ use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
286
+ outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
287
+ early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training,
288
+ if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping.
289
+ Defaults to 0 (disabled)
290
+ """
291
+ self.target_max_token_len = target_max_token_len
292
+ self.data_module = PLDataModule(
293
+ train_df,
294
+ eval_df,
295
+ self.tokenizer,
296
+ batch_size=batch_size,
297
+ source_max_token_len=source_max_token_len,
298
+ target_max_token_len=target_max_token_len,
299
+ )
300
+
301
+ self.T5Model = LightningModel(
302
+ tokenizer=self.tokenizer, model=self.model, output=outputdir
303
+ )
304
+
305
+ # checkpoint_callback = ModelCheckpoint(
306
+ # dirpath="checkpoints",
307
+ # filename="best-checkpoint-{epoch}-{train_loss:.2f}",
308
+ # save_top_k=-1,
309
+ # verbose=True,
310
+ # monitor="train_loss",
311
+ # mode="min",
312
+ # )
313
+
314
+ logger = MLFlowLogger(experiment_name="Summarization")
315
+
316
+ early_stop_callback = (
317
+ [
318
+ EarlyStopping(
319
+ monitor="val_loss",
320
+ min_delta=0.00,
321
+ patience=early_stopping_patience_epochs,
322
+ verbose=True,
323
+ mode="min",
324
+ )
325
+ ]
326
+ if early_stopping_patience_epochs > 0
327
+ else None
328
+ )
329
+
330
+ gpus = 1 if use_gpu else 0
331
+
332
+ trainer = Trainer(
333
+ logger=logger,
334
+ callbacks=early_stop_callback,
335
+ max_epochs=max_epochs,
336
+ gpus=gpus,
337
+ progress_bar_refresh_rate=5,
338
+ )
339
+
340
+ trainer.fit(self.T5Model, self.data_module)
341
+
342
+ def load_model(
343
+ self, model_dir: str = "model", use_gpu: bool = False
344
+ ):
345
+ """
346
+ loads a checkpoint for inferencing/prediction
347
+ Args:
348
+ model_type (str, optional): "t5" or "mt5". Defaults to "t5".
349
+ model_dir (str, optional): path to model directory. Defaults to "outputs".
350
+ use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
351
+ """
352
+ self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
353
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
354
+
355
+ if use_gpu:
356
+ if torch.cuda.is_available():
357
+ self.device = torch.device("cuda")
358
+ else:
359
+ raise Exception("exception ---> no gpu found. set use_gpu=False, to use CPU")
360
+ else:
361
+ self.device = torch.device("cpu")
362
+
363
+ self.model = self.model.to(self.device)
364
+
365
+ def save_model(
366
+ self,
367
+ model_dir="model"
368
+ ):
369
+ """
370
+ Save model to dir
371
+ :param model_dir:
372
+ :return: model is saved
373
+ """
374
+ path = f"{model_dir}"
375
+ self.tokenizer.save_pretrained(path)
376
+ self.model.save_pretrained(path)
377
+
378
+ def predict(
379
+ self,
380
+ source_text: str,
381
+ max_length: int = 512,
382
+ num_return_sequences: int = 1,
383
+ num_beams: int = 2,
384
+ top_k: int = 50,
385
+ top_p: float = 0.95,
386
+ do_sample: bool = True,
387
+ repetition_penalty: float = 2.5,
388
+ length_penalty: float = 1.0,
389
+ early_stopping: bool = True,
390
+ skip_special_tokens: bool = True,
391
+ clean_up_tokenization_spaces: bool = True,
392
+ ):
393
+ """
394
+ generates prediction for T5/MT5 model
395
+ Args:
396
+ source_text (str): any text for generating predictions
397
+ max_length (int, optional): max token length of prediction. Defaults to 512.
398
+ num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
399
+ num_beams (int, optional): number of beams. Defaults to 2.
400
+ top_k (int, optional): Defaults to 50.
401
+ top_p (float, optional): Defaults to 0.95.
402
+ do_sample (bool, optional): Defaults to True.
403
+ repetition_penalty (float, optional): Defaults to 2.5.
404
+ length_penalty (float, optional): Defaults to 1.0.
405
+ early_stopping (bool, optional): Defaults to True.
406
+ skip_special_tokens (bool, optional): Defaults to True.
407
+ clean_up_tokenization_spaces (bool, optional): Defaults to True.
408
+ Returns:
409
+ list[str]: returns predictions
410
+ """
411
+ input_ids = self.tokenizer.encode(
412
+ source_text, return_tensors="pt", add_special_tokens=True
413
+ )
414
+
415
+ input_ids = input_ids.to(self.device)
416
+ generated_ids = self.model.generate(
417
+ input_ids=input_ids,
418
+ num_beams=num_beams,
419
+ max_length=max_length,
420
+ repetition_penalty=repetition_penalty,
421
+ length_penalty=length_penalty,
422
+ early_stopping=early_stopping,
423
+ top_p=top_p,
424
+ top_k=top_k,
425
+ num_return_sequences=num_return_sequences,
426
+ )
427
+ preds = [
428
+ self.tokenizer.decode(
429
+ g,
430
+ skip_special_tokens=skip_special_tokens,
431
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
432
+ )
433
+ for g in generated_ids
434
+ ]
435
+ return preds
src/models/train_model.py CHANGED
@@ -1,441 +0,0 @@
1
- import time
2
-
3
- import torch
4
- import numpy as np
5
- import pandas as pd
6
- from datasets import load_metric
7
- from tqdm.auto import tqdm
8
- from transformers import (
9
- AdamW,
10
- T5ForConditionalGeneration,
11
- MT5ForConditionalGeneration,
12
- T5TokenizerFast as T5Tokenizer,
13
- MT5TokenizerFast as MT5Tokenizer,
14
- )
15
- from transformers import AutoTokenizer
16
- from torch.utils.data import Dataset, DataLoader
17
- from transformers import AutoModelWithLMHead, AutoTokenizer
18
- import pytorch_lightning as pl
19
- from pytorch_lightning.loggers import MLFlowLogger
20
- from pytorch_lightning import Trainer
21
- from pytorch_lightning.callbacks.early_stopping import EarlyStopping
22
- from pytorch_lightning import LightningDataModule
23
- from pytorch_lightning import LightningModule
24
-
25
- torch.cuda.empty_cache()
26
- pl.seed_everything(42)
27
-
28
-
29
- class DataModule(Dataset):
30
- """
31
- Data Module for pytorch
32
- """
33
-
34
- def __init__(
35
- self,
36
- data: pd.DataFrame,
37
- tokenizer: T5Tokenizer,
38
- source_max_token_len: int = 512,
39
- target_max_token_len: int = 512,
40
- ):
41
- """
42
- :param data:
43
- :param tokenizer:
44
- :param source_max_token_len:
45
- :param target_max_token_len:
46
- """
47
- self.data = data
48
- self.target_max_token_len = target_max_token_len
49
- self.source_max_token_len = source_max_token_len
50
- self.tokenizer = tokenizer
51
-
52
- def __len__(self):
53
- return len(self.data)
54
-
55
- def __getitem__(self, index: int):
56
- data_row = self.data.iloc[index]
57
-
58
- input_encoding = self.tokenizer(
59
- data_row["input_text"],
60
- max_length=self.source_max_token_len,
61
- padding="max_length",
62
- truncation=True,
63
- return_attention_mask=True,
64
- add_special_tokens=True,
65
- return_tensors="pt",
66
- )
67
-
68
- output_encoding = self.tokenizer(
69
- data_row["output_text"],
70
- max_length=self.target_max_token_len,
71
- padding="max_length",
72
- truncation=True,
73
- return_attention_mask=True,
74
- add_special_tokens=True,
75
- return_tensors="pt",
76
- )
77
-
78
- labels = output_encoding["input_ids"]
79
- labels[
80
- labels == 0
81
- ] = -100
82
-
83
- return dict(
84
- keywords=data_row["keywords"],
85
- text=data_row["text"],
86
- keywords_input_ids=input_encoding["input_ids"].flatten(),
87
- keywords_attention_mask=input_encoding["attention_mask"].flatten(),
88
- labels=labels.flatten(),
89
- labels_attention_mask=output_encoding["attention_mask"].flatten(),
90
- )
91
-
92
-
93
- class PLDataModule(LightningDataModule):
94
- def __init__(
95
- self,
96
- train_df: pd.DataFrame,
97
- test_df: pd.DataFrame,
98
- tokenizer: T5Tokenizer,
99
- source_max_token_len: int = 512,
100
- target_max_token_len: int = 512,
101
- batch_size: int = 4,
102
- split: float = 0.1
103
- ):
104
- """
105
- :param data_df:
106
- :param tokenizer:
107
- :param source_max_token_len:
108
- :param target_max_token_len:
109
- :param batch_size:
110
- :param split:
111
- """
112
- super().__init__()
113
- self.train_df = train_df
114
- self.test_df = test_df
115
- self.split = split
116
- self.batch_size = batch_size
117
- self.target_max_token_len = target_max_token_len
118
- self.source_max_token_len = source_max_token_len
119
- self.tokenizer = tokenizer
120
-
121
- def setup(self, stage=None):
122
- self.train_dataset = DataModule(
123
- self.train_df,
124
- self.tokenizer,
125
- self.source_max_token_len,
126
- self.target_max_token_len,
127
- )
128
- self.test_dataset = DataModule(
129
- self.test_df,
130
- self.tokenizer,
131
- self.source_max_token_len,
132
- self.target_max_token_len,
133
- )
134
-
135
- def train_dataloader(self):
136
- """ training dataloader """
137
- return DataLoader(
138
- self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2
139
- )
140
-
141
- def test_dataloader(self):
142
- """ test dataloader """
143
- return DataLoader(
144
- self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
145
- )
146
-
147
- def val_dataloader(self):
148
- """ validation dataloader """
149
- return DataLoader(
150
- self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
151
- )
152
-
153
-
154
- class LightningModel(LightningModule):
155
- """ PyTorch Lightning Model class"""
156
-
157
- def __init__(self, tokenizer, model, output: str = "outputs"):
158
- """
159
- initiates a PyTorch Lightning Model
160
- Args:
161
- tokenizer : T5 tokenizer
162
- model : T5 model
163
- output (str, optional): output directory to save model checkpoints. Defaults to "outputs".
164
- """
165
- super().__init__()
166
- self.model = model
167
- self.tokenizer = tokenizer
168
- self.output = output
169
- # self.val_acc = Accuracy()
170
- # self.train_acc = Accuracy()
171
-
172
- def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
173
- """ forward step """
174
- output = self.model(
175
- input_ids,
176
- attention_mask=attention_mask,
177
- labels=labels,
178
- decoder_attention_mask=decoder_attention_mask,
179
- )
180
-
181
- return output.loss, output.logits
182
-
183
- def training_step(self, batch, batch_size):
184
- """ training step """
185
- input_ids = batch["keywords_input_ids"]
186
- attention_mask = batch["keywords_attention_mask"]
187
- labels = batch["labels"]
188
- labels_attention_mask = batch["labels_attention_mask"]
189
-
190
- loss, outputs = self(
191
- input_ids=input_ids,
192
- attention_mask=attention_mask,
193
- decoder_attention_mask=labels_attention_mask,
194
- labels=labels,
195
- )
196
- self.log("train_loss", loss, prog_bar=True, logger=True)
197
- return loss
198
-
199
- def validation_step(self, batch, batch_size):
200
- """ validation step """
201
- input_ids = batch["keywords_input_ids"]
202
- attention_mask = batch["keywords_attention_mask"]
203
- labels = batch["labels"]
204
- labels_attention_mask = batch["labels_attention_mask"]
205
-
206
- loss, outputs = self(
207
- input_ids=input_ids,
208
- attention_mask=attention_mask,
209
- decoder_attention_mask=labels_attention_mask,
210
- labels=labels,
211
- )
212
- self.log("val_loss", loss, prog_bar=True, logger=True)
213
- return loss
214
-
215
- def test_step(self, batch, batch_size):
216
- """ test step """
217
- input_ids = batch["keywords_input_ids"]
218
- attention_mask = batch["keywords_attention_mask"]
219
- labels = batch["labels"]
220
- labels_attention_mask = batch["labels_attention_mask"]
221
-
222
- loss, outputs = self(
223
- input_ids=input_ids,
224
- attention_mask=attention_mask,
225
- decoder_attention_mask=labels_attention_mask,
226
- labels=labels,
227
- )
228
-
229
- self.log("test_loss", loss, prog_bar=True, logger=True)
230
- return loss
231
-
232
- def configure_optimizers(self):
233
- """ configure optimizers """
234
- model = self.model
235
- no_decay = ["bias", "LayerNorm.weight"]
236
- optimizer_grouped_parameters = [
237
- {
238
- "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
239
- "weight_decay": self.hparams.weight_decay,
240
- },
241
- {
242
- "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
243
- "weight_decay": 0.0,
244
- },
245
- ]
246
- optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
247
- self.opt = optimizer
248
- return [optimizer]
249
-
250
-
251
- class Summarization:
252
- """ Custom Summarization class """
253
-
254
- def __init__(self) -> None:
255
- """ initiates Summarization class """
256
- pass
257
-
258
- def from_pretrained(self, model_name="t5-base") -> None:
259
- """
260
- loads T5/MT5 Model model for training/finetuning
261
- Args:
262
- model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
263
- """
264
- self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
265
- self.model = T5ForConditionalGeneration.from_pretrained(
266
- f"{model_name}", return_dict=True
267
- )
268
-
269
- def train(
270
- self,
271
- train_df: pd.DataFrame,
272
- eval_df: pd.DataFrame,
273
- source_max_token_len: int = 512,
274
- target_max_token_len: int = 512,
275
- batch_size: int = 8,
276
- max_epochs: int = 5,
277
- use_gpu: bool = True,
278
- outputdir: str = "model",
279
- early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
280
- ):
281
- """
282
- trains T5/MT5 model on custom dataset
283
- Args:
284
- train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "input_text" and "output_text"
285
- eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "input_text" and
286
- "output_text"
287
- source_max_token_len (int, optional): max token length of source text. Defaults to 512.
288
- target_max_token_len (int, optional): max token length of target text. Defaults to 512.
289
- batch_size (int, optional): batch size. Defaults to 8.
290
- max_epochs (int, optional): max number of epochs. Defaults to 5.
291
- use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
292
- outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
293
- early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training,
294
- if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping.
295
- Defaults to 0 (disabled)
296
- """
297
- self.target_max_token_len = target_max_token_len
298
- self.data_module = PLDataModule(
299
- train_df,
300
- eval_df,
301
- self.tokenizer,
302
- batch_size=batch_size,
303
- source_max_token_len=source_max_token_len,
304
- target_max_token_len=target_max_token_len,
305
- )
306
-
307
- self.T5Model = LightningModel(
308
- tokenizer=self.tokenizer, model=self.model, output=outputdir
309
- )
310
-
311
- # checkpoint_callback = ModelCheckpoint(
312
- # dirpath="checkpoints",
313
- # filename="best-checkpoint-{epoch}-{train_loss:.2f}",
314
- # save_top_k=-1,
315
- # verbose=True,
316
- # monitor="train_loss",
317
- # mode="min",
318
- # )
319
-
320
- logger = MLFlowLogger(experiment_name="Summarization")
321
-
322
- early_stop_callback = (
323
- [
324
- EarlyStopping(
325
- monitor="val_loss",
326
- min_delta=0.00,
327
- patience=early_stopping_patience_epochs,
328
- verbose=True,
329
- mode="min",
330
- )
331
- ]
332
- if early_stopping_patience_epochs > 0
333
- else None
334
- )
335
-
336
- gpus = 1 if use_gpu else 0
337
-
338
- trainer = pl.Trainer(
339
- logger=logger,
340
- callbacks=early_stop_callback,
341
- max_epochs=max_epochs,
342
- gpus=gpus,
343
- progress_bar_refresh_rate=5,
344
- )
345
-
346
- trainer.fit(self.T5Model, self.data_module)
347
-
348
- def load_model(
349
- self, model_dir: str = "model", use_gpu: bool = False
350
- ):
351
- """
352
- loads a checkpoint for inferencing/prediction
353
- Args:
354
- model_type (str, optional): "t5" or "mt5". Defaults to "t5".
355
- model_dir (str, optional): path to model directory. Defaults to "outputs".
356
- use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
357
- """
358
- self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
359
- self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
360
-
361
- if use_gpu:
362
- if torch.cuda.is_available():
363
- self.device = torch.device("cuda")
364
- else:
365
- raise Exception("exception ---> no gpu found. set use_gpu=False, to use CPU")
366
- else:
367
- self.device = torch.device("cpu")
368
-
369
- self.model = self.model.to(self.device)
370
-
371
- def save_model(
372
- self,
373
- model_dir="model"
374
- ):
375
- """
376
- Save model to dir
377
- :param model_dir:
378
- :return: model is saved
379
- """
380
- path = f"{model_dir}"
381
- self.tokenizer.save_pretrained(path)
382
- self.model.save_pretrained(path)
383
-
384
- def predict(
385
- self,
386
- source_text: str,
387
- max_length: int = 512,
388
- num_return_sequences: int = 1,
389
- num_beams: int = 2,
390
- top_k: int = 50,
391
- top_p: float = 0.95,
392
- do_sample: bool = True,
393
- repetition_penalty: float = 2.5,
394
- length_penalty: float = 1.0,
395
- early_stopping: bool = True,
396
- skip_special_tokens: bool = True,
397
- clean_up_tokenization_spaces: bool = True,
398
- ):
399
- """
400
- generates prediction for T5/MT5 model
401
- Args:
402
- source_text (str): any text for generating predictions
403
- max_length (int, optional): max token length of prediction. Defaults to 512.
404
- num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
405
- num_beams (int, optional): number of beams. Defaults to 2.
406
- top_k (int, optional): Defaults to 50.
407
- top_p (float, optional): Defaults to 0.95.
408
- do_sample (bool, optional): Defaults to True.
409
- repetition_penalty (float, optional): Defaults to 2.5.
410
- length_penalty (float, optional): Defaults to 1.0.
411
- early_stopping (bool, optional): Defaults to True.
412
- skip_special_tokens (bool, optional): Defaults to True.
413
- clean_up_tokenization_spaces (bool, optional): Defaults to True.
414
- Returns:
415
- list[str]: returns predictions
416
- """
417
- input_ids = self.tokenizer.encode(
418
- source_text, return_tensors="pt", add_special_tokens=True
419
- )
420
-
421
- input_ids = input_ids.to(self.device)
422
- generated_ids = self.model.generate(
423
- input_ids=input_ids,
424
- num_beams=num_beams,
425
- max_length=max_length,
426
- repetition_penalty=repetition_penalty,
427
- length_penalty=length_penalty,
428
- early_stopping=early_stopping,
429
- top_p=top_p,
430
- top_k=top_k,
431
- num_return_sequences=num_return_sequences,
432
- )
433
- preds = [
434
- self.tokenizer.decode(
435
- g,
436
- skip_special_tokens=skip_special_tokens,
437
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
438
- )
439
- for g in generated_ids
440
- ]
441
- return preds