gagan3012 commited on
Commit
14c7a1a
·
1 Parent(s): aef2f7d
Files changed (1) hide show
  1. src/models/model.py +16 -5
src/models/model.py CHANGED
@@ -161,8 +161,6 @@ class LightningModel(LightningModule):
161
  self.model = model
162
  self.tokenizer = tokenizer
163
  self.output = output
164
- # self.val_acc = Accuracy()
165
- # self.train_acc = Accuracy()
166
 
167
  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
168
  """ forward step """
@@ -347,7 +345,7 @@ class Summarization:
347
  trainer.fit(self.T5Model, self.data_module)
348
 
349
  def load_model(
350
- self, model_dir: str = "../../models", use_gpu: bool = False
351
  ):
352
  """
353
  loads a checkpoint for inferencing/prediction
@@ -356,8 +354,21 @@ class Summarization:
356
  model_dir (str, optional): path to model directory. Defaults to "outputs".
357
  use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
358
  """
359
- self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
360
- self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  if use_gpu:
363
  if torch.cuda.is_available():
 
161
  self.model = model
162
  self.tokenizer = tokenizer
163
  self.output = output
 
 
164
 
165
  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
166
  """ forward step """
 
345
  trainer.fit(self.T5Model, self.data_module)
346
 
347
  def load_model(
348
+ self, model_type:str ='t5' , model_dir: str = "../../models", use_gpu: bool = False
349
  ):
350
  """
351
  loads a checkpoint for inferencing/prediction
 
354
  model_dir (str, optional): path to model directory. Defaults to "outputs".
355
  use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
356
  """
357
+ if model_type == "t5":
358
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
359
+ self.model = T5ForConditionalGeneration.from_pretrained(
360
+ f"{model_dir}", return_dict=True
361
+ )
362
+ elif model_type == "mt5":
363
+ self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
364
+ self.model = MT5ForConditionalGeneration.from_pretrained(
365
+ f"{model_dir}", return_dict=True
366
+ )
367
+ elif model_type == "byt5":
368
+ self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
369
+ self.model = T5ForConditionalGeneration.from_pretrained(
370
+ f"{model_dir}", return_dict=True
371
+ )
372
 
373
  if use_gpu:
374
  if torch.cuda.is_available():