File size: 1,038 Bytes
0f53151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# torch packages
import torch
from model.transformer import Transformer
import json

if __name__ == "__main__":
    """
    Following parameters are for Multi30K dataset
    """
    # Load config containing model input parameters
    with open('params.json') as json_data:
        config = json.load(json_data)
    print(config)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instantiate model
    model = Transformer(
                    config["dk"], 
                    config["dv"], 
                    config["h"],
                    config["src_vocab_size"],
                    config["target_vocab_size"],
                    config["num_encoders"],
                    config["num_decoders"],
                    config["dim_multiplier"], 
                    config["pdropout"],
                    device = device)
    # Load model weights
    model.load_state_dict(torch.load('pytorch_transformer_model.pt', 
                                     map_location=device))
    print(model)