ZhenYang21 commited on
Commit
5e1c5f4
1 Parent(s): ff8ce05

Upload inference_mathglm.py

Browse files
Files changed (1) hide show
  1. inference_mathglm.py +125 -0
inference_mathglm.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : inference_cogview.py
4
+ @Time : 2021/10/09 19:41:58
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import argparse
16
+ import stat
17
+
18
+ from SwissArmyTransformer import mpu, get_args, get_tokenizer
19
+ from SwissArmyTransformer.model import CachedAutoregressiveModel
20
+ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
21
+ from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
22
+ from SwissArmyTransformer.generation.utils import timed_name, generate_continually
23
+ from SwissArmyTransformer.training import set_random_seed
24
+
25
+ import json
26
+
27
+ def main(args):
28
+
29
+ '''
30
+ 2022/06/17
31
+ Modify load_checkpoint to from_pretraind
32
+ '''
33
+ # initialize_distributed(args)
34
+ # load model from saved checkpoint
35
+
36
+ model_path = '/path/to/checkpoints/'
37
+
38
+ model, args = CachedAutoregressiveModel.from_pretrained(args, model_path)
39
+
40
+ if args.fp16:
41
+ model = model.half()
42
+ model = model.to(args.device)
43
+ set_random_seed(args.seed)
44
+ model.eval()
45
+
46
+ tokenizer = get_tokenizer(args)
47
+
48
+ # define function for each query
49
+ end_tokens = [tokenizer.get_command('eos').Id]
50
+ strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
51
+
52
+ def process(raw_text):
53
+ if args.with_id:
54
+ query_id, raw_text = raw_text.split('\t')
55
+ raw_text = json.loads(raw_text)
56
+ question=raw_text["question"] + "答:"
57
+ raw_text = question
58
+ seq = tokenizer._encode(raw_text)
59
+ if len(seq) != 0 and seq[0] == 20005:
60
+ seq = seq[1:]
61
+ seq = [tokenizer.get_command('ENC').Id] + seq
62
+ seq += [-1] * (args.max_sequence_length - len(seq))
63
+ if len(seq) > args.max_sequence_length:
64
+ raise ValueError('text too long.')
65
+ # generation
66
+ seq = torch.cuda.LongTensor(seq, device=args.device)
67
+ mbz = args.max_inference_batch_size
68
+ assert args.batch_size < mbz or args.batch_size % mbz == 0
69
+ output_list = []
70
+ for tim in range(max(args.batch_size // mbz, 1)):
71
+ output = filling_sequence(model, seq.clone(),
72
+ batch_size=min(args.batch_size, mbz),
73
+ strategy=strategy,
74
+ log_attention_weights=None
75
+ )[0]
76
+ if isinstance(output, torch.Tensor): # different strategies
77
+ output = list(output)
78
+
79
+ output_list.extend(output)
80
+ # find SEP to obatin output
81
+ for i in range(len(output_list)):
82
+ output = output_list[i].tolist()
83
+ try:
84
+ unfinished = output.index(-1)
85
+ except ValueError:
86
+ unfinished = len(output)
87
+ if output[unfinished - 1] in end_tokens:
88
+ unfinished -= 1
89
+ output_list[i] = output[1:unfinished]
90
+ bog = output.index(tokenizer.get_command('eos').Id)
91
+ output_list[i] = output[1:bog] + output[bog+1:unfinished]
92
+
93
+ # decoding
94
+ txts = []
95
+ for seq in output_list:
96
+ decode_tokens = tokenizer.DecodeIds(seq)
97
+ txts.append(decode_tokens)
98
+
99
+ # save
100
+ if args.with_id:
101
+ full_path = os.path.join(args.output_path, query_id + '.txt')
102
+ else:
103
+ prefix = raw_text.replace('/', '')[:20]
104
+ full_path = timed_name(prefix, '.txt', args.output_path)
105
+ print(txts[0]) # print the first.
106
+ test_eval_path = os.path.join(args.output_path, 'test_eval.txt')
107
+ with open(test_eval_path, 'a', encoding='utf-8') as fout:
108
+ fout.write(txts[0] + '\n')
109
+ os.chmod(test_eval_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
110
+
111
+ os.makedirs(args.output_path, exist_ok=True)
112
+ generate_continually(process, args.input_source)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ py_parser = argparse.ArgumentParser(add_help=False)
117
+
118
+ known, args_list = py_parser.parse_known_args()
119
+ args = get_args(args_list)
120
+ args = argparse.Namespace(**vars(args), **vars(known))
121
+ args.do_train = False
122
+
123
+ with torch.no_grad():
124
+ main(args)
125
+