File size: 1,717 Bytes
613e5d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Run simply with
$ pytest
"""
import os
import pytest # pip install pytest
import subprocess

import torch
from model import ModelArgs, Transformer

def test_argmax_inference():
    """
    Only the simplest test for now: run inference with temperature 0 
    (for determinism) in both C and PyTorch, and see that the sampled tokens 
    are the same.
    """
    test_ckpt_dir = "out" # TODO create a dummy test checkpoint for this?

    # run C version
    model_path = os.path.join(test_ckpt_dir, "model.bin")
    command = ["./run", model_path, "0.0"]
    proc = subprocess.Popen(command, stdout=subprocess.PIPE)
    c_tokens = []
    for line in proc.stdout:
        token = int(line.decode('utf-8').strip())
        c_tokens.append(token)
    proc.wait()
    #print(c_tokens)

    # run PyTorch version
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt_path = os.path.join(test_ckpt_dir, "ckpt.pt")
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = ModelArgs(**checkpoint['model_args'])
    model = Transformer(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    model.to(device)
    x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS
    with torch.inference_mode():
        y = model.generate(x, max_new_tokens=gptconf.max_seq_len, temperature=0.0)
    pt_tokens = y[0].tolist()
    pt_tokens = pt_tokens[1:] # remove BOS
    #print(pt_tokens)

    # compare
    assert c_tokens == pt_tokens