Gagan Bhatia commited on
Commit
7a9c664
·
1 Parent(s): b9412d1
Files changed (1) hide show
  1. src/models/model.py +3 -3
src/models/model.py CHANGED
@@ -269,7 +269,7 @@ class Summarization:
269
  batch_size: int = 8,
270
  max_epochs: int = 5,
271
  use_gpu: bool = True,
272
- outputdir: str = "model",
273
  early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
274
  ):
275
  """
@@ -340,7 +340,7 @@ class Summarization:
340
  trainer.fit(self.T5Model, self.data_module)
341
 
342
  def load_model(
343
- self, model_dir: str = "model", use_gpu: bool = False
344
  ):
345
  """
346
  loads a checkpoint for inferencing/prediction
@@ -364,7 +364,7 @@ class Summarization:
364
 
365
  def save_model(
366
  self,
367
- model_dir="model"
368
  ):
369
  """
370
  Save model to dir
 
269
  batch_size: int = 8,
270
  max_epochs: int = 5,
271
  use_gpu: bool = True,
272
+ outputdir: str = "models",
273
  early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
274
  ):
275
  """
 
340
  trainer.fit(self.T5Model, self.data_module)
341
 
342
  def load_model(
343
+ self, model_dir: str = "models", use_gpu: bool = False
344
  ):
345
  """
346
  loads a checkpoint for inferencing/prediction
 
364
 
365
  def save_model(
366
  self,
367
+ model_dir="models"
368
  ):
369
  """
370
  Save model to dir