席亚东 commited on
Commit
f1e4729
·
1 Parent(s): 8926c77

add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +176 -0
inference.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+
3
+ from collections import namedtuple
4
+
5
+ import math
6
+ import torch
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from fairseq import options, tasks, utils
10
+ from eet.fairseq.transformer import EETTransformerDecoder
11
+
12
+
13
+ Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
14
+
15
+ def make_batches(lines, task, max_positions, encode_fn):
16
+
17
+ tokens = [task.source_dictionary.encode_line(encode_fn(line),
18
+ add_if_not_exist=False,
19
+ append_eos=False,
20
+ reverse_order=True).long()
21
+ for line in lines]
22
+ lengths = [t.numel() for t in tokens]
23
+ tokens = pad_sequence(tokens, batch_first=True,
24
+ padding_value=1).flip(dims=(1,))
25
+
26
+ return Batch(ids=torch.arange(len(tokens)),
27
+ src_tokens=tokens,
28
+ src_lengths=torch.tensor(lengths))
29
+
30
+
31
+ def encode_fn(x_str):
32
+ x_str = x_str.replace(" ", "")
33
+ x_str = x_str.split("</s>")
34
+ x_str = " </s> ".join([" ".join(list(x)) for x in x_str])
35
+ x_str = "</s> " + x_str
36
+ return x_str
37
+
38
+
39
+ def decode_fn(x):
40
+ x = x.replace(" ", "")
41
+ return x
42
+
43
+
44
+ def eos_token_filter(sent):
45
+ if "</s>" in sent:
46
+ return True
47
+ return False
48
+
49
+
50
+ def post_precess(line):
51
+ line = "</s>".join(line.split("</s>")[:-1])
52
+ return line
53
+
54
+
55
+ class Inference(object):
56
+
57
+ def __init__(self, model_path, data_path, eet_batch_size):
58
+
59
+ parser = options.get_generation_parser(
60
+ default_task="language_modeling")
61
+ args = options.parse_args_and_arch(parser)
62
+ args.data = data_path
63
+ args.path = model_path
64
+ self.args = args
65
+
66
+ # generate parameter
67
+ args.beam = 1 # don't change
68
+ args.min_len = 5
69
+ args.max_len_b = 200
70
+ args.lenpen = 1.0
71
+ args.sampling = True
72
+ args.sampling_topp = 0.8
73
+ # args.sampling_topk = 20
74
+ args.temperature = 0.8
75
+ args.no_repeat_ngram_size = 1
76
+ args.fp16 = True
77
+
78
+ # Setup task, e.g., translation
79
+ task = tasks.setup_task(args)
80
+ self.task = task
81
+ # Set dictionaries
82
+ self.src_dict = task.source_dictionary
83
+ self.tgt_dict = task.target_dictionary
84
+
85
+ use_cuda = torch.cuda.is_available() and not args.cpu
86
+ self.use_cuda = use_cuda
87
+
88
+ # Optimize ensemble for generation
89
+ state = torch.load(args.path, map_location=torch.device("cpu"))
90
+ cfg_args = eval(str(state["cfg"]))["model"]
91
+ del cfg_args["_name"]
92
+ keys_list = []
93
+ values_list = []
94
+ for key, value in cfg_args.items():
95
+ keys_list.append(key)
96
+ values_list.append(value)
97
+ Model_args = namedtuple("Model_args", keys_list)
98
+ model_args = Model_args._make(values_list)
99
+ del state
100
+
101
+ eet_seq_len = 1024 # max sequence length, (input length + generation length) shouldn't be larger than this
102
+ eet_batch_size = eet_batch_size
103
+ data_type = torch.float16
104
+ eet_config = {"data_type": data_type,
105
+ "max_batch": eet_batch_size,
106
+ "full_seq_len": eet_seq_len}
107
+ print(model_args)
108
+
109
+ eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path=args.path,
110
+ dictionary=self.src_dict, args=model_args,
111
+ config=eet_config,
112
+ no_encoder_attn=True)
113
+ self.models = [eet_model]
114
+ # Initialize generator
115
+ self.generator = task.build_generator(self.models, args)
116
+
117
+ # Load alignment dictionary for unknown word replacement
118
+ # (None if no unknown word replacement, empty if no path to align dictionary)
119
+ self.align_dict = utils.load_align_dict(args.replace_unk)
120
+
121
+ self.max_positions = 1024 # the model config
122
+ self.eos_index = self.tgt_dict.eos()
123
+ self.pad_index = self.tgt_dict.pad()
124
+
125
+ def __call__(self, inputs, append_right_eos=True):
126
+
127
+ results = []
128
+ start_id = 0
129
+
130
+ batch = make_batches(inputs, self.task, self.max_positions, encode_fn)
131
+ inputs_str = inputs
132
+
133
+ src_tokens = batch.src_tokens
134
+ src_lengths = batch.src_lengths
135
+ # a new paragraph always
136
+ if src_tokens[0][-1].item() != self.eos_index and append_right_eos:
137
+ src_tokens = torch.cat([src_tokens, src_tokens.new_ones(
138
+ src_tokens.size(0), 1) * self.eos_index], dim=1)
139
+ src_lengths += 1
140
+ if self.use_cuda:
141
+ src_tokens = src_tokens.cuda()
142
+ src_lengths = src_lengths.cuda()
143
+ sample = {
144
+ 'net_input': {
145
+ 'src_tokens': src_tokens,
146
+ 'src_lengths': src_lengths,
147
+ },
148
+ }
149
+
150
+ translations = self.task.inference_step(
151
+ self.generator, self.models, sample)
152
+
153
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
154
+ results.append((start_id + id, src_tokens[i], hypos))
155
+
156
+ # sort output to match input order
157
+ final_results = []
158
+ for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
159
+ # Process top predictions
160
+ tmp_res = []
161
+ for hypo in hypos[:min(len(hypos), self.args.nbest)]:
162
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
163
+ hypo_tokens=hypo['tokens'].int().cpu()[
164
+ len(src_tokens) - 1:],
165
+ src_str=None,
166
+ alignment=hypo['alignment'],
167
+ align_dict=self.align_dict,
168
+ tgt_dict=self.tgt_dict)
169
+
170
+ detok_hypo_str = decode_fn(hypo_str)
171
+ if eos_token_filter(detok_hypo_str):
172
+ detok_hypo_str = post_precess(detok_hypo_str)
173
+ score = hypo['score'] / math.log(2) # convert to base 2
174
+ tmp_res.append([detok_hypo_str, score])
175
+ final_results.append(tmp_res)
176
+ return final_results