VISOR-GPT / train /scripts /generate_lm_multiple.py
szukevin's picture
upload
7900c16
raw
history blame
6.99 kB
"""
This script provides an exmaple to wrap TencentPretrain for generation.
Given the beginning of a text, language model generates the rest.
"""
import sys
import os
import argparse
import torch
import torch.nn.functional as F
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)
from tencentpretrain.embeddings import *
from tencentpretrain.encoders import *
from tencentpretrain.targets import *
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import infer_opts, tokenizer_opts
from tqdm import tqdm
class GenerateLm(torch.nn.Module):
def __init__(self, args):
super(GenerateLm, self).__init__()
self.embedding = Embedding(args)
for embedding_name in args.embedding:
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
self.embedding.update(tmp_emb, embedding_name)
self.encoder = str2encoder[args.encoder](args)
self.target = Target()
self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm")
def forward(self, src, seg):
emb = self.embedding(src, seg)
output = self.encoder(emb, seg)
output = self.target.lm.output_layer(output)
return output
def top_k_top_p_filtering(logits, top_k, top_p):
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -float("Inf")
return logits
def build_visorgpt(model_path,
model_config,
vocab_path='TencentPretrain/models/google_uncased_en_coord_vocab.txt'):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
infer_opts(parser)
tokenizer_opts(parser)
parser.add_argument("--top_k", type=int, default=70)
parser.add_argument("--top_p", type=float, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
args = parser.parse_args()
args.target = "lm"
args.batch_size = 1
args.load_model_path = model_path
args.config_path = model_config
args.vocab_path = vocab_path
args = load_hyperparam(args)
args.seq_length = 1024
args.tokenizer = str2tokenizer[args.tokenizer](args)
model = GenerateLm(args)
model = load_model(model, args.load_model_path).cuda()
model.eval()
return args, model
def gen_sequence(args, model, input_text):
lines = [input_text]
generated_texts = []
for line in tqdm(lines):
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line))
seg = [1] * len(src)
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor = torch.LongTensor([src]).cuda(), torch.LongTensor([seg]).cuda()
for i in range(args.seq_length - beginning_length):
output = model(src_tensor, seg_tensor)
next_token_logits = output[0][-1] / args.temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1).cuda()
src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).cuda()], dim=1)
# generated_texts.append(line)
generated_sentence = " ".join(
args.tokenizer.convert_ids_to_tokens([token_id.item() for token_id in src_tensor[0]])
)
generated_texts.append(generated_sentence)
return generated_texts
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
infer_opts(parser)
parser.add_argument("--top_k", type=int, default=70)
parser.add_argument("--top_p", type=float, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--save_dir", type=str, default='predictions')
tokenizer_opts(parser)
args = parser.parse_args()
args.target = "lm"
args.batch_size = 1
args = load_hyperparam(args)
args.tokenizer = str2tokenizer[args.tokenizer](args)
model = GenerateLm(args)
model = load_model(model, args.load_model_path).cuda()
model.eval()
with open(args.test_path, mode="r", encoding="utf-8") as f:
lines = [i.strip() for i in f.readlines()]
generated_texts = []
for line in tqdm(lines):
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line))
seg = [1] * len(src)
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor = torch.LongTensor([src]).cuda(), torch.LongTensor([seg]).cuda()
for i in range(args.seq_length - beginning_length):
output = model(src_tensor, seg_tensor)
next_token_logits = output[0][-1] / args.temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1).cuda()
src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).cuda()], dim=1)
# generated_texts.append(line)
generated_sentence = " ".join(
args.tokenizer.convert_ids_to_tokens([token_id.item() for token_id in src_tensor[0]])
)
generated_texts.append(generated_sentence)
# import ipdb
# ipdb.set_trace()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
with open(args.save_dir + '/' + args.prediction_path, mode="w", encoding="utf-8") as f:
for t in generated_texts:
f.write(t + "\n")