gagan3012 commited on
Commit
5950d98
·
1 Parent(s): b3f69b2
Files changed (1) hide show
  1. src/models/model.py +9 -9
src/models/model.py CHANGED
@@ -7,7 +7,7 @@ from transformers import (
7
  )
8
  from torch.utils.data import Dataset, DataLoader
9
  import pytorch_lightning as pl
10
- from pytorch_lightning.loggers import MLFlowLogger, WandbLogger
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
13
  from pytorch_lightning import LightningDataModule
@@ -328,12 +328,12 @@ class Summarization:
328
  )
329
 
330
  MLlogger = MLFlowLogger(experiment_name="Summarization",
331
- tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow",
332
- save_dir='reports/training_metrics.txt')
333
 
334
  # WandLogger = WandbLogger(project="summarization-dagshub")
335
 
336
- #logger = DAGsHubLogger(metrics_path='reports/training_metrics.txt')
337
 
338
  early_stop_callback = (
339
  [
@@ -352,7 +352,7 @@ class Summarization:
352
  gpus = -1 if use_gpu and torch.cuda.is_available() else 0
353
 
354
  trainer = Trainer(
355
- logger=[MLlogger],
356
  callbacks=early_stop_callback,
357
  max_epochs=max_epochs,
358
  gpus=gpus,
@@ -460,10 +460,10 @@ class Summarization:
460
  num_return_sequences=num_return_sequences,
461
  )
462
  preds = self.tokenizer.decode(
463
- generated_ids[0],
464
- skip_special_tokens=skip_special_tokens,
465
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
466
- )
467
  return preds
468
 
469
  def evaluate(
 
7
  )
8
  from torch.utils.data import Dataset, DataLoader
9
  import pytorch_lightning as pl
10
+ from pytorch_lightning.loggers import MLFlowLogger
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
13
  from pytorch_lightning import LightningDataModule
 
328
  )
329
 
330
  MLlogger = MLFlowLogger(experiment_name="Summarization",
331
+ tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
332
+ #save_dir="reports/training_metrics.txt"
333
 
334
  # WandLogger = WandbLogger(project="summarization-dagshub")
335
 
336
+ # logger = DAGsHubLogger(metrics_path='reports/training_metrics.txt')
337
 
338
  early_stop_callback = (
339
  [
 
352
  gpus = -1 if use_gpu and torch.cuda.is_available() else 0
353
 
354
  trainer = Trainer(
355
+ logger=MLlogger,
356
  callbacks=early_stop_callback,
357
  max_epochs=max_epochs,
358
  gpus=gpus,
 
460
  num_return_sequences=num_return_sequences,
461
  )
462
  preds = self.tokenizer.decode(
463
+ generated_ids[0],
464
+ skip_special_tokens=skip_special_tokens,
465
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
466
+ )
467
  return preds
468
 
469
  def evaluate(