|
from src.utils import add_special_token |
|
import os.path as osp |
|
import numpy as np |
|
import random |
|
import torch |
|
from easydict import EasyDict as edict |
|
import argparse |
|
import pdb |
|
import json |
|
from model import BertTokenizer |
|
from collections import Counter |
|
from tqdm import tqdm |
|
from time import time |
|
from numpy import mean |
|
import math |
|
|
|
from transformers import BertModel |
|
|
|
|
|
class cfg(): |
|
def __init__(self): |
|
self.this_dir = osp.dirname(__file__) |
|
|
|
self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', '')) |
|
|
|
def get_args(self): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path") |
|
parser.add_argument("--update_model_name", default='MacBert', type=str, help="MacBert") |
|
parser.add_argument("--pretrained_model_name", default='TeleBert', type=str, help="TeleBert") |
|
parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件") |
|
self.cfg = parser.parse_args() |
|
|
|
def update_train_configs(self): |
|
|
|
self.cfg.data_root = self.data_root |
|
self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path) |
|
|
|
return self.cfg |
|
|
|
|
|
if __name__ == '__main__': |
|
''' |
|
功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据) |
|
''' |
|
cfg = cfg() |
|
cfg.get_args() |
|
cfgs = cfg.update_train_configs() |
|
|
|
|
|
path = osp.join(cfgs.data_root, 'transformer', cfgs.update_model_name) |
|
assert osp.exists(path) |
|
tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True) |
|
tokenizer, special_token, norm_token = add_special_token(tokenizer) |
|
added_vocab = tokenizer.get_added_vocab() |
|
vocb_path = osp.join(cfgs.data_path, 'added_vocab.json') |
|
|
|
with open(vocb_path, 'w') as fp: |
|
json.dump(added_vocab, fp, ensure_ascii=False) |
|
|
|
vocb_description = osp.join(cfgs.data_path, 'vocab_descrip.json') |
|
vocb_descrip = None |
|
|
|
vocb_descrip = { |
|
"alm": "alarm", |
|
"ran": "ran 无线接入网", |
|
"mml": "MML 人机语言命令", |
|
"nf": "NF 独立网络服务", |
|
"apn": "APN 接入点名称", |
|
"pgw": "PGW 数据管理子系统模块", |
|
"lst": "LST 查询命令", |
|
"qos": "QoS 定制服务质量", |
|
"ipv": "IPV 互联网通讯协议版本", |
|
"ims": "IMS IP多模态子系统", |
|
"gtp": "GTP GPRS隧道协议", |
|
"pdp": "PDP 分组数据协议", |
|
"hss": "HSS HTTP Smooth Stream", |
|
"[ALM]": "alarm 告警 标记", |
|
"[KPI]": "kpi 关键性能指标 标记", |
|
"[LOC]": "location 事件发生位置 标记", |
|
"[EOS]": "end of the sentence 文档结尾 标记", |
|
"[ENT]": "实体标记", |
|
"[ATTR]": "属性标记", |
|
"[NUM]": "数值标记", |
|
"[REL]": "关系标记", |
|
"[DOC]": "文档标记" |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
path = osp.join(cfgs.data_root, 'transformer', cfgs.pretrained_model_name) |
|
assert osp.exists(path) |
|
pre_tokenizer = BertTokenizer.from_pretrained(path, do_lower_case=True) |
|
model = BertModel.from_pretrained(path) |
|
|
|
print("use the vocb_description") |
|
key_to_emb = {} |
|
for key in added_vocab.keys(): |
|
if vocb_description is not None: |
|
if key in vocb_description: |
|
|
|
key_tokens = pre_tokenizer(vocb_description[key], return_tensors='pt') |
|
else: |
|
key_tokens = pre_tokenizer(key, return_tensors='pt') |
|
else: |
|
key_tokens = pre_tokenizer(key, return_tensors='pt') |
|
|
|
hidden_state = model(**key_tokens, output_hidden_states=True).hidden_states |
|
pdb.set_trace() |
|
key_to_emb[key] = hidden_state[-1][:, 1:-1, :].mean(dim=1) |
|
|
|
emb_path = osp.join(cfgs.data_path, 'added_vocab_embedding.pt') |
|
|
|
torch.save(key_to_emb, emb_path) |
|
print(f'save to {emb_path}') |
|
|