Spaces:
Runtime error
Runtime error
split added
Browse files- src/models/model.py +11 -6
src/models/model.py
CHANGED
@@ -13,7 +13,8 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
13 |
from pytorch_lightning import LightningDataModule
|
14 |
from pytorch_lightning import LightningModule
|
15 |
from datasets import load_metric
|
16 |
-
|
|
|
17 |
|
18 |
|
19 |
torch.cuda.empty_cache()
|
@@ -150,7 +151,7 @@ class PLDataModule(LightningDataModule):
|
|
150 |
class LightningModel(LightningModule):
|
151 |
""" PyTorch Lightning Model class"""
|
152 |
|
153 |
-
def __init__(self, tokenizer, model, learning_rate, adam_epsilon, output: str = "outputs"):
|
154 |
"""
|
155 |
initiates a PyTorch Lightning Model
|
156 |
Args:
|
@@ -162,6 +163,9 @@ class LightningModel(LightningModule):
|
|
162 |
self.model = model
|
163 |
self.tokenizer = tokenizer
|
164 |
self.output = output
|
|
|
|
|
|
|
165 |
|
166 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
167 |
""" forward step """
|
@@ -230,7 +234,7 @@ class LightningModel(LightningModule):
|
|
230 |
optimizer_grouped_parameters = [
|
231 |
{
|
232 |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
233 |
-
"weight_decay": self.
|
234 |
},
|
235 |
{
|
236 |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
@@ -285,7 +289,8 @@ class Summarization:
|
|
285 |
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
|
286 |
learning_rate: float = 0.0001,
|
287 |
adam_epsilon: float = 0.01,
|
288 |
-
num_workers: int = 2
|
|
|
289 |
):
|
290 |
"""
|
291 |
trains T5/MT5 model on custom dataset
|
@@ -318,13 +323,13 @@ class Summarization:
|
|
318 |
|
319 |
self.T5Model = LightningModel(
|
320 |
tokenizer=self.tokenizer, model=self.model, output=outputdir,
|
321 |
-
learning_rate=learning_rate, adam_epsilon=adam_epsilon
|
322 |
)
|
323 |
|
324 |
MLlogger = MLFlowLogger(experiment_name="Summarization",
|
325 |
tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
|
326 |
|
327 |
-
#logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
|
328 |
|
329 |
early_stop_callback = (
|
330 |
[
|
|
|
13 |
from pytorch_lightning import LightningDataModule
|
14 |
from pytorch_lightning import LightningModule
|
15 |
from datasets import load_metric
|
16 |
+
|
17 |
+
# from dagshub.pytorch_lightning import DAGsHubLogger
|
18 |
|
19 |
|
20 |
torch.cuda.empty_cache()
|
|
|
151 |
class LightningModel(LightningModule):
|
152 |
""" PyTorch Lightning Model class"""
|
153 |
|
154 |
+
def __init__(self, tokenizer, model, learning_rate, adam_epsilon, weight_decay, output: str = "outputs"):
|
155 |
"""
|
156 |
initiates a PyTorch Lightning Model
|
157 |
Args:
|
|
|
163 |
self.model = model
|
164 |
self.tokenizer = tokenizer
|
165 |
self.output = output
|
166 |
+
self.learning_rate = learning_rate
|
167 |
+
self.adam_epsilon = adam_epsilon
|
168 |
+
self.weight_decay = weight_decay
|
169 |
|
170 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
171 |
""" forward step """
|
|
|
234 |
optimizer_grouped_parameters = [
|
235 |
{
|
236 |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
237 |
+
"weight_decay": self.weight_decay,
|
238 |
},
|
239 |
{
|
240 |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
|
289 |
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
|
290 |
learning_rate: float = 0.0001,
|
291 |
adam_epsilon: float = 0.01,
|
292 |
+
num_workers: int = 2,
|
293 |
+
weight_decay: float =0.0001
|
294 |
):
|
295 |
"""
|
296 |
trains T5/MT5 model on custom dataset
|
|
|
323 |
|
324 |
self.T5Model = LightningModel(
|
325 |
tokenizer=self.tokenizer, model=self.model, output=outputdir,
|
326 |
+
learning_rate=learning_rate, adam_epsilon=adam_epsilon,weight_decay=weight_decay
|
327 |
)
|
328 |
|
329 |
MLlogger = MLFlowLogger(experiment_name="Summarization",
|
330 |
tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
|
331 |
|
332 |
+
# logger = DAGsHubLogger(metrics_path='reports/metrics.txt')
|
333 |
|
334 |
early_stop_callback = (
|
335 |
[
|