fuhsiao418 commited on
Commit
cb10343
·
1 Parent(s): 8753937
Files changed (2) hide show
  1. utils/__init__.py +1 -1
  2. utils/methods.py +11 -0
utils/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from utils.preprocess import read_text_to_json, convert_to_sentence_json, extract_sentence_features, is_valid_format
2
- from utils.methods import load_ExtModel, load_AbstrModel
 
1
  from utils.preprocess import read_text_to_json, convert_to_sentence_json, extract_sentence_features, is_valid_format
2
+ from utils.methods import load_ExtModel, load_AbstrModel, extractive_method, abstractive_method
utils/methods.py CHANGED
@@ -83,5 +83,16 @@ def load_AbstrModel(path, device='cpu'):
83
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
84
  abstrModel = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
85
  abstrModel = abstrModel.to(device)
 
 
 
 
 
 
 
 
 
 
 
86
  return tokenizer, abstrModel
87
 
 
83
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
84
  abstrModel = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
85
  abstrModel = abstrModel.to(device)
86
+
87
+ generation_config = {
88
+ 'num_beams': 5,
89
+ 'max_length': 512,
90
+ 'min_length': 64,
91
+ 'length_penalty': 2.0,
92
+ 'early_stopping': True,
93
+ 'no_repeat_ngram_size': 3
94
+ }
95
+
96
+ abstrModel.config.update(generation_config)
97
  return tokenizer, abstrModel
98