|
""" |
|
Sample from a trained model |
|
""" |
|
import os |
|
import pickle |
|
from contextlib import nullcontext |
|
import torch |
|
import tiktoken |
|
from model import GPTConfig, GPT |
|
from tqdm import tqdm |
|
import random |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode |
|
import argparse |
|
import itertools |
|
import random |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--init_from", type=str, default="resume", help="Directory of raw data & output files") |
|
parser.add_argument("--out_path", type=str, required=True) |
|
parser.add_argument("--num_samples", type=int, required=False, default=100000) |
|
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample") |
|
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']") |
|
parser.add_argument("--beam_size",type=int, required=False,default=3,help="beam size for beam search") |
|
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions") |
|
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability") |
|
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model") |
|
parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a tokenizer directory") |
|
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>") |
|
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0) |
|
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding") |
|
parser.add_argument("--fasta", action='store_true', default=True, help="Enable writing output in FASTA format") |
|
|
|
args = parser.parse_args() |
|
init_from = args.init_from |
|
out_path = args.out_path |
|
num_samples = args.num_samples |
|
max_new_tokens = args.max_new_tokens |
|
strategy = args.strategy |
|
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search'] |
|
beam_size = args.beam_size |
|
temperature = args.temperature |
|
top_k = args.top_k |
|
ckpt_path = args.ckpt_path |
|
tokenizer_path = args.tokenizer_path |
|
start = args.start |
|
repetition_penalty = args.repetition_penalty |
|
fasta = args.fasta |
|
|
|
|
|
|
|
seed = random.randint(1,6666) |
|
device = 'cuda' |
|
dtype = 'float32' |
|
|
|
compile = False |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' if 'cuda' in device else 'cpu' |
|
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
|
|
if init_from == 'resume': |
|
|
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
gptconf = GPTConfig(**checkpoint['model_args']) |
|
model = GPT(gptconf) |
|
state_dict = checkpoint['model'] |
|
unwanted_prefix = '_orig_mod.' |
|
for k,v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
elif init_from.startswith('gpt2'): |
|
|
|
model = GPT.from_pretrained(init_from, dict(dropout=0.0)) |
|
|
|
model.eval() |
|
model.to(device) |
|
if compile: |
|
model = torch.compile(model) |
|
|
|
|
|
load_meta = False |
|
encode = tokenizer.encode |
|
decode = tokenizer.decode |
|
|
|
fasta_out_path = os.path.splitext(out_path)[0] + ".fasta" if fasta else None |
|
|
|
if strategy in["sampling", "top_k"]: |
|
start_ids = encode("".join(start)) |
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
|
|
with open(out_path, 'a') as f: |
|
with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f: |
|
with torch.no_grad(): |
|
with ctx: |
|
for k in tqdm(range(num_samples), desc="Generating samples"): |
|
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist() |
|
|
|
|
|
if args.shuffle_token: |
|
random.shuffle(token_sequence) |
|
|
|
y = decode(token_sequence).replace(' ', '') |
|
|
|
f.write(y) |
|
f.flush() |
|
|
|
|
|
if fasta: |
|
fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n" |
|
fasta_f.write(fasta_entry.strip() + '\n') |
|
fasta_f.flush() |
|
|
|
|
|
elif strategy in ["beam_search", "greedy_search"]: |
|
with open(out_path, 'a') as f: |
|
with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f: |
|
with torch.no_grad(): |
|
with ctx: |
|
start = '<|endoftext|>' |
|
start_ids = encode(start) |
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, beam_size=beam_size)[0].tolist() |
|
|
|
y = decode(token_sequence).replace(' ', '') |
|
f.write(y) |
|
f.flush() |
|
|
|
|
|
if fasta: |
|
fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n" |
|
fasta_f.write(fasta_entry.strip() + '\n') |
|
fasta_f.flush() |