gagan3012 commited on
Commit
6730e31
·
1 Parent(s): 14c7a1a
Files changed (1) hide show
  1. src/models/model.py +3 -3
src/models/model.py CHANGED
@@ -6,7 +6,7 @@ from dagshub.pytorch_lightning import DAGsHubLogger
6
  from transformers import (
7
  AdamW,
8
  T5ForConditionalGeneration,
9
- T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration,ByT5Tokenizer,
10
  )
11
  from torch.utils.data import Dataset, DataLoader
12
  import pytorch_lightning as pl
@@ -248,7 +248,7 @@ class Summarization:
248
  """ initiates Summarization class """
249
  pass
250
 
251
- def from_pretrained(self,model_type = "t5", model_name="t5-base") -> None:
252
  """
253
  loads T5/MT5 Model model for training/finetuning
254
  Args:
@@ -345,7 +345,7 @@ class Summarization:
345
  trainer.fit(self.T5Model, self.data_module)
346
 
347
  def load_model(
348
- self, model_type:str ='t5' , model_dir: str = "../../models", use_gpu: bool = False
349
  ):
350
  """
351
  loads a checkpoint for inferencing/prediction
 
6
  from transformers import (
7
  AdamW,
8
  T5ForConditionalGeneration,
9
+ T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration, ByT5Tokenizer,
10
  )
11
  from torch.utils.data import Dataset, DataLoader
12
  import pytorch_lightning as pl
 
248
  """ initiates Summarization class """
249
  pass
250
 
251
+ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
252
  """
253
  loads T5/MT5 Model model for training/finetuning
254
  Args:
 
345
  trainer.fit(self.T5Model, self.data_module)
346
 
347
  def load_model(
348
+ self, model_type: str = 't5', model_dir: str = "../../models", use_gpu: bool = False
349
  ):
350
  """
351
  loads a checkpoint for inferencing/prediction