File size: 4,651 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
This script provides an example to wrap TencentPretrain for text-to-text inference.
"""
import sys
import os
import random
import argparse
import torch

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.vocab import Vocab
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import infer_opts, tokenizer_opts
from finetune.run_text2text import Text2text
from inference.run_classifier_infer import batch_loader


def read_dataset(args, path):
    dataset, columns = [], {}
    with open(path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                for i, column_name in enumerate(line.rstrip("\r\n").split("\t")):
                    columns[column_name] = i
                continue
            line = line.rstrip("\r\n").split("\t")

            if "text_b" in columns:
                text = line[columns["text_a"]] + SEP_TOKEN + line[columns["text_b"]]
            else:
                text = line[columns["text_a"]]

            src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text) + [SEP_TOKEN])
            seg = [1] * len(src)

            if len(src) > args.seq_length:
                src = src[: args.seq_length]
                seg = seg[: args.seq_length]
            PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
            while len(src) < args.seq_length:
                src.append(PAD_ID)
                seg.append(0)

            dataset.append((src, seg))

    return dataset


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    infer_opts(parser)

    tokenizer_opts(parser)

    parser.add_argument("--tgt_seq_length", type=int, default=32,
                        help="Output sequence length.")
    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Build tokenizer.
    args.tokenizer = str2tokenizer[args.tokenizer](args)

    # Build classification model.
    model = Text2text(args)
    model = load_model(model, args.load_model_path)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)


    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    dataset = read_dataset(args, args.test_path)

    src = torch.LongTensor([sample[0] for sample in dataset])
    seg = torch.LongTensor([sample[1] for sample in dataset])

    batch_size = args.batch_size
    instances_num = src.size()[0]

    print("The number of prediction instances: ", instances_num)

    model.eval()
    with open(args.prediction_path, mode="w", encoding="utf-8") as f:
        f.write("label")
        f.write("\n")
        for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
            src_batch = src_batch.to(args.device)
            seg_batch = seg_batch.to(args.device)
            tgt_in_batch = torch.zeros(src_batch.size()[0], 1, dtype = torch.long, device = args.device)
            tgt_seg_batch = torch.ones(tgt_in_batch.size()[0], 1, dtype = torch.long, device = args.device)
            current_batch_size = tgt_in_batch.size()[0]
            for j in range(current_batch_size):
                tgt_in_batch[j][-1] = args.tokenizer.vocab.get(CLS_TOKEN)

            with torch.no_grad():
                memory_bank = model(src_batch, None, seg_batch, tgt_seg_batch, only_use_encoder=True)
            for _ in range(args.tgt_seq_length):
                with torch.no_grad():
                    outputs = model(src_batch, (tgt_in_batch, None, src_batch), None, tgt_seg_batch, memory_bank=memory_bank)
                next_token_logits = outputs[:, -1]
                next_tokens = torch.argmax(next_token_logits, dim=1).unsqueeze(1)
                tgt_in_batch = torch.cat([tgt_in_batch, next_tokens], dim=1)
                tgt_seg_batch = torch.ones(tgt_in_batch.size()[0], tgt_in_batch.size()[1], dtype=torch.long, device=args.device)
            for j in range(len(outputs)):
                f.write("".join([args.tokenizer.inv_vocab[token_id.item()] for token_id in tgt_in_batch[j][1:]])
                        .split(SEP_TOKEN)[0])
                f.write("\n")


if __name__ == "__main__":
    main()