|
import os.path as osp |
|
import numpy as np |
|
import random |
|
import torch |
|
from easydict import EasyDict as edict |
|
import argparse |
|
import pdb |
|
import json |
|
from model import BertTokenizer |
|
from collections import Counter |
|
from ltp import LTP |
|
from tqdm import tqdm |
|
from src.utils import add_special_token |
|
from functools import reduce |
|
from time import time |
|
from numpy import mean |
|
import math |
|
|
|
from src.utils import Loss_log, time_trans |
|
from collections import defaultdict |
|
|
|
|
|
class cfg(): |
|
def __init__(self): |
|
self.this_dir = osp.dirname(__file__) |
|
|
|
self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', '')) |
|
|
|
def get_args(self): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path") |
|
|
|
parser.add_argument("--freq", default=50, type=int, help="出现多少次的词认为是重要的") |
|
parser.add_argument("--batch_size", default=100, type=int, help="分词的batch size") |
|
parser.add_argument("--seq_data_name", default='Seq_data_large', type=str, help="seq_data 名字") |
|
parser.add_argument("--deal_numeric", default=0, type=int, help="是否处理数值数据") |
|
|
|
parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件") |
|
self.cfg = parser.parse_args() |
|
|
|
def update_train_configs(self): |
|
|
|
self.cfg.data_root = self.data_root |
|
self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path) |
|
|
|
return self.cfg |
|
|
|
|
|
def refresh_data(ref, freq, special_token): |
|
''' |
|
功能:在自定义的special token基础上基于最小出现频率得到更多新词分词系统的参考,作为wwm基础 |
|
输入: |
|
freq: 在(37万)语义词典中的最小出现频率(空格为分词) |
|
special_token: 前面手工定义的特殊token(可能存在交集) |
|
输出: |
|
add_words:在定义的最小出现频率基础上筛选出来的新词 |
|
''' |
|
|
|
seq_sub_data = [line.split() for line in ref] |
|
all_data = [] |
|
for data in seq_sub_data: |
|
all_data.extend(data) |
|
sub_word_times = dict(Counter(all_data)) |
|
asub_word_time_order = sorted(sub_word_times.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
add_words = [] |
|
|
|
for i in asub_word_time_order: |
|
|
|
if i[1] >= freq and len(i[0]) > 1 and len(i[0]) < 20 and not str.isdigit(i[0]): |
|
add_words.append(i[0]) |
|
add_words.extend(special_token) |
|
|
|
print(f"[{len(add_words)}] special words will be added with frequency [{freq}]!") |
|
return add_words |
|
|
|
|
|
def cws(seq_data, add_words, batch_size): |
|
''' |
|
功能:所有序列数据的输入转换成分词之后的结果 |
|
输入: |
|
seq_data:所有序列数据输入 e.g.['KPI异常下降', 'KPI异常上升'] |
|
add_words:添加的special words |
|
batch_size:每次分多少句 |
|
输出: |
|
all_segment:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']] |
|
data_size:输入/输出的序列数量(e.g. 2) |
|
''' |
|
|
|
print(f"loading...") |
|
ltp = LTP("LTP/base2") |
|
|
|
print(f"begin adding words ...") |
|
|
|
ltp.add_words(words=add_words) |
|
ltp.to("cuda") |
|
|
|
|
|
print(f"{len(add_words)} special words are added!") |
|
|
|
|
|
|
|
|
|
data_size = len(seq_data) |
|
seq_data_cws = [] |
|
size = int(data_size / batch_size) + 1 |
|
b = 0 |
|
e = b + batch_size |
|
|
|
|
|
log = Loss_log() |
|
|
|
with tqdm(total=size) as _tqdm: |
|
|
|
|
|
|
|
error_data = [] |
|
for i in range(size): |
|
|
|
output = [] |
|
try: |
|
_output = ltp.pipeline(seq_data[b:e], tasks=["cws"]) |
|
for data in _output.cws: |
|
try: |
|
data_out = ltp.pipeline(data, tasks=["cws"]) |
|
|
|
data_out_ = [] |
|
for i in data_out.cws: |
|
data_out_.extend([k.strip() for k in i]) |
|
output.append(data_out_) |
|
except: |
|
print(f"二阶段分词出错!范围是:[{b}]-[{e}]") |
|
error_data.append(data) |
|
|
|
|
|
except: |
|
print(f"第一阶段分词出错!范围是:[{b}]-[{e}]") |
|
error_data.append(f"第一阶段分词出错!范围是:[{b}]-[{e}]") |
|
|
|
seq_data_cws.extend(output) |
|
b = e |
|
e += batch_size |
|
|
|
|
|
if e >= data_size: |
|
if b >= data_size: |
|
break |
|
e = data_size |
|
_tqdm.set_description(f'from {b} to {e}:') |
|
_tqdm.update(1) |
|
|
|
print(f"过滤了{data_size - len(seq_data_cws)}个句子") |
|
|
|
return seq_data_cws, data_size, error_data |
|
|
|
|
|
def ltp_debug(ltp, op): |
|
output = [] |
|
for data in op: |
|
data_out = ltp.pipeline(data, tasks=["cws"]) |
|
|
|
data_out_ = [] |
|
for i in data_out.cws: |
|
|
|
data_out_.append(i[0].strip()) |
|
|
|
|
|
output.append(data_out_) |
|
return output |
|
|
|
|
|
def deal_sub_words(subwords, special_token): |
|
''' |
|
功能:把每个word的整体内,非首字符的部分加上 '##' 前缀, special_token 不应该被mask |
|
''' |
|
for i in range(len(subwords)): |
|
if i == 0: |
|
continue |
|
if subwords[i] in special_token: |
|
continue |
|
if subwords[i].startswith("##"): |
|
continue |
|
|
|
subwords[i] = "##" + subwords[i] |
|
return subwords |
|
|
|
|
|
def generate_chinese_ref(seq_data_cws, special_token, deal_numeric, kpi_dic): |
|
''' |
|
输入: |
|
seq_data_cws:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']] |
|
special_token:不应该被mask ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '|'] |
|
data_size:数据量 e.g. 2 |
|
输出: |
|
ww_return (whole word return):打标之后的chinese ref e.g. [['KPI', '异','##常', '下', '##降'], ['KPI', '异', '##常', '上', '##升']] |
|
''' |
|
|
|
data_size = len(seq_data_cws) |
|
kpi_static_set = set() |
|
rev_kpi_dic = dict(zip(kpi_dic.values(), kpi_dic.keys())) |
|
max_len = 0 |
|
sten_that_over_maxl = [] |
|
with tqdm(total=data_size) as _tqdm: |
|
ww_return = [] |
|
ww_list = [] |
|
kpi_info = [] |
|
not_in_KPI = defaultdict(int) |
|
for i in range(data_size): |
|
_tqdm.set_description(f'checking...[{i}/{data_size}] max len: [{max_len}]') |
|
orig = tokenizer.tokenize(" ".join(seq_data_cws[i])) |
|
|
|
if deal_numeric: |
|
|
|
_kpi_info, kpi_type_list = extract_kpi(orig, kpi_dic, not_in_KPI) |
|
kpi_info.append(_kpi_info) |
|
kpi_static_set.update(kpi_type_list) |
|
|
|
sub_total = [] |
|
ww_seq_tmp = [] |
|
ww_tmp = [] |
|
for sub_data in seq_data_cws[i]: |
|
sub = tokenizer.tokenize(sub_data) |
|
sub_total.extend(sub) |
|
|
|
|
|
ref_token = deal_sub_words(sub, special_token) |
|
|
|
ww_seq_tmp.extend(ref_token) |
|
ww_tmp.append(ref_token) |
|
|
|
if sub_total != orig: |
|
print("error in match... ") |
|
if len(orig) > 512: |
|
print("the lenth is over the max lenth") |
|
pdb.set_trace() |
|
|
|
|
|
|
|
sz_ww_seq = len(ww_seq_tmp) |
|
|
|
max_len = sz_ww_seq if sz_ww_seq > max_len else max_len |
|
if sz_ww_seq > 500: |
|
sten_that_over_maxl.append((ww_seq_tmp, sz_ww_seq)) |
|
|
|
assert len(sub_total) == sz_ww_seq |
|
ww_return.append(ww_seq_tmp) |
|
ww_list.append(ww_tmp) |
|
|
|
_tqdm.update(1) |
|
|
|
if deal_numeric: |
|
in_kpi = [] |
|
|
|
for key in rev_kpi_dic.keys(): |
|
if key in kpi_static_set: |
|
in_kpi.append(rev_kpi_dic[key]) |
|
if len(in_kpi) < len(rev_kpi_dic): |
|
print(f"[{len(in_kpi)}] KPI are covered by data: {in_kpi}") |
|
print(f" [{len(not_in_KPI)}] KPI无法匹配{not_in_KPI}") |
|
else: |
|
print("all KPI are covered!") |
|
return ww_return, kpi_info, sten_that_over_maxl |
|
|
|
|
|
def extract_num(seq_data_cws): |
|
''' |
|
功能:把序列中的数值信息提取出来 |
|
同时过滤 nan 数值 |
|
''' |
|
num_ref = [] |
|
seq_data_cws_new = [] |
|
for j in range(len(seq_data_cws)): |
|
num_index = [i for i, x in enumerate(seq_data_cws[j]) if x == '[NUM]'] |
|
|
|
kpi_score = [] |
|
flag = 1 |
|
for index in num_index: |
|
|
|
|
|
try: |
|
tmp = float(seq_data_cws[j][index + 1]) |
|
except: |
|
|
|
flag = 0 |
|
continue |
|
if math.isnan(tmp): |
|
flag = 0 |
|
else: |
|
kpi_score.append(tmp) |
|
|
|
if len(num_index) > 0: |
|
for index in reversed(num_index): |
|
seq_data_cws[j].pop(index + 1) |
|
if flag == 1: |
|
num_ref.append(kpi_score) |
|
seq_data_cws_new.append(seq_data_cws[j]) |
|
return seq_data_cws_new, num_ref |
|
|
|
|
|
def extract_kpi(token_data, kpi_dic, not_in_KPI): |
|
''' |
|
功能:把序列中的[KPI]下标范围,[NUM]下标提取出来 |
|
输出格式: [(1,2,4),(5,6,7)] |
|
''' |
|
kpi_and_num_info = [] |
|
kpi_type = [] |
|
kpi_index = [i for i, x in enumerate(token_data) if x.lower() == '[kpi]'] |
|
num_index = [i for i, x in enumerate(token_data) if x.lower() == '[num]'] |
|
sz = len(kpi_index) |
|
assert sz == len(num_index) |
|
for i in range(sz): |
|
|
|
|
|
kpi_name = ''.join(token_data[kpi_index[i] + 1: num_index[i] - 1]) |
|
kpi_name_clear = kpi_name.replace('##', '') |
|
|
|
if kpi_name in kpi_dic: |
|
kpi_id = int(kpi_dic[kpi_name]) |
|
elif kpi_name_clear in kpi_dic: |
|
kpi_id = int(kpi_dic[kpi_name_clear]) |
|
elif kpi_name_clear in not_in_KPI: |
|
kpi_id = -1 |
|
not_in_KPI[kpi_name_clear] += 1 |
|
else: |
|
|
|
not_in_KPI[kpi_name_clear] += 1 |
|
kpi_id = -1 |
|
|
|
|
|
kpi_info = [kpi_index[i] + 1, num_index[i] - 2, num_index[i], kpi_id] |
|
kpi_and_num_info.append(kpi_info) |
|
kpi_type.append(kpi_id) |
|
|
|
|
|
return kpi_and_num_info, kpi_type |
|
|
|
|
|
def kpi_combine(kpi_info, num_ref): |
|
sz = len(kpi_info) |
|
assert sz == len(num_ref) |
|
for i in range(sz): |
|
for j in range(len(kpi_info[i])): |
|
kpi_info[i][j].append(num_ref[i][j]) |
|
|
|
return kpi_info |
|
|
|
|
|
|
|
|
|
def kpi_lower_update(kpi_dic): |
|
new_dic = {} |
|
for key in kpi_dic: |
|
kk = key.lower().split() |
|
kk = ''.join(kk).strip() |
|
new_dic[kk] = kpi_dic[key] |
|
return new_dic |
|
|
|
|
|
if __name__ == '__main__': |
|
''' |
|
功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据) |
|
''' |
|
cfg = cfg() |
|
cfg.get_args() |
|
cfgs = cfg.update_train_configs() |
|
|
|
|
|
domain_file_path = osp.join(cfgs.data_path, 'special_vocab.txt') |
|
with open(domain_file_path, encoding="utf-8") as f: |
|
ref = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] |
|
tokenizer = BertTokenizer.from_pretrained(osp.join(cfgs.data_root, 'transformer', 'MacBert'), do_lower_case=True) |
|
seq_data_name = cfgs.seq_data_name |
|
with open(osp.join(cfgs.data_path, f'{seq_data_name}.json'), "r") as fp: |
|
seq_data = json.load(fp) |
|
kpi_dic_name = 'kpi2id' |
|
with open(osp.join(cfgs.data_path, f'{kpi_dic_name}.json'), "r") as fp: |
|
kpi_dic = json.load(fp) |
|
kpi_dic = kpi_lower_update(kpi_dic) |
|
|
|
random.shuffle(seq_data) |
|
|
|
print(f"tokenizer size before: {len(tokenizer)}") |
|
tokenizer, special_token, norm_token = add_special_token(tokenizer) |
|
special_token = special_token + norm_token |
|
|
|
print(f"tokenizer size after: {len(tokenizer)}") |
|
print('------------------------ refresh data --------------------------------') |
|
add_words = refresh_data(ref, cfgs.freq, special_token) |
|
|
|
if not cfgs.read_cws: |
|
print('------------------------ cws ----------------------------------') |
|
seq_data_cws, data_size, error_data = cws(seq_data, add_words, cfgs.batch_size) |
|
print(f'batch size is {cfgs.batch_size}') |
|
if len(error_data) > 0: |
|
with open(osp.join(cfgs.data_path, f'{seq_data_name}_error.json'), "w") as fp: |
|
json.dump(error_data, fp, ensure_ascii=False) |
|
save_path_cws_orig = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json') |
|
print("get the new training data! saving...") |
|
with open(save_path_cws_orig, 'w', ) as fp: |
|
json.dump(seq_data_cws, fp, ensure_ascii=False) |
|
else: |
|
print('------------------------ read ----------------------------------') |
|
save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json') |
|
print("get the new training data!") |
|
with open(save_path_cws, 'r', ) as fp: |
|
seq_data_cws = json.load(fp) |
|
data_size = len(seq_data_cws) |
|
|
|
sz_orig = len(seq_data_cws) |
|
if cfgs.deal_numeric: |
|
seq_data_cws, num_ref = extract_num(seq_data_cws) |
|
print(f"过滤了{sz_orig - len(seq_data_cws)}个无效句子") |
|
data_size = len(seq_data_cws) |
|
|
|
print('---------------------- generate chinese ref ------------------------------') |
|
chinese_ref, kpi_info, sten_that_over_maxl = generate_chinese_ref(seq_data_cws, special_token, cfgs.deal_numeric, kpi_dic) |
|
|
|
if len(sten_that_over_maxl) > 0: |
|
print(f"{len(sten_that_over_maxl)} over the 500 len!") |
|
save_path_max = osp.join(cfgs.data_path, f'{seq_data_name}_max_len_500.json') |
|
with open(save_path_max, 'w') as fp: |
|
json.dump(sten_that_over_maxl, fp, ensure_ascii=False) |
|
|
|
if cfgs.deal_numeric: |
|
print("KPI info combine") |
|
kpi_ref = kpi_combine(kpi_info, num_ref) |
|
|
|
print('------------------------- match finished ------------------------------') |
|
|
|
|
|
save_path_ref = osp.join(cfgs.data_path, f'{seq_data_name}_chinese_ref.json') |
|
with open(save_path_ref, 'w') as fp: |
|
json.dump(chinese_ref, fp, ensure_ascii=False) |
|
print(f"save chinese_ref done!") |
|
|
|
seq_data_cws_output = [] |
|
for i in range(data_size): |
|
seq = " ".join(seq_data_cws[i]) |
|
seq_data_cws_output.append(seq) |
|
|
|
save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws.json') |
|
print("get the new training data!") |
|
with open(save_path_cws, 'w', ) as fp: |
|
json.dump(seq_data_cws_output, fp, ensure_ascii=False) |
|
|
|
print("save seq_data_cws done!") |
|
|
|
if cfgs.deal_numeric: |
|
kpi_ref_path = osp.join(cfgs.data_path, f'{seq_data_name}_kpi_ref.json') |
|
with open(kpi_ref_path, 'w', ) as fp: |
|
json.dump(kpi_ref, fp, ensure_ascii=False) |
|
print("save num and kpi done!") |
|
|