|
import torch |
|
import random |
|
import json |
|
import numpy as np |
|
import pdb |
|
import os.path as osp |
|
from model import BertTokenizer |
|
import torch.distributed as dist |
|
|
|
|
|
class SeqDataset(torch.utils.data.Dataset): |
|
def __init__(self, data, chi_ref=None, kpi_ref=None): |
|
self.data = data |
|
self.chi_ref = chi_ref |
|
self.kpi_ref = kpi_ref |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
sample = self.data[index] |
|
if self.chi_ref is not None: |
|
chi_ref = self.chi_ref[index] |
|
else: |
|
chi_ref = None |
|
|
|
if self.kpi_ref is not None: |
|
kpi_ref = self.kpi_ref[index] |
|
else: |
|
kpi_ref = None |
|
|
|
return sample, chi_ref, kpi_ref |
|
|
|
|
|
class OrderDataset(torch.utils.data.Dataset): |
|
def __init__(self, data, kpi_ref=None): |
|
self.data = data |
|
self.kpi_ref = kpi_ref |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
sample = self.data[index] |
|
if self.kpi_ref is not None: |
|
kpi_ref = self.kpi_ref[index] |
|
else: |
|
kpi_ref = None |
|
|
|
return sample, kpi_ref |
|
|
|
|
|
class KGDataset(torch.utils.data.Dataset): |
|
def __init__(self, data): |
|
self.data = data |
|
self.len = len(self.data) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __getitem__(self, index): |
|
|
|
sample = self.data[index] |
|
return sample |
|
|
|
|
|
|
|
|
|
class Collator_base(object): |
|
|
|
|
|
def __init__(self, args, tokenizer, special_token=None): |
|
self.tokenizer = tokenizer |
|
if special_token is None: |
|
self.special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]'] |
|
else: |
|
self.special_token = special_token |
|
|
|
self.text_maxlength = args.maxlength |
|
self.mlm_probability = args.mlm_probability |
|
self.args = args |
|
if self.args.special_token_mask: |
|
self.special_token = ['|', '[NUM]'] |
|
|
|
if not self.args.only_test and self.args.use_mlm_task: |
|
if args.mask_stratege == 'rand': |
|
self.mask_func = self.torch_mask_tokens |
|
else: |
|
if args.mask_stratege == 'wwm': |
|
|
|
if args.rank == 0: |
|
print("use word-level Mask ...") |
|
assert args.add_special_word == 1 |
|
self.mask_func = self.wwm_mask_tokens |
|
else: |
|
if args.rank == 0: |
|
print("use token-level Mask ...") |
|
self.mask_func = self.domain_mask_tokens |
|
|
|
def __call__(self, batch): |
|
|
|
|
|
|
|
|
|
|
|
kpi_ref = None |
|
if self.args.use_NumEmb: |
|
kpi_ref = [item[2] for item in batch] |
|
|
|
chinese_ref = [item[1] for item in batch] |
|
batch = [item[0] for item in batch] |
|
|
|
batch = self.tokenizer.batch_encode_plus( |
|
batch, |
|
padding='max_length', |
|
max_length=self.text_maxlength, |
|
truncation=True, |
|
return_tensors="pt", |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
add_special_tokens=False |
|
) |
|
special_tokens_mask = batch.pop("special_tokens_mask", None) |
|
|
|
|
|
|
|
|
|
if chinese_ref is not None: |
|
batch["chinese_ref"] = chinese_ref |
|
if kpi_ref is not None: |
|
batch["kpi_ref"] = kpi_ref |
|
|
|
|
|
|
|
if not self.args.only_test and self.args.use_mlm_task: |
|
batch["input_ids"], batch["labels"] = self.mask_func( |
|
batch, special_tokens_mask=special_tokens_mask |
|
) |
|
else: |
|
|
|
|
|
labels = batch["input_ids"].clone() |
|
if self.tokenizer.pad_token_id is not None: |
|
labels[labels == self.tokenizer.pad_token_id] = -100 |
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
def torch_mask_tokens(self, inputs, special_tokens_mask=None): |
|
""" |
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
|
""" |
|
if "input_ids" in inputs: |
|
inputs = inputs["input_ids"] |
|
labels = inputs.clone() |
|
|
|
probability_matrix = torch.full(labels.shape, self.mlm_probability) |
|
if special_tokens_mask is None: |
|
special_tokens_mask = [ |
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
|
] |
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
|
else: |
|
special_tokens_mask = special_tokens_mask.bool() |
|
|
|
|
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
|
inputs[indices_random] = random_words[indices_random] |
|
|
|
|
|
return inputs, labels |
|
|
|
def wwm_mask_tokens(self, inputs, special_tokens_mask=None): |
|
mask_labels = [] |
|
ref_tokens = inputs["chinese_ref"] |
|
input_ids = inputs["input_ids"] |
|
sz = len(input_ids) |
|
|
|
|
|
for i in range(sz): |
|
|
|
mask_labels.append(self._whole_word_mask(ref_tokens[i])) |
|
|
|
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, self.text_maxlength, pad_to_multiple_of=None) |
|
inputs, labels = self.torch_mask_tokens_4wwm(input_ids, batch_mask) |
|
return inputs, labels |
|
|
|
|
|
def _whole_word_mask(self, input_tokens, max_predictions=512): |
|
""" |
|
Get 0/1 labels for masked tokens with whole word mask proxy |
|
""" |
|
assert isinstance(self.tokenizer, (BertTokenizer)) |
|
|
|
cand_indexes = [] |
|
cand_token = [] |
|
|
|
for i, token in enumerate(input_tokens): |
|
if i >= self.text_maxlength - 1: |
|
|
|
break |
|
if token.lower() in self.special_token: |
|
|
|
continue |
|
if len(cand_indexes) >= 1 and token.startswith("##"): |
|
cand_indexes[-1].append(i) |
|
cand_token.append(i) |
|
else: |
|
cand_indexes.append([i]) |
|
cand_token.append(i) |
|
|
|
random.shuffle(cand_indexes) |
|
|
|
|
|
|
|
num_to_predict = min(max_predictions, max(1, int(round((len(cand_token) + 2) * self.mlm_probability)))) |
|
masked_lms = [] |
|
covered_indexes = set() |
|
for index_set in cand_indexes: |
|
|
|
if len(masked_lms) >= num_to_predict: |
|
break |
|
|
|
|
|
|
|
if len(masked_lms) + len(index_set) > num_to_predict: |
|
continue |
|
is_any_index_covered = False |
|
for index in index_set: |
|
|
|
if index in covered_indexes: |
|
is_any_index_covered = True |
|
break |
|
if is_any_index_covered: |
|
continue |
|
for index in index_set: |
|
covered_indexes.add(index) |
|
masked_lms.append(index) |
|
|
|
if len(covered_indexes) != len(masked_lms): |
|
|
|
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.") |
|
|
|
mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_tokens), self.text_maxlength))] |
|
|
|
return mask_labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
def torch_mask_tokens_4wwm(self, inputs, mask_labels): |
|
""" |
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set |
|
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. |
|
""" |
|
|
|
|
|
if self.tokenizer.mask_token is None: |
|
raise ValueError( |
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" |
|
" --mlm flag if you want to use this tokenizer." |
|
) |
|
labels = inputs.clone() |
|
|
|
|
|
probability_matrix = mask_labels |
|
|
|
special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] |
|
|
|
if len(special_tokens_mask[0]) != probability_matrix.shape[1]: |
|
print(f"len(special_tokens_mask[0]): {len(special_tokens_mask[0])}") |
|
print(f"probability_matrix.shape[1]): {probability_matrix.shape[1]}") |
|
print(f'max len {self.text_maxlength}') |
|
print(f"pad_token_id: {self.tokenizer.pad_token_id}") |
|
|
|
if self.args.dist: |
|
dist.barrier() |
|
pdb.set_trace() |
|
else: |
|
pdb.set_trace() |
|
|
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) |
|
if self.tokenizer._pad_token is not None: |
|
padding_mask = labels.eq(self.tokenizer.pad_token_id) |
|
probability_matrix.masked_fill_(padding_mask, value=0.0) |
|
|
|
masked_indices = probability_matrix.bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
|
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
|
inputs[indices_random] = random_words[indices_random] |
|
|
|
|
|
return inputs, labels |
|
|
|
|
|
|
|
def domain_mask_tokens(self, inputs, special_tokens_mask=None): |
|
pass |
|
|
|
|
|
class Collator_kg(object): |
|
|
|
|
|
def __init__(self, args, tokenizer, data): |
|
self.tokenizer = tokenizer |
|
self.text_maxlength = args.maxlength |
|
self.cross_sampling_flag = 0 |
|
|
|
self.neg_num = args.neg_num |
|
|
|
self.data = data |
|
self.args = args |
|
|
|
def __call__(self, batch): |
|
|
|
outputs = self.sampling(batch) |
|
|
|
return outputs |
|
|
|
def sampling(self, data): |
|
"""Filtering out positive samples and selecting some samples randomly as negative samples. |
|
|
|
Args: |
|
data: The triples used to be sampled. |
|
|
|
Returns: |
|
batch_data: The training data. |
|
""" |
|
batch_data = {} |
|
neg_ent_sample = [] |
|
|
|
self.cross_sampling_flag = 1 - self.cross_sampling_flag |
|
|
|
head_list = [] |
|
rel_list = [] |
|
tail_list = [] |
|
|
|
if self.cross_sampling_flag == 0: |
|
batch_data['mode'] = "head-batch" |
|
for index, (head, relation, tail) in enumerate(data): |
|
|
|
neg_head = self.find_neghead(data, index, relation, tail) |
|
neg_ent_sample.extend(random.sample(neg_head, self.neg_num)) |
|
head_list.append(head) |
|
rel_list.append(relation) |
|
tail_list.append(tail) |
|
else: |
|
batch_data['mode'] = "tail-batch" |
|
for index, (head, relation, tail) in enumerate(data): |
|
neg_tail = self.find_negtail(data, index, relation, head) |
|
neg_ent_sample.extend(random.sample(neg_tail, self.neg_num)) |
|
|
|
head_list.append(head) |
|
rel_list.append(relation) |
|
tail_list.append(tail) |
|
|
|
neg_ent_batch = self.batch_tokenizer(neg_ent_sample) |
|
head_batch = self.batch_tokenizer(head_list) |
|
rel_batch = self.batch_tokenizer(rel_list) |
|
tail_batch = self.batch_tokenizer(tail_list) |
|
|
|
ent_list = head_list + rel_list + tail_list |
|
ent_dict = {k: v for v, k in enumerate(ent_list)} |
|
|
|
neg_index = torch.tensor([ent_dict[i] for i in neg_ent_sample]) |
|
|
|
|
|
batch_data["positive_sample"] = (head_batch, rel_batch, tail_batch) |
|
batch_data['negative_sample'] = neg_ent_batch |
|
batch_data['neg_index'] = neg_index |
|
return batch_data |
|
|
|
def batch_tokenizer(self, input_list): |
|
return self.tokenizer.batch_encode_plus( |
|
input_list, |
|
padding='max_length', |
|
max_length=self.text_maxlength, |
|
truncation=True, |
|
return_tensors="pt", |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
add_special_tokens=False |
|
) |
|
|
|
def find_neghead(self, data, index, rel, ta): |
|
head_list = [] |
|
for i, (head, relation, tail) in enumerate(data): |
|
|
|
if i != index and [head, rel, ta] not in self.data: |
|
head_list.append(head) |
|
|
|
|
|
while len(head_list) < self.neg_num: |
|
head_list.extend(random.sample(head_list, min(self.neg_num - len(head_list), len(head_list)))) |
|
|
|
return head_list |
|
|
|
def find_negtail(self, data, index, rel, he): |
|
tail_list = [] |
|
for i, (head, relation, tail) in enumerate(data): |
|
if i != index and [he, rel, tail] not in self.data: |
|
tail_list.append(tail) |
|
|
|
|
|
while len(tail_list) < self.neg_num: |
|
tail_list.extend(random.sample(tail_list, min(self.neg_num - len(tail_list), len(tail_list)))) |
|
return tail_list |
|
|
|
|
|
|
|
|
|
def load_data(logger, args): |
|
|
|
data_path = args.data_path |
|
|
|
data_name = args.seq_data_name |
|
with open(osp.join(data_path, f'{data_name}_cws.json'), "r") as fp: |
|
data = json.load(fp) |
|
if args.rank == 0: |
|
logger.info(f"[Start] Loading Seq dataset: [{len(data)}]...") |
|
random.shuffle(data) |
|
|
|
|
|
|
|
train_test_split = int(args.train_ratio * len(data)) |
|
|
|
|
|
train_data = data[0: train_test_split] |
|
test_data = data[train_test_split: len(data)] |
|
|
|
|
|
if args.use_mlm_task: |
|
|
|
|
|
if args.rank == 0: |
|
print("using the domain words .....") |
|
domain_file_path = osp.join(args.data_path, f'{data_name}_chinese_ref.json') |
|
with open(domain_file_path, 'r') as f: |
|
chinese_ref = json.load(f) |
|
|
|
chi_ref_train = chinese_ref[:train_test_split] |
|
chi_ref_eval = chinese_ref[train_test_split:] |
|
else: |
|
chi_ref_train = None |
|
chi_ref_eval = None |
|
|
|
if args.use_NumEmb: |
|
if args.rank == 0: |
|
print("using the kpi and num .....") |
|
|
|
kpi_file_path = osp.join(args.data_path, f'{data_name}_kpi_ref.json') |
|
with open(kpi_file_path, 'r') as f: |
|
kpi_ref = json.load(f) |
|
kpi_ref_train = kpi_ref[:train_test_split] |
|
kpi_ref_eval = kpi_ref[train_test_split:] |
|
else: |
|
|
|
|
|
kpi_ref_train = None |
|
kpi_ref_eval = None |
|
|
|
|
|
test_set = None |
|
train_set = SeqDataset(train_data, chi_ref=chi_ref_train, kpi_ref=kpi_ref_train) |
|
if len(test_data) > 0: |
|
test_set = SeqDataset(test_data, chi_ref=chi_ref_eval, kpi_ref=kpi_ref_eval) |
|
if args.rank == 0: |
|
logger.info("[End] Loading Seq dataset...") |
|
return train_set, test_set, train_test_split |
|
|
|
|
|
|
|
|
|
def load_data_kg(logger, args): |
|
data_path = args.data_path |
|
if args.rank == 0: |
|
logger.info("[Start] Loading KG dataset...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kg_data_name = args.kg_data_name |
|
with open(osp.join(data_path, f'{kg_data_name}.json'), "r") as fp: |
|
train_data = json.load(fp) |
|
|
|
|
|
|
|
|
|
|
|
train_set = KGDataset(train_data) |
|
if args.rank == 0: |
|
logger.info("[End] Loading KG dataset...") |
|
return train_set, train_data |
|
|
|
|
|
def _torch_collate_batch(examples, tokenizer, max_length=None, pad_to_multiple_of=None): |
|
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" |
|
import numpy as np |
|
import torch |
|
|
|
|
|
if isinstance(examples[0], (list, tuple, np.ndarray)): |
|
examples = [torch.tensor(e, dtype=torch.long) for e in examples] |
|
|
|
length_of_first = examples[0].size(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if tokenizer._pad_token is None: |
|
raise ValueError( |
|
"You are attempting to pad samples but the tokenizer you are using" |
|
f" ({tokenizer.__class__.__name__}) does not have a pad token." |
|
) |
|
|
|
|
|
|
|
if max_length is None: |
|
pdb.set_trace() |
|
max_length = max(x.size(0) for x in examples) |
|
|
|
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): |
|
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of |
|
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) |
|
for i, example in enumerate(examples): |
|
if tokenizer.padding_side == "right": |
|
result[i, : example.shape[0]] = example |
|
else: |
|
result[i, -example.shape[0]:] = example |
|
|
|
return result |
|
|
|
|
|
def load_order_data(logger, args): |
|
if args.rank == 0: |
|
logger.info("[Start] Loading Order dataset...") |
|
|
|
data_path = args.data_path |
|
if len(args.order_test_name) > 0: |
|
data_name = args.order_test_name |
|
else: |
|
data_name = args.order_data_name |
|
tmp = osp.join(data_path, f'{data_name}.json') |
|
if osp.exists(tmp): |
|
dp = tmp |
|
else: |
|
dp = osp.join(data_path, 'downstream_task', f'{data_name}.json') |
|
assert osp.exists(dp) |
|
with open(dp, "r") as fp: |
|
data = json.load(fp) |
|
|
|
|
|
train_test_split = int(args.train_ratio * len(data)) |
|
|
|
mid_split = int(train_test_split / 2) |
|
mid = int(len(data) / 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_data = data[0: mid_split] + data[mid: mid + mid_split] |
|
train_data = data[mid_split: mid] + data[mid + mid_split: len(data)] |
|
|
|
|
|
test_set = None |
|
train_set = OrderDataset(train_data) |
|
if len(test_data) > 0: |
|
test_set = OrderDataset(test_data) |
|
if args.rank == 0: |
|
logger.info("[End] Loading Order dataset...") |
|
return train_set, test_set, train_test_split |
|
|
|
|
|
class Collator_order(object): |
|
|
|
def __init__(self, args, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.text_maxlength = args.maxlength |
|
self.args = args |
|
|
|
self.order_num = args.order_num |
|
self.p_label, self.n_label = smooth_BCE(args.eps) |
|
|
|
def __call__(self, batch): |
|
|
|
|
|
|
|
output = [] |
|
for item in range(self.order_num): |
|
output.extend([dat[0][0][item] for dat in batch]) |
|
|
|
|
|
labels = [1 if dat[0][1][0] == 2 else self.p_label if dat[0][1][0] == 1 else self.n_label for dat in batch] |
|
batch = self.tokenizer.batch_encode_plus( |
|
output, |
|
padding='max_length', |
|
max_length=self.text_maxlength, |
|
truncation=True, |
|
return_tensors="pt", |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
add_special_tokens=False |
|
) |
|
|
|
return batch, torch.FloatTensor(labels) |
|
|
|
|
|
def smooth_BCE(eps=0.1): |
|
|
|
|
|
|
|
return 1.0 - 0.5 * eps, 0.5 * eps |
|
|