darpanaswal commited on
Commit
64c7578
·
verified ·
1 Parent(s): 04de3ea

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -3
main.py CHANGED
@@ -34,7 +34,7 @@ def summarize_text_mt5(texts, model, tokenizer):
34
  max_length=512, truncation=True,
35
  padding=True).to(model.device)
36
  summary_ids = model.generate(input_ids = inputs.input_ids,
37
- max_length=128,
38
  num_beams=4, length_penalty=2.0,
39
  early_stopping=True)
40
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
@@ -44,7 +44,7 @@ def summarize_text_mbart50(texts, model, tokenizer):
44
  inputs = tokenizer(texts, return_tensors="pt",
45
  max_length=1024, truncation=True,
46
  padding=True).to(model.device)
47
- summary_ids = model.generate(inputs.input_ids, max_length=128,
48
  num_beams=4, length_penalty=2.0,
49
  early_stopping=True)
50
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
@@ -60,7 +60,7 @@ def summarize_text_llama(texts, model, tokenizer):
60
 
61
  summary_ids = model.generate(
62
  inputs.input_ids,
63
- max_new_tokens=128,
64
  temperature=0.7,
65
  top_p=0.9,
66
  num_beams=4,
 
34
  max_length=512, truncation=True,
35
  padding=True).to(model.device)
36
  summary_ids = model.generate(input_ids = inputs.input_ids,
37
+ max_length=60,
38
  num_beams=4, length_penalty=2.0,
39
  early_stopping=True)
40
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
 
44
  inputs = tokenizer(texts, return_tensors="pt",
45
  max_length=1024, truncation=True,
46
  padding=True).to(model.device)
47
+ summary_ids = model.generate(inputs.input_ids, max_length=60,
48
  num_beams=4, length_penalty=2.0,
49
  early_stopping=True)
50
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
 
60
 
61
  summary_ids = model.generate(
62
  inputs.input_ids,
63
+ max_new_tokens=60,
64
  temperature=0.7,
65
  top_p=0.9,
66
  num_beams=4,