Spaces:
Runtime error
Runtime error
Fixes
Browse files- src/models/model.py +3 -3
src/models/model.py
CHANGED
@@ -6,7 +6,7 @@ from dagshub.pytorch_lightning import DAGsHubLogger
|
|
6 |
from transformers import (
|
7 |
AdamW,
|
8 |
T5ForConditionalGeneration,
|
9 |
-
T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration,ByT5Tokenizer,
|
10 |
)
|
11 |
from torch.utils.data import Dataset, DataLoader
|
12 |
import pytorch_lightning as pl
|
@@ -248,7 +248,7 @@ class Summarization:
|
|
248 |
""" initiates Summarization class """
|
249 |
pass
|
250 |
|
251 |
-
def from_pretrained(self,model_type
|
252 |
"""
|
253 |
loads T5/MT5 Model model for training/finetuning
|
254 |
Args:
|
@@ -345,7 +345,7 @@ class Summarization:
|
|
345 |
trainer.fit(self.T5Model, self.data_module)
|
346 |
|
347 |
def load_model(
|
348 |
-
self, model_type:str ='t5'
|
349 |
):
|
350 |
"""
|
351 |
loads a checkpoint for inferencing/prediction
|
|
|
6 |
from transformers import (
|
7 |
AdamW,
|
8 |
T5ForConditionalGeneration,
|
9 |
+
T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration, ByT5Tokenizer,
|
10 |
)
|
11 |
from torch.utils.data import Dataset, DataLoader
|
12 |
import pytorch_lightning as pl
|
|
|
248 |
""" initiates Summarization class """
|
249 |
pass
|
250 |
|
251 |
+
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|
252 |
"""
|
253 |
loads T5/MT5 Model model for training/finetuning
|
254 |
Args:
|
|
|
345 |
trainer.fit(self.T5Model, self.data_module)
|
346 |
|
347 |
def load_model(
|
348 |
+
self, model_type: str = 't5', model_dir: str = "../../models", use_gpu: bool = False
|
349 |
):
|
350 |
"""
|
351 |
loads a checkpoint for inferencing/prediction
|