Spaces:
Runtime error
Runtime error
Fixes
Browse files- src/models/model.py +16 -5
src/models/model.py
CHANGED
@@ -161,8 +161,6 @@ class LightningModel(LightningModule):
|
|
161 |
self.model = model
|
162 |
self.tokenizer = tokenizer
|
163 |
self.output = output
|
164 |
-
# self.val_acc = Accuracy()
|
165 |
-
# self.train_acc = Accuracy()
|
166 |
|
167 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
168 |
""" forward step """
|
@@ -347,7 +345,7 @@ class Summarization:
|
|
347 |
trainer.fit(self.T5Model, self.data_module)
|
348 |
|
349 |
def load_model(
|
350 |
-
self, model_dir: str = "../../models", use_gpu: bool = False
|
351 |
):
|
352 |
"""
|
353 |
loads a checkpoint for inferencing/prediction
|
@@ -356,8 +354,21 @@ class Summarization:
|
|
356 |
model_dir (str, optional): path to model directory. Defaults to "outputs".
|
357 |
use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
|
358 |
"""
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
if use_gpu:
|
363 |
if torch.cuda.is_available():
|
|
|
161 |
self.model = model
|
162 |
self.tokenizer = tokenizer
|
163 |
self.output = output
|
|
|
|
|
164 |
|
165 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
166 |
""" forward step """
|
|
|
345 |
trainer.fit(self.T5Model, self.data_module)
|
346 |
|
347 |
def load_model(
|
348 |
+
self, model_type:str ='t5' , model_dir: str = "../../models", use_gpu: bool = False
|
349 |
):
|
350 |
"""
|
351 |
loads a checkpoint for inferencing/prediction
|
|
|
354 |
model_dir (str, optional): path to model directory. Defaults to "outputs".
|
355 |
use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
|
356 |
"""
|
357 |
+
if model_type == "t5":
|
358 |
+
self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
|
359 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
360 |
+
f"{model_dir}", return_dict=True
|
361 |
+
)
|
362 |
+
elif model_type == "mt5":
|
363 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
|
364 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(
|
365 |
+
f"{model_dir}", return_dict=True
|
366 |
+
)
|
367 |
+
elif model_type == "byt5":
|
368 |
+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
|
369 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
370 |
+
f"{model_dir}", return_dict=True
|
371 |
+
)
|
372 |
|
373 |
if use_gpu:
|
374 |
if torch.cuda.is_available():
|