moiduy04 commited on
Commit
e8a4189
·
1 Parent(s): d7c7d33

Upload load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +50 -0
load_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from transformer import get_model, Transformer
5
+ from config import load_config, get_weights_file_path
6
+ from train import get_local_dataset_tokenizer
7
+ from tokenizer import get_or_build_local_tokenizer
8
+
9
+ from tokenizers import Tokenizer
10
+
11
+
12
+ def load_model_tokenizer(
13
+ config,
14
+ device = torch.device('cpu'),
15
+ ) -> Tuple[Transformer, Tokenizer, Tokenizer]:
16
+ """
17
+ Loads a local model and tokenizer from a given config
18
+ """
19
+ if config['model']['preload'] is None:
20
+ raise ValueError('Unspecified preload model')
21
+
22
+ src_tokenizer = get_or_build_local_tokenizer(
23
+ config=config,
24
+ ds=None,
25
+ lang=config['dataset']['src_lang'],
26
+ tokenizer_type=config['dataset']['src_tokenizer']
27
+ )
28
+ tgt_tokenizer = get_or_build_local_tokenizer(
29
+ config=config,
30
+ ds=None,
31
+ lang=config['dataset']['tgt_lang'],
32
+ tokenizer_type=config['dataset']['tgt_tokenizer']
33
+ )
34
+
35
+ model = get_model(
36
+ config,
37
+ src_tokenizer.get_vocab_size(),
38
+ tgt_tokenizer.get_vocab_size(),
39
+ ).to(device)
40
+
41
+ model_filename = get_weights_file_path(config, config['model']['preload'])
42
+ state = torch.load(model_filename, map_location=device)
43
+ model.load_state_dict(state['model_state_dict'])
44
+
45
+ print(f'Finish loading model and tokenizers')
46
+ return (model, src_tokenizer, tgt_tokenizer)
47
+
48
+ if __name__ == '__main__':
49
+ config = load_config(file_name='config/config_final.yaml')
50
+ model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)