Spaces:
Runtime error
Runtime error
File size: 1,695 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from tencentpretrain.embeddings import *
from tencentpretrain.encoders import *
from tencentpretrain.decoders import *
from tencentpretrain.targets import *
from tencentpretrain.models.model import Model
def build_model(args):
"""
Build universial encoder representations models.
The combinations of different embedding, encoder,
and target layers yield pretrained models of different
properties.
We could select suitable one for downstream tasks.
"""
embedding = Embedding(args)
for embedding_name in args.embedding:
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
embedding.update(tmp_emb, embedding_name)
encoder = str2encoder[args.encoder](args)
if args.decoder is not None:
if args.data_processor == "mt":
tgt_vocab_size = len(args.tgt_tokenizer.vocab)
else:
tgt_vocab_size = len(args.tokenizer.vocab)
tgt_embedding = Embedding(args)
for embedding_name in args.tgt_embedding:
tmp_emb = str2embedding[embedding_name](args, tgt_vocab_size)
tgt_embedding.update(tmp_emb, embedding_name)
decoder = str2decoder[args.decoder](args)
else:
tgt_embedding = None
decoder = None
target = Target()
for target_name in args.target:
if args.data_processor == "mt":
tmp_target = str2target[target_name](args, len(args.tgt_tokenizer.vocab))
else:
tmp_target = str2target[target_name](args, len(args.tokenizer.vocab))
target.update(tmp_target, target_name)
model = Model(args, embedding, encoder, tgt_embedding, decoder, target)
return model
|