Spaces:
Runtime error
Runtime error
Fixes
Browse files- src/models/evaluate_model.py +1 -1
- src/models/model.py +18 -6
- src/models/train_model.py +6 -4
src/models/evaluate_model.py
CHANGED
@@ -10,7 +10,7 @@ def evaluate_model():
|
|
10 |
test_df = pd.load_csv('../../data/processed/test.csv')
|
11 |
model = Summarization()
|
12 |
model.load_model()
|
13 |
-
results = model.evaluate(test_df=test_df)
|
14 |
with dagshub.dagshub_logger() as logger:
|
15 |
logger.log_metrics(results)
|
16 |
return results
|
|
|
10 |
test_df = pd.load_csv('../../data/processed/test.csv')
|
11 |
model = Summarization()
|
12 |
model.load_model()
|
13 |
+
results = model.evaluate(test_df=test_df,metrics="rouge")
|
14 |
with dagshub.dagshub_logger() as logger:
|
15 |
logger.log_metrics(results)
|
16 |
return results
|
src/models/model.py
CHANGED
@@ -6,7 +6,7 @@ from dagshub.pytorch_lightning import DAGsHubLogger
|
|
6 |
from transformers import (
|
7 |
AdamW,
|
8 |
T5ForConditionalGeneration,
|
9 |
-
T5TokenizerFast as T5Tokenizer,
|
10 |
)
|
11 |
from torch.utils.data import Dataset, DataLoader
|
12 |
import pytorch_lightning as pl
|
@@ -250,16 +250,28 @@ class Summarization:
|
|
250 |
""" initiates Summarization class """
|
251 |
pass
|
252 |
|
253 |
-
def from_pretrained(self, model_name="t5-base") -> None:
|
254 |
"""
|
255 |
loads T5/MT5 Model model for training/finetuning
|
256 |
Args:
|
257 |
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
|
|
|
258 |
"""
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
def train(
|
265 |
self,
|
|
|
6 |
from transformers import (
|
7 |
AdamW,
|
8 |
T5ForConditionalGeneration,
|
9 |
+
T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration,ByT5Tokenizer,
|
10 |
)
|
11 |
from torch.utils.data import Dataset, DataLoader
|
12 |
import pytorch_lightning as pl
|
|
|
250 |
""" initiates Summarization class """
|
251 |
pass
|
252 |
|
253 |
+
def from_pretrained(self,model_type = "t5", model_name="t5-base") -> None:
|
254 |
"""
|
255 |
loads T5/MT5 Model model for training/finetuning
|
256 |
Args:
|
257 |
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
|
258 |
+
:param model_type:
|
259 |
"""
|
260 |
+
if model_type == "t5":
|
261 |
+
self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
|
262 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
263 |
+
f"{model_name}", return_dict=True
|
264 |
+
)
|
265 |
+
elif model_type == "mt5":
|
266 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_name}")
|
267 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(
|
268 |
+
f"{model_name}", return_dict=True
|
269 |
+
)
|
270 |
+
elif model_type == "byt5":
|
271 |
+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
|
272 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
273 |
+
f"{model_name}", return_dict=True
|
274 |
+
)
|
275 |
|
276 |
def train(
|
277 |
self,
|
src/models/train_model.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1 |
from src.models.model import Summarization
|
2 |
import pandas as pd
|
3 |
|
|
|
4 |
def train_model():
|
5 |
"""
|
6 |
Train the model
|
7 |
"""
|
8 |
# Load the data
|
9 |
-
train_df = pd.
|
10 |
-
eval_df = pd.
|
11 |
|
12 |
model = Summarization()
|
13 |
-
model.from_pretrained('t5-base')
|
14 |
model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
|
15 |
model.save_model()
|
16 |
|
|
|
17 |
if __name__ == '__main__':
|
18 |
-
train_model()
|
|
|
1 |
from src.models.model import Summarization
|
2 |
import pandas as pd
|
3 |
|
4 |
+
|
5 |
def train_model():
|
6 |
"""
|
7 |
Train the model
|
8 |
"""
|
9 |
# Load the data
|
10 |
+
train_df = pd.read_csv('../../data/processed/train.csv')
|
11 |
+
eval_df = pd.read_csv('../../data/processed/validation.csv')
|
12 |
|
13 |
model = Summarization()
|
14 |
+
model.from_pretrained('t5','t5-base')
|
15 |
model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
|
16 |
model.save_model()
|
17 |
|
18 |
+
|
19 |
if __name__ == '__main__':
|
20 |
+
train_model()
|