File size: 4,271 Bytes
5e1c5f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- encoding: utf-8 -*-
'''
@File    :   inference_cogview.py
@Time    :   2021/10/09 19:41:58
@Author  :   Ming Ding 
@Contact :   [email protected]
'''

# here put the import lib
import os
import sys
import math
import random
import torch
import argparse
import stat

from SwissArmyTransformer import mpu, get_args, get_tokenizer
from SwissArmyTransformer.model import CachedAutoregressiveModel
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
from SwissArmyTransformer.generation.utils import timed_name, generate_continually
from SwissArmyTransformer.training import set_random_seed

import json

def main(args):

    '''
    2022/06/17
    Modify load_checkpoint to from_pretraind
    '''
    # initialize_distributed(args)
    # load model from saved checkpoint 
    
    model_path = '/path/to/checkpoints/'
    
    model, args = CachedAutoregressiveModel.from_pretrained(args, model_path)

    if args.fp16:
        model = model.half()
    model = model.to(args.device)
    set_random_seed(args.seed)
    model.eval()
    
    tokenizer = get_tokenizer(args)
 
    # define function for each query
    end_tokens = [tokenizer.get_command('eos').Id] 
    strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
    
    def process(raw_text): 
        if args.with_id:
            query_id, raw_text = raw_text.split('\t')
        raw_text = json.loads(raw_text)
        question=raw_text["question"] + "答:"
        raw_text = question
        seq = tokenizer._encode(raw_text)
        if len(seq) != 0 and seq[0] == 20005:
            seq = seq[1:]
        seq = [tokenizer.get_command('ENC').Id] + seq
        seq += [-1] * (args.max_sequence_length - len(seq)) 
        if len(seq) > args.max_sequence_length:
            raise ValueError('text too long.')
        # generation
        seq = torch.cuda.LongTensor(seq, device=args.device)
        mbz = args.max_inference_batch_size
        assert args.batch_size < mbz or args.batch_size % mbz == 0
        output_list = []
        for tim in range(max(args.batch_size // mbz, 1)):
            output = filling_sequence(model, seq.clone(),
                    batch_size=min(args.batch_size, mbz),
                    strategy=strategy,
                    log_attention_weights=None
                    )[0]
            if isinstance(output, torch.Tensor): # different strategies
                output = list(output)

            output_list.extend(output)
        # find SEP to obatin output 
        for i in range(len(output_list)):
            output = output_list[i].tolist() 
            try:
                unfinished = output.index(-1)
            except ValueError:
                unfinished = len(output)
            if output[unfinished - 1] in end_tokens:
                unfinished -= 1
            output_list[i] = output[1:unfinished]
            bog = output.index(tokenizer.get_command('eos').Id)
            output_list[i] = output[1:bog] + output[bog+1:unfinished] 
                
        # decoding 
        txts = [] 
        for seq in output_list:
            decode_tokens = tokenizer.DecodeIds(seq)
            txts.append(decode_tokens)
        
        # save
        if args.with_id:
            full_path = os.path.join(args.output_path, query_id + '.txt')
        else:
            prefix = raw_text.replace('/', '')[:20]
            full_path = timed_name(prefix, '.txt', args.output_path)
            print(txts[0]) # print the first.
        test_eval_path = os.path.join(args.output_path, 'test_eval.txt')
        with open(test_eval_path, 'a', encoding='utf-8') as fout:
            fout.write(txts[0] + '\n')
        os.chmod(test_eval_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)

    os.makedirs(args.output_path, exist_ok=True)
    generate_continually(process, args.input_source) 


if __name__ == "__main__":
    py_parser = argparse.ArgumentParser(add_help=False)
    
    known, args_list = py_parser.parse_known_args()
    args = get_args(args_list)
    args = argparse.Namespace(**vars(args), **vars(known))
    args.do_train = False
    
    with torch.no_grad():
        main(args)