Dean commited on
Commit
a7c7fdd
·
1 Parent(s): 83a4c6e

Fixing MLflow logging by using the MLflow autolog feature for Pytorch instead of the built-in pytorch lightning MLflowLogger

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. src/models/model.py +6 -6
.gitignore CHANGED
@@ -96,3 +96,4 @@ coverage.xml
96
  summarization-dagshub/
97
  /models
98
  default/
 
 
96
  summarization-dagshub/
97
  /models
98
  default/
99
+ artifacts/
src/models/model.py CHANGED
@@ -7,7 +7,6 @@ from transformers import (
7
  )
8
  from torch.utils.data import Dataset, DataLoader
9
  import pytorch_lightning as pl
10
- from pytorch_lightning.loggers import MLFlowLogger
11
  from dagshub.pytorch_lightning import DAGsHubLogger
12
  from pytorch_lightning import Trainer
13
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
@@ -16,6 +15,7 @@ from pytorch_lightning import LightningModule
16
  from datasets import load_metric
17
  from tqdm.auto import tqdm
18
 
 
19
 
20
  torch.cuda.empty_cache()
21
  pl.seed_everything(42)
@@ -326,9 +326,6 @@ class Summarization:
326
  learning_rate=learning_rate, adam_epsilon=adam_epsilon, weight_decay=weight_decay
327
  )
328
 
329
- MLlogger = MLFlowLogger(experiment_name="Summarization",
330
- tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
331
-
332
  logger = DAGsHubLogger(metrics_path='reports/training_metrics.csv',
333
  hparams_path='reports/training_params.yml')
334
 
@@ -349,14 +346,17 @@ class Summarization:
349
  gpus = -1 if use_gpu and torch.cuda.is_available() else 0
350
 
351
  trainer = Trainer(
352
- logger=[MLlogger, logger],
353
  callbacks=early_stop_callback,
354
  max_epochs=max_epochs,
355
  gpus=gpus,
356
  progress_bar_refresh_rate=5,
357
  )
358
 
359
- trainer.fit(self.T5Model, self.data_module)
 
 
 
360
 
361
  def load_model(
362
  self, model_type: str = 't5', model_dir: str = "models", use_gpu: bool = False
 
7
  )
8
  from torch.utils.data import Dataset, DataLoader
9
  import pytorch_lightning as pl
 
10
  from dagshub.pytorch_lightning import DAGsHubLogger
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
 
15
  from datasets import load_metric
16
  from tqdm.auto import tqdm
17
 
18
+ import mlflow.pytorch
19
 
20
  torch.cuda.empty_cache()
21
  pl.seed_everything(42)
 
326
  learning_rate=learning_rate, adam_epsilon=adam_epsilon, weight_decay=weight_decay
327
  )
328
 
 
 
 
329
  logger = DAGsHubLogger(metrics_path='reports/training_metrics.csv',
330
  hparams_path='reports/training_params.yml')
331
 
 
346
  gpus = -1 if use_gpu and torch.cuda.is_available() else 0
347
 
348
  trainer = Trainer(
349
+ logger=logger,
350
  callbacks=early_stop_callback,
351
  max_epochs=max_epochs,
352
  gpus=gpus,
353
  progress_bar_refresh_rate=5,
354
  )
355
 
356
+ mlflow.pytorch.autolog(log_models=False)
357
+
358
+ with mlflow.start_run() as run:
359
+ trainer.fit(self.T5Model, self.data_module)
360
 
361
  def load_model(
362
  self, model_type: str = 't5', model_dir: str = "models", use_gpu: bool = False