Spaces:
Runtime error
Runtime error
Dean
commited on
Commit
·
a7c7fdd
1
Parent(s):
83a4c6e
Fixing MLflow logging by using the MLflow autolog feature for Pytorch instead of the built-in pytorch lightning MLflowLogger
Browse files- .gitignore +1 -0
- src/models/model.py +6 -6
.gitignore
CHANGED
@@ -96,3 +96,4 @@ coverage.xml
|
|
96 |
summarization-dagshub/
|
97 |
/models
|
98 |
default/
|
|
|
|
96 |
summarization-dagshub/
|
97 |
/models
|
98 |
default/
|
99 |
+
artifacts/
|
src/models/model.py
CHANGED
@@ -7,7 +7,6 @@ from transformers import (
|
|
7 |
)
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import pytorch_lightning as pl
|
10 |
-
from pytorch_lightning.loggers import MLFlowLogger
|
11 |
from dagshub.pytorch_lightning import DAGsHubLogger
|
12 |
from pytorch_lightning import Trainer
|
13 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
@@ -16,6 +15,7 @@ from pytorch_lightning import LightningModule
|
|
16 |
from datasets import load_metric
|
17 |
from tqdm.auto import tqdm
|
18 |
|
|
|
19 |
|
20 |
torch.cuda.empty_cache()
|
21 |
pl.seed_everything(42)
|
@@ -326,9 +326,6 @@ class Summarization:
|
|
326 |
learning_rate=learning_rate, adam_epsilon=adam_epsilon, weight_decay=weight_decay
|
327 |
)
|
328 |
|
329 |
-
MLlogger = MLFlowLogger(experiment_name="Summarization",
|
330 |
-
tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
|
331 |
-
|
332 |
logger = DAGsHubLogger(metrics_path='reports/training_metrics.csv',
|
333 |
hparams_path='reports/training_params.yml')
|
334 |
|
@@ -349,14 +346,17 @@ class Summarization:
|
|
349 |
gpus = -1 if use_gpu and torch.cuda.is_available() else 0
|
350 |
|
351 |
trainer = Trainer(
|
352 |
-
logger=
|
353 |
callbacks=early_stop_callback,
|
354 |
max_epochs=max_epochs,
|
355 |
gpus=gpus,
|
356 |
progress_bar_refresh_rate=5,
|
357 |
)
|
358 |
|
359 |
-
|
|
|
|
|
|
|
360 |
|
361 |
def load_model(
|
362 |
self, model_type: str = 't5', model_dir: str = "models", use_gpu: bool = False
|
|
|
7 |
)
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import pytorch_lightning as pl
|
|
|
10 |
from dagshub.pytorch_lightning import DAGsHubLogger
|
11 |
from pytorch_lightning import Trainer
|
12 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
|
15 |
from datasets import load_metric
|
16 |
from tqdm.auto import tqdm
|
17 |
|
18 |
+
import mlflow.pytorch
|
19 |
|
20 |
torch.cuda.empty_cache()
|
21 |
pl.seed_everything(42)
|
|
|
326 |
learning_rate=learning_rate, adam_epsilon=adam_epsilon, weight_decay=weight_decay
|
327 |
)
|
328 |
|
|
|
|
|
|
|
329 |
logger = DAGsHubLogger(metrics_path='reports/training_metrics.csv',
|
330 |
hparams_path='reports/training_params.yml')
|
331 |
|
|
|
346 |
gpus = -1 if use_gpu and torch.cuda.is_available() else 0
|
347 |
|
348 |
trainer = Trainer(
|
349 |
+
logger=logger,
|
350 |
callbacks=early_stop_callback,
|
351 |
max_epochs=max_epochs,
|
352 |
gpus=gpus,
|
353 |
progress_bar_refresh_rate=5,
|
354 |
)
|
355 |
|
356 |
+
mlflow.pytorch.autolog(log_models=False)
|
357 |
+
|
358 |
+
with mlflow.start_run() as run:
|
359 |
+
trainer.fit(self.T5Model, self.data_module)
|
360 |
|
361 |
def load_model(
|
362 |
self, model_type: str = 't5', model_dir: str = "models", use_gpu: bool = False
|