Spaces:
Runtime error
Runtime error
File size: 6,096 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""
This script provides an example to wrap TencentPretrain for feature extraction.
"""
import sys
import os
import torch
import torch.nn as nn
import argparse
import numpy as np
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)
from tencentpretrain.embeddings import *
from tencentpretrain.encoders import *
from tencentpretrain.targets import *
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.misc import pooling
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import infer_opts, tokenizer_opts
def batch_loader(batch_size, src, seg):
instances_num = src.size(0)
for i in range(instances_num // batch_size):
src_batch = src[i * batch_size : (i + 1) * batch_size]
seg_batch = seg[i * batch_size : (i + 1) * batch_size]
yield src_batch, seg_batch
if instances_num > instances_num // batch_size * batch_size:
src_batch = src[instances_num // batch_size * batch_size:]
seg_batch = seg[instances_num // batch_size * batch_size:]
yield src_batch, seg_batch
def read_dataset(args, path):
dataset = []
PAD_ID = args.tokenizer.vocab.get(PAD_TOKEN)
with open(path, mode="r", encoding="utf-8") as f:
for line in f:
src = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(line))
if len(src) == 0:
continue
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN]) + src + args.tokenizer.convert_tokens_to_ids([SEP_TOKEN])
seg = [1] * len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
while len(src) < args.seq_length:
src.append(PAD_ID)
seg.append(PAD_ID)
dataset.append((src, seg))
return dataset
class FeatureExtractor(torch.nn.Module):
def __init__(self, args):
super(FeatureExtractor, self).__init__()
self.embedding = Embedding(args)
for embedding_name in args.embedding:
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
self.embedding.update(tmp_emb, embedding_name)
self.encoder = str2encoder[args.encoder](args)
self.pooling_type = args.pooling
def forward(self, src, seg):
emb = self.embedding(src, seg)
output = self.encoder(emb, seg)
output = pooling(output, seg, self.pooling_type)
return output
class WhiteningHandle(torch.nn.Module):
"""
Whitening operation.
@ref: https://github.com/bojone/BERT-whitening/blob/main/demo.py
"""
def __init__(self, args, vecs):
super(WhiteningHandle, self).__init__()
self.kernel, self.bias = self._compute_kernel_bias(vecs)
def forward(self, vecs, n_components=None, normal=True, pt=True):
vecs = self._format_vecs_to_np(vecs)
vecs = self._transform(vecs, n_components)
vecs = self._normalize(vecs) if normal else vecs
vecs = torch.tensor(vecs) if pt else vecs
return vecs
def _compute_kernel_bias(self, vecs):
vecs = self._format_vecs_to_np(vecs)
mu = vecs.mean(axis=0, keepdims=True)
cov = np.cov(vecs.T)
u, s, vh = np.linalg.svd(cov)
W = np.dot(u, np.diag(1 / np.sqrt(s)))
return W, -mu
def _transform(self, vecs, n_components):
w = self.kernel[:, :n_components] \
if isinstance(n_components, int) else self.kernel
return (vecs + self.bias).dot(w)
def _normalize(self, vecs):
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
def _format_vecs_to_np(self, vecs):
vecs_np = []
for vec in vecs:
if isinstance(vec, list):
vec = np.array(vec)
elif torch.is_tensor(vec):
vec = vec.detach().numpy()
elif isinstance(vec, np.ndarray):
vec = vec
else:
raise Exception('Unknown vec type.')
vecs_np.append(vec)
vecs_np = np.array(vecs_np)
return vecs_np
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
infer_opts(parser)
parser.add_argument("--whitening_size", type=int, default=None, help="Output vector size after whitening.")
tokenizer_opts(parser)
args = parser.parse_args()
args = load_hyperparam(args)
args.tokenizer = str2tokenizer[args.tokenizer](args)
# Build feature extractor model.
model = FeatureExtractor(args)
model = load_model(model, args.load_model_path)
# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
if torch.cuda.device_count() > 1:
print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model.eval()
dataset = read_dataset(args, args.test_path)
src = torch.LongTensor([sample[0] for sample in dataset])
seg = torch.LongTensor([sample[1] for sample in dataset])
feature_vectors = []
for i, (src_batch, seg_batch) in enumerate(batch_loader(args.batch_size, src, seg)):
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
output = model(src_batch, seg_batch)
feature_vectors.append(output.cpu().detach())
feature_vectors = torch.cat(feature_vectors, 0)
# Vector whitening.
if args.whitening_size is not None:
whitening = WhiteningHandle(args, feature_vectors)
feature_vectors = whitening(feature_vectors, args.whitening_size, pt=True)
print("The size of feature vectors (sentences_num * vector size): {}".format(feature_vectors.shape))
torch.save(feature_vectors, args.prediction_path)
|