Spaces:
No application file
No application file
Update finetune.py
Browse files- finetune.py +4 -4
finetune.py
CHANGED
@@ -26,7 +26,7 @@ def summarize_text_mt5(texts, model, tokenizer):
|
|
26 |
max_length=512, truncation=True,
|
27 |
padding=True).to(model.device)
|
28 |
summary_ids = model.generate(inputs.input_ids,
|
29 |
-
max_length=
|
30 |
num_beams=4, length_penalty=2.0,
|
31 |
early_stopping=True)
|
32 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
@@ -36,7 +36,7 @@ def summarize_text_mbart50(texts, model, tokenizer):
|
|
36 |
inputs = tokenizer(texts, return_tensors="pt",
|
37 |
max_length=1024, truncation=True,
|
38 |
padding=True).to(model.device)
|
39 |
-
summary_ids = model.generate(inputs.input_ids, max_length=
|
40 |
num_beams=4, length_penalty=2.0,
|
41 |
early_stopping=True)
|
42 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
@@ -94,10 +94,10 @@ def fine_tune(model_name, finetune_type, model, tokenizer, summarize_text, train
|
|
94 |
print("Starting Fine-tuning...")
|
95 |
if model_name == "mT5":
|
96 |
max_input = 512
|
97 |
-
max_output =
|
98 |
else:
|
99 |
max_input = 1024
|
100 |
-
max_output =
|
101 |
|
102 |
train_dataset = train
|
103 |
eval_dataset = val
|
|
|
26 |
max_length=512, truncation=True,
|
27 |
padding=True).to(model.device)
|
28 |
summary_ids = model.generate(inputs.input_ids,
|
29 |
+
max_length=60,
|
30 |
num_beams=4, length_penalty=2.0,
|
31 |
early_stopping=True)
|
32 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
|
|
36 |
inputs = tokenizer(texts, return_tensors="pt",
|
37 |
max_length=1024, truncation=True,
|
38 |
padding=True).to(model.device)
|
39 |
+
summary_ids = model.generate(inputs.input_ids, max_length=60,
|
40 |
num_beams=4, length_penalty=2.0,
|
41 |
early_stopping=True)
|
42 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
|
|
94 |
print("Starting Fine-tuning...")
|
95 |
if model_name == "mT5":
|
96 |
max_input = 512
|
97 |
+
max_output = 60
|
98 |
else:
|
99 |
max_input = 1024
|
100 |
+
max_output = 60
|
101 |
|
102 |
train_dataset = train
|
103 |
eval_dataset = val
|