darpanaswal commited on
Commit
04de3ea
·
verified ·
1 Parent(s): d825fc7

Update finetune.py

Browse files
Files changed (1) hide show
  1. finetune.py +4 -4
finetune.py CHANGED
@@ -26,7 +26,7 @@ def summarize_text_mt5(texts, model, tokenizer):
26
  max_length=512, truncation=True,
27
  padding=True).to(model.device)
28
  summary_ids = model.generate(inputs.input_ids,
29
- max_length=128,
30
  num_beams=4, length_penalty=2.0,
31
  early_stopping=True)
32
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
@@ -36,7 +36,7 @@ def summarize_text_mbart50(texts, model, tokenizer):
36
  inputs = tokenizer(texts, return_tensors="pt",
37
  max_length=1024, truncation=True,
38
  padding=True).to(model.device)
39
- summary_ids = model.generate(inputs.input_ids, max_length=128,
40
  num_beams=4, length_penalty=2.0,
41
  early_stopping=True)
42
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
@@ -94,10 +94,10 @@ def fine_tune(model_name, finetune_type, model, tokenizer, summarize_text, train
94
  print("Starting Fine-tuning...")
95
  if model_name == "mT5":
96
  max_input = 512
97
- max_output = 128
98
  else:
99
  max_input = 1024
100
- max_output = 128
101
 
102
  train_dataset = train
103
  eval_dataset = val
 
26
  max_length=512, truncation=True,
27
  padding=True).to(model.device)
28
  summary_ids = model.generate(inputs.input_ids,
29
+ max_length=60,
30
  num_beams=4, length_penalty=2.0,
31
  early_stopping=True)
32
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
 
36
  inputs = tokenizer(texts, return_tensors="pt",
37
  max_length=1024, truncation=True,
38
  padding=True).to(model.device)
39
+ summary_ids = model.generate(inputs.input_ids, max_length=60,
40
  num_beams=4, length_penalty=2.0,
41
  early_stopping=True)
42
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
 
94
  print("Starting Fine-tuning...")
95
  if model_name == "mT5":
96
  max_input = 512
97
+ max_output = 60
98
  else:
99
  max_input = 1024
100
+ max_output = 60
101
 
102
  train_dataset = train
103
  eval_dataset = val