gagan3012 commited on
Commit
83ad679
·
1 Parent(s): 07c9283

split added

Browse files
Files changed (1) hide show
  1. src/models/model.py +11 -6
src/models/model.py CHANGED
@@ -13,7 +13,8 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
13
  from pytorch_lightning import LightningDataModule
14
  from pytorch_lightning import LightningModule
15
  from datasets import load_metric
16
- #from dagshub.pytorch_lightning import DAGsHubLogger
 
17
 
18
 
19
  torch.cuda.empty_cache()
@@ -150,7 +151,7 @@ class PLDataModule(LightningDataModule):
150
  class LightningModel(LightningModule):
151
  """ PyTorch Lightning Model class"""
152
 
153
- def __init__(self, tokenizer, model, learning_rate, adam_epsilon, output: str = "outputs"):
154
  """
155
  initiates a PyTorch Lightning Model
156
  Args:
@@ -162,6 +163,9 @@ class LightningModel(LightningModule):
162
  self.model = model
163
  self.tokenizer = tokenizer
164
  self.output = output
 
 
 
165
 
166
  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
167
  """ forward step """
@@ -230,7 +234,7 @@ class LightningModel(LightningModule):
230
  optimizer_grouped_parameters = [
231
  {
232
  "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
233
- "weight_decay": self.hparams.weight_decay,
234
  },
235
  {
236
  "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
@@ -285,7 +289,8 @@ class Summarization:
285
  early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
286
  learning_rate: float = 0.0001,
287
  adam_epsilon: float = 0.01,
288
- num_workers: int = 2
 
289
  ):
290
  """
291
  trains T5/MT5 model on custom dataset
@@ -318,13 +323,13 @@ class Summarization:
318
 
319
  self.T5Model = LightningModel(
320
  tokenizer=self.tokenizer, model=self.model, output=outputdir,
321
- learning_rate=learning_rate, adam_epsilon=adam_epsilon
322
  )
323
 
324
  MLlogger = MLFlowLogger(experiment_name="Summarization",
325
  tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
326
 
327
- #logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
328
 
329
  early_stop_callback = (
330
  [
 
13
  from pytorch_lightning import LightningDataModule
14
  from pytorch_lightning import LightningModule
15
  from datasets import load_metric
16
+
17
+ # from dagshub.pytorch_lightning import DAGsHubLogger
18
 
19
 
20
  torch.cuda.empty_cache()
 
151
  class LightningModel(LightningModule):
152
  """ PyTorch Lightning Model class"""
153
 
154
+ def __init__(self, tokenizer, model, learning_rate, adam_epsilon, weight_decay, output: str = "outputs"):
155
  """
156
  initiates a PyTorch Lightning Model
157
  Args:
 
163
  self.model = model
164
  self.tokenizer = tokenizer
165
  self.output = output
166
+ self.learning_rate = learning_rate
167
+ self.adam_epsilon = adam_epsilon
168
+ self.weight_decay = weight_decay
169
 
170
  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
171
  """ forward step """
 
234
  optimizer_grouped_parameters = [
235
  {
236
  "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
237
+ "weight_decay": self.weight_decay,
238
  },
239
  {
240
  "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
 
289
  early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
290
  learning_rate: float = 0.0001,
291
  adam_epsilon: float = 0.01,
292
+ num_workers: int = 2,
293
+ weight_decay: float =0.0001
294
  ):
295
  """
296
  trains T5/MT5 model on custom dataset
 
323
 
324
  self.T5Model = LightningModel(
325
  tokenizer=self.tokenizer, model=self.model, output=outputdir,
326
+ learning_rate=learning_rate, adam_epsilon=adam_epsilon,weight_decay=weight_decay
327
  )
328
 
329
  MLlogger = MLFlowLogger(experiment_name="Summarization",
330
  tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
331
 
332
+ # logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
333
 
334
  early_stop_callback = (
335
  [