File size: 3,440 Bytes
223340a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'''
Author: Qiguang Chen
LastEditors: Qiguang Chen
Date: 2023-02-13 10:44:39
LastEditTime: 2023-02-19 15:45:08
Description: 

'''

import argparse
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import dill

from common.config import Config
from common.model_manager import ModelManager
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoTokenizer, PreTrainedTokenizer

class PretrainedConfigForSLUToSave(PretrainedConfig):
    def __init__(self, **kargs) -> None:
        cfg = model_manager.config
        kargs["name_or_path"] = cfg.base["name"]
        kargs["return_dict"] = False
        kargs["is_decoder"] = True
        kargs["_id2label"] = {"intent": model_manager.intent_list, "slot": model_manager.slot_list}
        kargs["_label2id"] = {"intent": model_manager.intent_dict, "slot": model_manager.slot_dict}
        kargs["_num_labels"] = {"intent": len(model_manager.intent_list), "slot": len(model_manager.slot_list)}
        kargs["tokenizer_class"] = cfg.base["name"]
        kargs["vocab_size"] = model_manager.tokenizer.vocab_size
        kargs["model"] = cfg.model
        kargs["model"]["decoder"]["intent_classifier"]["intent_label_num"] = len(model_manager.intent_list)
        kargs["model"]["decoder"]["slot_classifier"]["slot_label_num"] = len(model_manager.slot_list)
        kargs["tokenizer"] = cfg.tokenizer
        len(model_manager.slot_list)
        super().__init__(**kargs)

class PretrainedModelForSLUToSave(PreTrainedModel):
    def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.model = model_manager.model
        self.config_class = config


class PreTrainedTokenizerForSLUToSave(PreTrainedTokenizer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.tokenizer = model_manager.tokenizer
    
    # @overload
    def save_vocabulary(self, save_directory: str, filename_prefix = None):
        if filename_prefix is not None:
            path = os.path.join(save_directory, filename_prefix+"-tokenizer.pkl")
        else:
            path = os.path.join(save_directory, "tokenizer.pkl")
        # tokenizer_name=model_manager.config.tokenizer.get("_tokenizer_name_")
        # if tokenizer_name == "word_tokenizer":
        #     self.tokenizer.save(path)
        # else:
        #     torch.save()
        with open(path,'wb') as f:
            dill.dump(self.tokenizer,f)
        return (path,)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', '-cp', type=str, required=True)
    parser.add_argument('--output_path', '-op', type=str, default="save/temp")
    args = parser.parse_args()
    config = Config.load_from_yaml(args.config_path)
    config.base["train"] = False
    config.base["test"] = False
    if config.model_manager["load_dir"] is None:
        config.model_manager["load_dir"] = config.model_manager["save_dir"]
    model_manager = ModelManager(config)
    model_manager.load()
    model_manager.config.autoload_template()
    
    pretrained_config = PretrainedConfigForSLUToSave()
    pretrained_model= PretrainedModelForSLUToSave(pretrained_config)
    pretrained_model.save_pretrained(args.output_path)

    pretrained_tokenizer = PreTrainedTokenizerForSLUToSave()
    pretrained_tokenizer.save_pretrained(args.output_path)