|
""" |
|
Sample from the trained model with PyTorch |
|
""" |
|
import os |
|
import pickle |
|
from contextlib import nullcontext |
|
import torch |
|
from model import ModelArgs, Transformer |
|
from tokenizer import Tokenizer |
|
|
|
from tinystories import get_tokenizer_model_path |
|
|
|
|
|
checkpoint = 'out/ckpt.pt' |
|
start = "" |
|
num_samples = 1 |
|
max_new_tokens = 100 |
|
temperature = 1.0 |
|
top_k = 300 |
|
tokenizer = "" |
|
seed = 1337 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
dtype = "float32" |
|
compile = False |
|
exec(open('configurator.py').read()) |
|
|
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' if 'cuda' in device else 'cpu' |
|
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
|
|
checkpoint_dict = torch.load(checkpoint, map_location=device) |
|
gptconf = ModelArgs(**checkpoint_dict['model_args']) |
|
model = Transformer(gptconf) |
|
state_dict = checkpoint_dict['model'] |
|
unwanted_prefix = '_orig_mod.' |
|
for k,v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
model.eval() |
|
model.to(device) |
|
if compile: |
|
print("Compiling the model...") |
|
model = torch.compile(model) |
|
|
|
|
|
vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2") |
|
vocab_size = gptconf.vocab_size |
|
if tokenizer: |
|
|
|
tokenizer_model = tokenizer |
|
else: |
|
|
|
query_vocab_size = 0 if vocab_source == "llama2" else vocab_size |
|
tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size) |
|
enc = Tokenizer(tokenizer_model=tokenizer_model) |
|
|
|
|
|
if start.startswith('FILE:'): |
|
with open(start[5:], 'r', encoding='utf-8') as f: |
|
start = f.read() |
|
start_ids = enc.encode(start, bos=True, eos=False) |
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
|
|
with torch.no_grad(): |
|
with ctx: |
|
for k in range(num_samples): |
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
|
print(enc.decode(y[0].tolist())) |
|
print('---------------') |
|
|