MathGLM / inference_mathglm.py
ZhenYang21's picture
Upload inference_mathglm.py
5e1c5f4
# -*- encoding: utf-8 -*-
'''
@File : inference_cogview.py
@Time : 2021/10/09 19:41:58
@Author : Ming Ding
@Contact : [email protected]
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import argparse
import stat
from SwissArmyTransformer import mpu, get_args, get_tokenizer
from SwissArmyTransformer.model import CachedAutoregressiveModel
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
from SwissArmyTransformer.generation.utils import timed_name, generate_continually
from SwissArmyTransformer.training import set_random_seed
import json
def main(args):
'''
2022/06/17
Modify load_checkpoint to from_pretraind
'''
# initialize_distributed(args)
# load model from saved checkpoint
model_path = '/path/to/checkpoints/'
model, args = CachedAutoregressiveModel.from_pretrained(args, model_path)
if args.fp16:
model = model.half()
model = model.to(args.device)
set_random_seed(args.seed)
model.eval()
tokenizer = get_tokenizer(args)
# define function for each query
end_tokens = [tokenizer.get_command('eos').Id]
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
def process(raw_text):
if args.with_id:
query_id, raw_text = raw_text.split('\t')
raw_text = json.loads(raw_text)
question=raw_text["question"] + "答:"
raw_text = question
seq = tokenizer._encode(raw_text)
if len(seq) != 0 and seq[0] == 20005:
seq = seq[1:]
seq = [tokenizer.get_command('ENC').Id] + seq
seq += [-1] * (args.max_sequence_length - len(seq))
if len(seq) > args.max_sequence_length:
raise ValueError('text too long.')
# generation
seq = torch.cuda.LongTensor(seq, device=args.device)
mbz = args.max_inference_batch_size
assert args.batch_size < mbz or args.batch_size % mbz == 0
output_list = []
for tim in range(max(args.batch_size // mbz, 1)):
output = filling_sequence(model, seq.clone(),
batch_size=min(args.batch_size, mbz),
strategy=strategy,
log_attention_weights=None
)[0]
if isinstance(output, torch.Tensor): # different strategies
output = list(output)
output_list.extend(output)
# find SEP to obatin output
for i in range(len(output_list)):
output = output_list[i].tolist()
try:
unfinished = output.index(-1)
except ValueError:
unfinished = len(output)
if output[unfinished - 1] in end_tokens:
unfinished -= 1
output_list[i] = output[1:unfinished]
bog = output.index(tokenizer.get_command('eos').Id)
output_list[i] = output[1:bog] + output[bog+1:unfinished]
# decoding
txts = []
for seq in output_list:
decode_tokens = tokenizer.DecodeIds(seq)
txts.append(decode_tokens)
# save
if args.with_id:
full_path = os.path.join(args.output_path, query_id + '.txt')
else:
prefix = raw_text.replace('/', '')[:20]
full_path = timed_name(prefix, '.txt', args.output_path)
print(txts[0]) # print the first.
test_eval_path = os.path.join(args.output_path, 'test_eval.txt')
with open(test_eval_path, 'a', encoding='utf-8') as fout:
fout.write(txts[0] + '\n')
os.chmod(test_eval_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
os.makedirs(args.output_path, exist_ok=True)
generate_continually(process, args.input_source)
if __name__ == "__main__":
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.do_train = False
with torch.no_grad():
main(args)