File size: 6,096 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
  This script provides an example to wrap TencentPretrain for feature extraction.
"""
import sys
import os
import torch
import torch.nn as nn
import argparse
import numpy as np

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.embeddings import *
from tencentpretrain.encoders import *
from tencentpretrain.targets import *
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.misc import pooling
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import infer_opts, tokenizer_opts


def batch_loader(batch_size, src, seg):
    instances_num = src.size(0)
    for i in range(instances_num // batch_size):
        src_batch = src[i * batch_size : (i + 1) * batch_size]
        seg_batch = seg[i * batch_size : (i + 1) * batch_size]
        yield src_batch, seg_batch
    if instances_num > instances_num // batch_size * batch_size:
        src_batch = src[instances_num // batch_size * batch_size:]
        seg_batch = seg[instances_num // batch_size * batch_size:]
        yield src_batch, seg_batch


def read_dataset(args, path):
    dataset = []
    PAD_ID = args.tokenizer.vocab.get(PAD_TOKEN)
    with open(path, mode="r", encoding="utf-8") as f:
        for line in f:
            src = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(line))
            if len(src) == 0:
                continue
            src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN]) + src + args.tokenizer.convert_tokens_to_ids([SEP_TOKEN])
            seg = [1] * len(src)

            if len(src) > args.seq_length:
                src = src[:args.seq_length]
                seg = seg[:args.seq_length]
            while len(src) < args.seq_length:
                src.append(PAD_ID)
                seg.append(PAD_ID)
            dataset.append((src, seg))
    return dataset


class FeatureExtractor(torch.nn.Module):    
    def __init__(self, args):
        super(FeatureExtractor, self).__init__()
        self.embedding = Embedding(args)
        for embedding_name in args.embedding:
            tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
            self.embedding.update(tmp_emb, embedding_name)
        self.encoder = str2encoder[args.encoder](args)
        self.pooling_type = args.pooling

    def forward(self, src, seg):
        emb = self.embedding(src, seg)
        output = self.encoder(emb, seg)
        output = pooling(output, seg, self.pooling_type)

        return output        


class WhiteningHandle(torch.nn.Module):
    """
    Whitening operation.
    @ref: https://github.com/bojone/BERT-whitening/blob/main/demo.py
    """
    def __init__(self, args, vecs):
        super(WhiteningHandle, self).__init__()
        self.kernel, self.bias = self._compute_kernel_bias(vecs)

    def forward(self, vecs, n_components=None, normal=True, pt=True):
        vecs = self._format_vecs_to_np(vecs)
        vecs = self._transform(vecs, n_components)
        vecs = self._normalize(vecs) if normal else vecs
        vecs = torch.tensor(vecs) if pt else vecs
        return vecs

    def _compute_kernel_bias(self, vecs):
        vecs = self._format_vecs_to_np(vecs)
        mu = vecs.mean(axis=0, keepdims=True)
        cov = np.cov(vecs.T)
        u, s, vh = np.linalg.svd(cov)
        W = np.dot(u, np.diag(1 / np.sqrt(s)))
        return W, -mu

    def _transform(self, vecs, n_components):
        w = self.kernel[:, :n_components] \
                if isinstance(n_components, int) else self.kernel
        return (vecs + self.bias).dot(w)

    def _normalize(self, vecs):
        return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5

    def _format_vecs_to_np(self, vecs):
        vecs_np = []
        for vec in vecs:
            if isinstance(vec, list):
                vec = np.array(vec)
            elif torch.is_tensor(vec):
                vec = vec.detach().numpy()
            elif isinstance(vec, np.ndarray):
                vec = vec
            else:
                raise Exception('Unknown vec type.')
            vecs_np.append(vec)
        vecs_np = np.array(vecs_np)
        return vecs_np


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    infer_opts(parser)

    parser.add_argument("--whitening_size", type=int, default=None, help="Output vector size after whitening.")

    tokenizer_opts(parser)

    args = parser.parse_args()
    args = load_hyperparam(args)

    args.tokenizer = str2tokenizer[args.tokenizer](args)

    # Build feature extractor model.
    model = FeatureExtractor(args)
    model = load_model(model, args.load_model_path)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model.eval()

    dataset = read_dataset(args, args.test_path)

    src = torch.LongTensor([sample[0] for sample in dataset])
    seg = torch.LongTensor([sample[1] for sample in dataset])

    feature_vectors = []
    for i, (src_batch, seg_batch) in enumerate(batch_loader(args.batch_size, src, seg)):
        src_batch = src_batch.to(device)
        seg_batch = seg_batch.to(device)
        output = model(src_batch, seg_batch)
        feature_vectors.append(output.cpu().detach())
    feature_vectors = torch.cat(feature_vectors, 0)

    # Vector whitening.
    if args.whitening_size is not None:
        whitening = WhiteningHandle(args, feature_vectors)
        feature_vectors = whitening(feature_vectors, args.whitening_size, pt=True)

    print("The size of feature vectors (sentences_num * vector size): {}".format(feature_vectors.shape))
    torch.save(feature_vectors, args.prediction_path)