Spaces:
Runtime error
Runtime error
''' | |
Author: Qiguang Chen | |
LastEditors: Qiguang Chen | |
Date: 2023-02-13 10:44:39 | |
LastEditTime: 2023-02-19 15:45:08 | |
Description: | |
''' | |
import argparse | |
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import dill | |
from common.config import Config | |
from common.model_manager import ModelManager | |
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoTokenizer, PreTrainedTokenizer | |
class PretrainedConfigForSLUToSave(PretrainedConfig): | |
def __init__(self, **kargs) -> None: | |
cfg = model_manager.config | |
kargs["name_or_path"] = cfg.base["name"] | |
kargs["return_dict"] = False | |
kargs["is_decoder"] = True | |
kargs["_id2label"] = {"intent": model_manager.intent_list, "slot": model_manager.slot_list} | |
kargs["_label2id"] = {"intent": model_manager.intent_dict, "slot": model_manager.slot_dict} | |
kargs["_num_labels"] = {"intent": len(model_manager.intent_list), "slot": len(model_manager.slot_list)} | |
kargs["tokenizer_class"] = cfg.base["name"] | |
kargs["vocab_size"] = model_manager.tokenizer.vocab_size | |
kargs["model"] = cfg.model | |
kargs["model"]["decoder"]["intent_classifier"]["intent_label_num"] = len(model_manager.intent_list) | |
kargs["model"]["decoder"]["slot_classifier"]["slot_label_num"] = len(model_manager.slot_list) | |
kargs["tokenizer"] = cfg.tokenizer | |
len(model_manager.slot_list) | |
super().__init__(**kargs) | |
class PretrainedModelForSLUToSave(PreTrainedModel): | |
def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None: | |
super().__init__(config, *inputs, **kwargs) | |
self.model = model_manager.model | |
self.config_class = config | |
class PreTrainedTokenizerForSLUToSave(PreTrainedTokenizer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.tokenizer = model_manager.tokenizer | |
# @overload | |
def save_vocabulary(self, save_directory: str, filename_prefix = None): | |
if filename_prefix is not None: | |
path = os.path.join(save_directory, filename_prefix+"-tokenizer.pkl") | |
else: | |
path = os.path.join(save_directory, "tokenizer.pkl") | |
# tokenizer_name=model_manager.config.tokenizer.get("_tokenizer_name_") | |
# if tokenizer_name == "word_tokenizer": | |
# self.tokenizer.save(path) | |
# else: | |
# torch.save() | |
with open(path,'wb') as f: | |
dill.dump(self.tokenizer,f) | |
return (path,) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config_path', '-cp', type=str, required=True) | |
parser.add_argument('--output_path', '-op', type=str, default="save/temp") | |
args = parser.parse_args() | |
config = Config.load_from_yaml(args.config_path) | |
config.base["train"] = False | |
config.base["test"] = False | |
if config.model_manager["load_dir"] is None: | |
config.model_manager["load_dir"] = config.model_manager["save_dir"] | |
model_manager = ModelManager(config) | |
model_manager.load() | |
model_manager.config.autoload_template() | |
pretrained_config = PretrainedConfigForSLUToSave() | |
pretrained_model= PretrainedModelForSLUToSave(pretrained_config) | |
pretrained_model.save_pretrained(args.output_path) | |
pretrained_tokenizer = PreTrainedTokenizerForSLUToSave() | |
pretrained_tokenizer.save_pretrained(args.output_path) | |