gagan3012 commited on
Commit
41f3be3
·
1 Parent(s): 5f1d01b
Files changed (1) hide show
  1. src/models/model.py +8 -6
src/models/model.py CHANGED
@@ -7,7 +7,7 @@ from transformers import (
7
  )
8
  from torch.utils.data import Dataset, DataLoader
9
  import pytorch_lightning as pl
10
- from pytorch_lightning.loggers import MLFlowLogger
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
13
  from pytorch_lightning import LightningDataModule
@@ -290,7 +290,7 @@ class Summarization:
290
  learning_rate: float = 0.0001,
291
  adam_epsilon: float = 0.01,
292
  num_workers: int = 2,
293
- weight_decay: float =0.0001
294
  ):
295
  """
296
  trains T5/MT5 model on custom dataset
@@ -323,11 +323,13 @@ class Summarization:
323
 
324
  self.T5Model = LightningModel(
325
  tokenizer=self.tokenizer, model=self.model, output=outputdir,
326
- learning_rate=learning_rate, adam_epsilon=adam_epsilon,weight_decay=weight_decay
327
  )
328
 
329
- MLlogger = MLFlowLogger(experiment_name="Summarization",
330
- tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
 
 
331
 
332
  # logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
333
 
@@ -348,7 +350,7 @@ class Summarization:
348
  gpus = -1 if use_gpu and torch.cuda.is_available() else 0
349
 
350
  trainer = Trainer(
351
- logger=[MLlogger],
352
  callbacks=early_stop_callback,
353
  max_epochs=max_epochs,
354
  gpus=gpus,
 
7
  )
8
  from torch.utils.data import Dataset, DataLoader
9
  import pytorch_lightning as pl
10
+ from pytorch_lightning.loggers import MLFlowLogger, WandbLogger
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
13
  from pytorch_lightning import LightningDataModule
 
290
  learning_rate: float = 0.0001,
291
  adam_epsilon: float = 0.01,
292
  num_workers: int = 2,
293
+ weight_decay: float = 0.0001
294
  ):
295
  """
296
  trains T5/MT5 model on custom dataset
 
323
 
324
  self.T5Model = LightningModel(
325
  tokenizer=self.tokenizer, model=self.model, output=outputdir,
326
+ learning_rate=learning_rate, adam_epsilon=adam_epsilon, weight_decay=weight_decay
327
  )
328
 
329
+ # MLlogger = MLFlowLogger(experiment_name="Summarization",
330
+ # tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")\
331
+
332
+ WandLogger = WandbLogger(project="keytotext")
333
 
334
  # logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
335
 
 
350
  gpus = -1 if use_gpu and torch.cuda.is_available() else 0
351
 
352
  trainer = Trainer(
353
+ logger=[WandLogger],
354
  callbacks=early_stop_callback,
355
  max_epochs=max_epochs,
356
  gpus=gpus,