Spaces:
Runtime error
Runtime error
fixes
Browse files- src/models/model.py +8 -6
src/models/model.py
CHANGED
@@ -7,7 +7,7 @@ from transformers import (
|
|
7 |
)
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import pytorch_lightning as pl
|
10 |
-
from pytorch_lightning.loggers import MLFlowLogger
|
11 |
from pytorch_lightning import Trainer
|
12 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
13 |
from pytorch_lightning import LightningDataModule
|
@@ -290,7 +290,7 @@ class Summarization:
|
|
290 |
learning_rate: float = 0.0001,
|
291 |
adam_epsilon: float = 0.01,
|
292 |
num_workers: int = 2,
|
293 |
-
weight_decay: float =0.0001
|
294 |
):
|
295 |
"""
|
296 |
trains T5/MT5 model on custom dataset
|
@@ -323,11 +323,13 @@ class Summarization:
|
|
323 |
|
324 |
self.T5Model = LightningModel(
|
325 |
tokenizer=self.tokenizer, model=self.model, output=outputdir,
|
326 |
-
learning_rate=learning_rate, adam_epsilon=adam_epsilon,weight_decay=weight_decay
|
327 |
)
|
328 |
|
329 |
-
MLlogger = MLFlowLogger(experiment_name="Summarization",
|
330 |
-
|
|
|
|
|
331 |
|
332 |
# logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
|
333 |
|
@@ -348,7 +350,7 @@ class Summarization:
|
|
348 |
gpus = -1 if use_gpu and torch.cuda.is_available() else 0
|
349 |
|
350 |
trainer = Trainer(
|
351 |
-
logger=[
|
352 |
callbacks=early_stop_callback,
|
353 |
max_epochs=max_epochs,
|
354 |
gpus=gpus,
|
|
|
7 |
)
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import pytorch_lightning as pl
|
10 |
+
from pytorch_lightning.loggers import MLFlowLogger, WandbLogger
|
11 |
from pytorch_lightning import Trainer
|
12 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
13 |
from pytorch_lightning import LightningDataModule
|
|
|
290 |
learning_rate: float = 0.0001,
|
291 |
adam_epsilon: float = 0.01,
|
292 |
num_workers: int = 2,
|
293 |
+
weight_decay: float = 0.0001
|
294 |
):
|
295 |
"""
|
296 |
trains T5/MT5 model on custom dataset
|
|
|
323 |
|
324 |
self.T5Model = LightningModel(
|
325 |
tokenizer=self.tokenizer, model=self.model, output=outputdir,
|
326 |
+
learning_rate=learning_rate, adam_epsilon=adam_epsilon, weight_decay=weight_decay
|
327 |
)
|
328 |
|
329 |
+
# MLlogger = MLFlowLogger(experiment_name="Summarization",
|
330 |
+
# tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")\
|
331 |
+
|
332 |
+
WandLogger = WandbLogger(project="keytotext")
|
333 |
|
334 |
# logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
|
335 |
|
|
|
350 |
gpus = -1 if use_gpu and torch.cuda.is_available() else 0
|
351 |
|
352 |
trainer = Trainer(
|
353 |
+
logger=[WandLogger],
|
354 |
callbacks=early_stop_callback,
|
355 |
max_epochs=max_epochs,
|
356 |
gpus=gpus,
|