席亚东
commited on
Commit
·
05e5527
1
Parent(s):
cc5bc33
first commit
Browse files- README.md +118 -0
- dict.txt +0 -0
- inference.py +181 -0
README.md
CHANGED
@@ -1,3 +1,121 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
|
4 |
+
language: zh
|
5 |
+
inference: false
|
6 |
+
tags:
|
7 |
+
- text-generation
|
8 |
+
- story-generation
|
9 |
+
- pytorch
|
10 |
+
- inference acceleration
|
11 |
+
- gpt2
|
12 |
+
- gpt3
|
13 |
---
|
14 |
+
# YuYan: Pre-training of Language Models for Story Generation
|
15 |
+
|
16 |
+
YuYan is a series of Chinese language models with different size, developed by Fuxi AI lab, Netease.Inc. They are trained on a large Chinese novel dataset of high quality.
|
17 |
+
|
18 |
+
YuYan is in the same family of decoder-only models like [GPT2 and GPT-3](https://arxiv.org/abs/2005.14165). As such, it was pretrained using the self-supervised causal language modedling objective.
|
19 |
+
|
20 |
+
Because the training data is mainly the novel, the model is good at generating the next plot given the story context.
|
21 |
+
|
22 |
+
## Model Inference Acceleration
|
23 |
+
|
24 |
+
As the model size increases, the model inference time increases and more computational resources are required.
|
25 |
+
|
26 |
+
Therefore, we developed our own transformer model inference acceleration framework, [EET](https://github.com/NetEase-FuXi/EET.git). More details are in [Easy and Efficient Transformer: Scalable Inference Solution For Large NLP Model](https://aclanthology.org/2022.naacl-industry.8/).
|
27 |
+
|
28 |
+
We combine our language model with the EET inference framework to provide industrial-grade inference reasoning performance.
|
29 |
+
|
30 |
+
## How to use
|
31 |
+
|
32 |
+
Our model is trained based on the [fairseq](https://github.com/facebookresearch/fairseq). As a result, the inference and finetuning depend on it.
|
33 |
+
|
34 |
+
For inference, we modify some parts of the original fairseq codes. Mainly
|
35 |
+
> fairseq-0.12.2/fairseq/sequence_generator.py
|
36 |
+
|
37 |
+
We integrate the EET with sequence_generator. We replace the eos token to a token unlikely to be sampled to ensure the generated text length. The repetition penalty trick is also modified. You can change the penalty strength by adjusting the value of `self.ban_weight`.
|
38 |
+
|
39 |
+
Then, to keep the eos token in the final generated text, we change the line 75 `include_eos=False` to `include_eos=True` in
|
40 |
+
> fairseq-0.12.2/fairseq/data/dictionary.py
|
41 |
+
|
42 |
+
Finally, to pass in parameters in python scripts, we remove the line 67 ~ line 69 in
|
43 |
+
>fairseq-0.12.2/fairseq/dataclass/utils.py
|
44 |
+
|
45 |
+
Below are the install tutorial.
|
46 |
+
|
47 |
+
```
|
48 |
+
# install pytorch
|
49 |
+
pip install torch==1.8.1 # install pytorch
|
50 |
+
|
51 |
+
# install fairseq
|
52 |
+
unzip fairseq-0.12.2.zip
|
53 |
+
cd fairseq-0.12.2
|
54 |
+
pip install.
|
55 |
+
|
56 |
+
# install EET
|
57 |
+
git clone https://github.com/NetEase-FuXi/EET.git
|
58 |
+
cd EET
|
59 |
+
pip install .
|
60 |
+
|
61 |
+
# install transformers (EET requirements)
|
62 |
+
pip install transformers==4.23
|
63 |
+
|
64 |
+
# make a folder, move the dictionary file and model file into it.
|
65 |
+
mkdir transformer_lm_gpt2_xxl
|
66 |
+
mv dict.txt transformer_lm_gpt2_xxl/
|
67 |
+
mv checkpoint_best_part_*.pt transformer_lm_gpt2_xxl/
|
68 |
+
|
69 |
+
```
|
70 |
+
`inference.py` is a script to provide a interface to initialize the EET object and sequence_generator. In addition, It includes some pre-process and post-process functions for text input and output. You can modify the script according to your needs.
|
71 |
+
|
72 |
+
After the environment is ready, several lines of codes can realize the inference.
|
73 |
+
|
74 |
+
``` python
|
75 |
+
|
76 |
+
from inference import Inference
|
77 |
+
model_path = "transformer_lm_gpt2_xxl/checkpoint_best.pt"
|
78 |
+
data_path = "transformer_lm_gpt2_xxl"
|
79 |
+
eet_batch_size = 10 # max inference batch size, adjust according to cuda memory, 40GB memory is necessary
|
80 |
+
inference = Inference(model_path, data_path, eet_batch_size)
|
81 |
+
|
82 |
+
inp = "田园一听这话,轻挑的嘴角放了下来,两腿叉开,踱着方步,跨过汤婆子,一屁股坐在了老人面前。</s>刘萌和健军一左一右站在他身旁,像是王朝、马汉护着包公断案。"
|
83 |
+
text = inference([inp] * 10, append_right_eos=True)
|
84 |
+
|
85 |
+
```
|
86 |
+
This interface supports batch inputs, so if you need to generate multiple results for one input, you can copy the input multiple times. The interface supports results generated for multiple different inputs, e.g.
|
87 |
+
```python
|
88 |
+
text = inference(["四个月后,正是草长花秾的暮春季节。</s>令狐冲和盈盈新婚燕尔,携手共赴华山。","院子中传来急促的脚步声,他停下手中的招式,将开元刀插入刀鞘。"])
|
89 |
+
```
|
90 |
+
|
91 |
+
## Citation
|
92 |
+
If you find the technical report or resource is useful, please cite the following technical report in your paper.
|
93 |
+
- https://aclanthology.org/2022.naacl-industry.8/
|
94 |
+
```
|
95 |
+
@inproceedings{li-etal-2022-easy,
|
96 |
+
title = "Easy and Efficient Transformer: Scalable Inference Solution For Large {NLP} Model",
|
97 |
+
author = "Li, Gongzheng and
|
98 |
+
Xi, Yadong and
|
99 |
+
Ding, Jingzhen and
|
100 |
+
Wang, Duan and
|
101 |
+
Luo, Ziyang and
|
102 |
+
Zhang, Rongsheng and
|
103 |
+
Liu, Bai and
|
104 |
+
Fan, Changjie and
|
105 |
+
Mao, Xiaoxi and
|
106 |
+
Zhao, Zeng",
|
107 |
+
booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies: Industry Track",
|
108 |
+
month = jul,
|
109 |
+
year = "2022",
|
110 |
+
address = "Hybrid: Seattle, Washington + Online",
|
111 |
+
publisher = "Association for Computational Linguistics",
|
112 |
+
url = "https://aclanthology.org/2022.naacl-industry.8",
|
113 |
+
doi = "10.18653/v1/2022.naacl-industry.8",
|
114 |
+
pages = "62--68"
|
115 |
+
}
|
116 |
+
|
117 |
+
```
|
118 |
+
## Contact Us
|
119 |
+
You can also contact us by email:
|
120 |
+
|
121 |
dict.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
inference.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
|
9 |
+
from fairseq import options, tasks, utils
|
10 |
+
from eet.fairseq.transformer import EETTransformerDecoder
|
11 |
+
|
12 |
+
|
13 |
+
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
14 |
+
|
15 |
+
def make_batches(lines, task, max_positions, encode_fn):
|
16 |
+
|
17 |
+
tokens = [task.source_dictionary.encode_line(encode_fn(line),
|
18 |
+
add_if_not_exist=False,
|
19 |
+
append_eos=False,
|
20 |
+
reverse_order=True).long()
|
21 |
+
for line in lines]
|
22 |
+
lengths = [t.numel() for t in tokens]
|
23 |
+
tokens = pad_sequence(tokens, batch_first=True,
|
24 |
+
padding_value=1).flip(dims=(1,))
|
25 |
+
|
26 |
+
return Batch(ids=torch.arange(len(tokens)),
|
27 |
+
src_tokens=tokens,
|
28 |
+
src_lengths=torch.tensor(lengths))
|
29 |
+
|
30 |
+
|
31 |
+
def encode_fn(x_str):
|
32 |
+
x_str = x_str.replace(" ", "")
|
33 |
+
x_str = x_str.split("</s>")
|
34 |
+
x_str = " </s> ".join([" ".join(list(x)) for x in x_str])
|
35 |
+
x_str = "</s> " + x_str
|
36 |
+
return x_str
|
37 |
+
|
38 |
+
|
39 |
+
def decode_fn(x):
|
40 |
+
x = x.replace(" ", "")
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def eos_token_filter(sent):
|
45 |
+
if "</s>" in sent:
|
46 |
+
return True
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
def post_precess(line):
|
51 |
+
line = "</s>".join(line.split("</s>")[:-1])
|
52 |
+
return line
|
53 |
+
|
54 |
+
|
55 |
+
class Inference(object):
|
56 |
+
|
57 |
+
def __init__(self, model_path, data_path, eet_batch_size):
|
58 |
+
|
59 |
+
parser = options.get_generation_parser(
|
60 |
+
default_task="language_modeling")
|
61 |
+
args = options.parse_args_and_arch(parser)
|
62 |
+
args.data = data_path
|
63 |
+
args.path = model_path
|
64 |
+
self.args = args
|
65 |
+
|
66 |
+
# generate parameter
|
67 |
+
args.beam = 1 # don't change
|
68 |
+
args.min_len = 5
|
69 |
+
args.max_len_b = 200
|
70 |
+
args.lenpen = 1.0
|
71 |
+
args.sampling = True
|
72 |
+
args.sampling_topp = 0.8
|
73 |
+
# args.sampling_topk = 20
|
74 |
+
args.temperature = 0.8
|
75 |
+
args.no_repeat_ngram_size = 1
|
76 |
+
args.fp16 = True
|
77 |
+
|
78 |
+
# Setup task, e.g., translation
|
79 |
+
task = tasks.setup_task(args)
|
80 |
+
self.task = task
|
81 |
+
# Set dictionaries
|
82 |
+
self.src_dict = task.source_dictionary
|
83 |
+
self.tgt_dict = task.target_dictionary
|
84 |
+
|
85 |
+
use_cuda = torch.cuda.is_available() and not args.cpu
|
86 |
+
self.use_cuda = use_cuda
|
87 |
+
|
88 |
+
model_path = args.path
|
89 |
+
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
90 |
+
checkpoint["model"].update(model_path.replace("best.pt", "best_part_2.pt"))
|
91 |
+
checkpoint["model"].update(model_path.replace("best.pt", "best_part_3.pt"))
|
92 |
+
torch.save(checkpoint, model_path)
|
93 |
+
# load part 1
|
94 |
+
state = torch.load(args.path, map_location=torch.device("cpu"))
|
95 |
+
cfg_args = eval(str(state["cfg"]))["model"]
|
96 |
+
del cfg_args["_name"]
|
97 |
+
keys_list = []
|
98 |
+
values_list = []
|
99 |
+
for key, value in cfg_args.items():
|
100 |
+
keys_list.append(key)
|
101 |
+
values_list.append(value)
|
102 |
+
Model_args = namedtuple("Model_args", keys_list)
|
103 |
+
model_args = Model_args._make(values_list)
|
104 |
+
del state
|
105 |
+
|
106 |
+
eet_seq_len = 1024 # max sequence length, (input length + generation length) shouldn't be larger than this
|
107 |
+
eet_batch_size = eet_batch_size
|
108 |
+
data_type = torch.float16
|
109 |
+
eet_config = {"data_type": data_type,
|
110 |
+
"max_batch": eet_batch_size,
|
111 |
+
"full_seq_len": eet_seq_len}
|
112 |
+
print(model_args)
|
113 |
+
|
114 |
+
eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path=args.path,
|
115 |
+
dictionary=self.src_dict, args=model_args,
|
116 |
+
config=eet_config,
|
117 |
+
no_encoder_attn=True)
|
118 |
+
self.models = [eet_model]
|
119 |
+
# Initialize generator
|
120 |
+
self.generator = task.build_generator(self.models, args)
|
121 |
+
|
122 |
+
# Load alignment dictionary for unknown word replacement
|
123 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
124 |
+
self.align_dict = utils.load_align_dict(args.replace_unk)
|
125 |
+
|
126 |
+
self.max_positions = 1024 # the model config
|
127 |
+
self.eos_index = self.tgt_dict.eos()
|
128 |
+
self.pad_index = self.tgt_dict.pad()
|
129 |
+
|
130 |
+
def __call__(self, inputs, append_right_eos=True):
|
131 |
+
|
132 |
+
results = []
|
133 |
+
start_id = 0
|
134 |
+
|
135 |
+
batch = make_batches(inputs, self.task, self.max_positions, encode_fn)
|
136 |
+
inputs_str = inputs
|
137 |
+
|
138 |
+
src_tokens = batch.src_tokens
|
139 |
+
src_lengths = batch.src_lengths
|
140 |
+
# a new paragraph always
|
141 |
+
if src_tokens[0][-1].item() != self.eos_index and append_right_eos:
|
142 |
+
src_tokens = torch.cat([src_tokens, src_tokens.new_ones(
|
143 |
+
src_tokens.size(0), 1) * self.eos_index], dim=1)
|
144 |
+
src_lengths += 1
|
145 |
+
if self.use_cuda:
|
146 |
+
src_tokens = src_tokens.cuda()
|
147 |
+
src_lengths = src_lengths.cuda()
|
148 |
+
sample = {
|
149 |
+
'net_input': {
|
150 |
+
'src_tokens': src_tokens,
|
151 |
+
'src_lengths': src_lengths,
|
152 |
+
},
|
153 |
+
}
|
154 |
+
|
155 |
+
translations = self.task.inference_step(
|
156 |
+
self.generator, self.models, sample)
|
157 |
+
|
158 |
+
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
159 |
+
results.append((start_id + id, src_tokens[i], hypos))
|
160 |
+
|
161 |
+
# sort output to match input order
|
162 |
+
final_results = []
|
163 |
+
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
|
164 |
+
# Process top predictions
|
165 |
+
tmp_res = []
|
166 |
+
for hypo in hypos[:min(len(hypos), self.args.nbest)]:
|
167 |
+
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
168 |
+
hypo_tokens=hypo['tokens'].int().cpu()[
|
169 |
+
len(src_tokens) - 1:],
|
170 |
+
src_str=None,
|
171 |
+
alignment=hypo['alignment'],
|
172 |
+
align_dict=self.align_dict,
|
173 |
+
tgt_dict=self.tgt_dict)
|
174 |
+
|
175 |
+
detok_hypo_str = decode_fn(hypo_str)
|
176 |
+
if eos_token_filter(detok_hypo_str):
|
177 |
+
detok_hypo_str = post_precess(detok_hypo_str)
|
178 |
+
score = hypo['score'] / math.log(2) # convert to base 2
|
179 |
+
tmp_res.append([detok_hypo_str, score])
|
180 |
+
final_results.append(tmp_res)
|
181 |
+
return final_results
|