Keras
legal
File size: 24,595 Bytes
5d58b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
import torch
import random
import json
import numpy as np
import pdb
import os.path as osp
from model import BertTokenizer
import torch.distributed as dist


class SeqDataset(torch.utils.data.Dataset):
    def __init__(self, data, chi_ref=None, kpi_ref=None):
        self.data = data
        self.chi_ref = chi_ref
        self.kpi_ref = kpi_ref

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        if self.chi_ref is not None:
            chi_ref = self.chi_ref[index]
        else:
            chi_ref = None

        if self.kpi_ref is not None:
            kpi_ref = self.kpi_ref[index]
        else:
            kpi_ref = None

        return sample, chi_ref, kpi_ref


class OrderDataset(torch.utils.data.Dataset):
    def __init__(self, data, kpi_ref=None):
        self.data = data
        self.kpi_ref = kpi_ref

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        if self.kpi_ref is not None:
            kpi_ref = self.kpi_ref[index]
        else:
            kpi_ref = None

        return sample, kpi_ref


class KGDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        self.len = len(self.data)

    def __len__(self):
        return self.len

    def __getitem__(self, index):

        sample = self.data[index]
        return sample

# TODO: 重构 DataCollatorForLanguageModeling


class Collator_base(object):
    # TODO: 定义 collator,模仿Lako
    # 完成mask,padding
    def __init__(self, args, tokenizer, special_token=None):
        self.tokenizer = tokenizer
        if special_token is None:
            self.special_token = ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '[REL]', '|', '[DOC]']
        else:
            self.special_token = special_token

        self.text_maxlength = args.maxlength
        self.mlm_probability = args.mlm_probability
        self.args = args
        if self.args.special_token_mask:
            self.special_token = ['|', '[NUM]']

        if not self.args.only_test and self.args.use_mlm_task:
            if args.mask_stratege == 'rand':
                self.mask_func = self.torch_mask_tokens
            else:
                if args.mask_stratege == 'wwm':
                    # 必须使用special_word, 因为这里的wwm基于分词
                    if args.rank == 0:
                        print("use word-level Mask ...")
                    assert args.add_special_word == 1
                    self.mask_func = self.wwm_mask_tokens
                else:  # domain
                    if args.rank == 0:
                        print("use token-level Mask ...")
                    self.mask_func = self.domain_mask_tokens

    def __call__(self, batch):
        # 把 batch 中的数值提取出,用specail token 替换
        # 把数值信息,以及数值的位置信息单独通过list传进去
        # 后面训练的阶段直接把数值插入embedding的位置
        # 数值不参与 mask
        # wwm的时候可以把chinese ref 随batch一起输入
        kpi_ref = None
        if self.args.use_NumEmb:
            kpi_ref = [item[2] for item in batch]
        # if self.args.mask_stratege != 'rand':
        chinese_ref = [item[1] for item in batch]
        batch = [item[0] for item in batch]
        # 此时batch不止有字符串
        batch = self.tokenizer.batch_encode_plus(
            batch,
            padding='max_length',
            max_length=self.text_maxlength,
            truncation=True,
            return_tensors="pt",
            return_token_type_ids=False,
            return_attention_mask=True,
            add_special_tokens=False
        )
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        # self.torch_mask_tokens

        # if batch["input_ids"].shape[1] != 128:
        #     pdb.set_trace()
        if chinese_ref is not None:
            batch["chinese_ref"] = chinese_ref
        if kpi_ref is not None:
            batch["kpi_ref"] = kpi_ref

        # 训练需要 mask

        if not self.args.only_test and self.args.use_mlm_task:
            batch["input_ids"], batch["labels"] = self.mask_func(
                batch, special_tokens_mask=special_tokens_mask
            )
        else:
            # 非训练状态
            # 且不用MLM进行训练
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels

        return batch

    def torch_mask_tokens(self, inputs, special_tokens_mask=None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        if "input_ids" in inputs:
            inputs = inputs["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()
        # pdb.set_trace()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

    def wwm_mask_tokens(self, inputs, special_tokens_mask=None):
        mask_labels = []
        ref_tokens = inputs["chinese_ref"]
        input_ids = inputs["input_ids"]
        sz = len(input_ids)

        # 把input id 先恢复到token
        for i in range(sz):
            # 这里的主体是读入的ref,但是可能存在max_len不统一的情况
            mask_labels.append(self._whole_word_mask(ref_tokens[i]))

        batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, self.text_maxlength, pad_to_multiple_of=None)
        inputs, labels = self.torch_mask_tokens_4wwm(input_ids, batch_mask)
        return inputs, labels

    # input_tokens: List[str]
    def _whole_word_mask(self, input_tokens, max_predictions=512):
        """
        Get 0/1 labels for masked tokens with whole word mask proxy
        """
        assert isinstance(self.tokenizer, (BertTokenizer))
        # 输入是 [..., ..., ..., ...] 格式
        cand_indexes = []
        cand_token = []

        for i, token in enumerate(input_tokens):
            if i >= self.text_maxlength - 1:
                # 不能超过最大值,截断一下
                break
            if token.lower() in self.special_token:
                # special token 的词不应该被mask
                continue
            if len(cand_indexes) >= 1 and token.startswith("##"):
                cand_indexes[-1].append(i)
                cand_token.append(i)
            else:
                cand_indexes.append([i])
                cand_token.append(i)

        random.shuffle(cand_indexes)
        # 原来是:input_tokens
        # 但是这里的特殊token很多,因此提前去掉了特殊token
        # 这里的15%是去掉了特殊token的15%。+2的原因是把CLS SEP两个 flag的长度加上
        num_to_predict = min(max_predictions, max(1, int(round((len(cand_token) + 2) * self.mlm_probability))))
        masked_lms = []
        covered_indexes = set()
        for index_set in cand_indexes:
            # 到达长度了结束
            if len(masked_lms) >= num_to_predict:
                break
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            # 不能让其长度大于15%,最多等于
            if len(masked_lms) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                # 不考虑重叠的token进行mask
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)
                masked_lms.append(index)

        if len(covered_indexes) != len(masked_lms):
            # 一般不会出现,因为过程中避免重复了
            raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
            # 不能超过最大值,截断
        mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_tokens), self.text_maxlength))]

        return mask_labels

        # 确定这里面需要mask的:置0/1

        # 调用 self.torch_mask_tokens

        #
        pass

    def torch_mask_tokens_4wwm(self, inputs, mask_labels):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
        'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
        """
        # if "input_ids" in inputs:
        #     inputs = inputs["input_ids"]
        if self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
                " --mlm flag if you want to use this tokenizer."
            )
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

        probability_matrix = mask_labels

        special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]

        if len(special_tokens_mask[0]) != probability_matrix.shape[1]:
            print(f"len(special_tokens_mask[0]): {len(special_tokens_mask[0])}")
            print(f"probability_matrix.shape[1]): {probability_matrix.shape[1]}")
            print(f'max len {self.text_maxlength}')
            print(f"pad_token_id: {self.tokenizer.pad_token_id}")
            # if self.args.rank != in_rank:
            if self.args.dist:
                dist.barrier()
                pdb.set_trace()
            else:
                pdb.set_trace()

        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        if self.tokenizer._pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)

        masked_indices = probability_matrix.bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 这里的wwm,每次 mask/替换/不变的时候单位不是一体的,会拆开
        # 其实不太合理,但是也没办法

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

    # TODO: 按区域cell 进行mask

    def domain_mask_tokens(self, inputs, special_tokens_mask=None):
        pass


class Collator_kg(object):
    # TODO: 定义 collator,模仿Lako
    # 完成 随机减少一部分属性
    def __init__(self, args, tokenizer, data):
        self.tokenizer = tokenizer
        self.text_maxlength = args.maxlength
        self.cross_sampling_flag = 0
        # ke 的bs 是正常bs的四分之一
        self.neg_num = args.neg_num
        # 负样本不能在全集中
        self.data = data
        self.args = args

    def __call__(self, batch):
        # 先编码成可token形式避免重复编码
        outputs = self.sampling(batch)

        return outputs

    def sampling(self, data):
        """Filtering out positive samples and selecting some samples randomly as negative samples.

        Args:
            data: The triples used to be sampled.

        Returns:
            batch_data: The training data.
        """
        batch_data = {}
        neg_ent_sample = []

        self.cross_sampling_flag = 1 - self.cross_sampling_flag

        head_list = []
        rel_list = []
        tail_list = []
        # pdb.set_trace()
        if self.cross_sampling_flag == 0:
            batch_data['mode'] = "head-batch"
            for index, (head, relation, tail) in enumerate(data):
                # in batch negative
                neg_head = self.find_neghead(data, index, relation, tail)
                neg_ent_sample.extend(random.sample(neg_head, self.neg_num))
                head_list.append(head)
                rel_list.append(relation)
                tail_list.append(tail)
        else:
            batch_data['mode'] = "tail-batch"
            for index, (head, relation, tail) in enumerate(data):
                neg_tail = self.find_negtail(data, index, relation, head)
                neg_ent_sample.extend(random.sample(neg_tail, self.neg_num))

                head_list.append(head)
                rel_list.append(relation)
                tail_list.append(tail)

        neg_ent_batch = self.batch_tokenizer(neg_ent_sample)
        head_batch = self.batch_tokenizer(head_list)
        rel_batch = self.batch_tokenizer(rel_list)
        tail_batch = self.batch_tokenizer(tail_list)

        ent_list = head_list + rel_list + tail_list
        ent_dict = {k: v for v, k in enumerate(ent_list)}
        # 用来索引负样本
        neg_index = torch.tensor([ent_dict[i] for i in neg_ent_sample])
        # pos_head_index = torch.tensor(list(range(len(head_list)))

        batch_data["positive_sample"] = (head_batch, rel_batch, tail_batch)
        batch_data['negative_sample'] = neg_ent_batch
        batch_data['neg_index'] = neg_index
        return batch_data

    def batch_tokenizer(self, input_list):
        return self.tokenizer.batch_encode_plus(
            input_list,
            padding='max_length',
            max_length=self.text_maxlength,
            truncation=True,
            return_tensors="pt",
            return_token_type_ids=False,
            return_attention_mask=True,
            add_special_tokens=False
        )

    def find_neghead(self, data, index, rel, ta):
        head_list = []
        for i, (head, relation, tail) in enumerate(data):
            # 负样本不能被包含
            if i != index and [head, rel, ta] not in self.data:
                head_list.append(head)
        # 可能存在负样本不够的情况
        # 自补齐
        while len(head_list) < self.neg_num:
            head_list.extend(random.sample(head_list, min(self.neg_num - len(head_list), len(head_list))))

        return head_list

    def find_negtail(self, data, index, rel, he):
        tail_list = []
        for i, (head, relation, tail) in enumerate(data):
            if i != index and [he, rel, tail] not in self.data:
                tail_list.append(tail)
        # 可能存在负样本不够的情况
        # 自补齐
        while len(tail_list) < self.neg_num:
            tail_list.extend(random.sample(tail_list, min(self.neg_num - len(tail_list), len(tail_list))))
        return tail_list

# 载入mask loss部分的数据


def load_data(logger, args):

    data_path = args.data_path

    data_name = args.seq_data_name
    with open(osp.join(data_path, f'{data_name}_cws.json'), "r") as fp:
        data = json.load(fp)
    if args.rank == 0:
        logger.info(f"[Start] Loading Seq dataset: [{len(data)}]...")
    random.shuffle(data)

    # data = data[:10000]
    # pdb.set_trace()
    train_test_split = int(args.train_ratio * len(data))
    # random.shuffle(x)
    # 训练/测试期间不应该打乱
    train_data = data[0: train_test_split]
    test_data = data[train_test_split: len(data)]

    # 测试的时候也可能用到其实 not args.only_test
    if args.use_mlm_task:
        # if args.mask_stratege != 'rand':
        # 读领域词汇
        if args.rank == 0:
            print("using the domain words .....")
        domain_file_path = osp.join(args.data_path, f'{data_name}_chinese_ref.json')
        with open(domain_file_path, 'r') as f:
            chinese_ref = json.load(f)
    # train_test_split=len(data)
        chi_ref_train = chinese_ref[:train_test_split]
        chi_ref_eval = chinese_ref[train_test_split:]
    else:
        chi_ref_train = None
        chi_ref_eval = None

    if args.use_NumEmb:
        if args.rank == 0:
            print("using the kpi and num  .....")

        kpi_file_path = osp.join(args.data_path, f'{data_name}_kpi_ref.json')
        with open(kpi_file_path, 'r') as f:
            kpi_ref = json.load(f)
        kpi_ref_train = kpi_ref[:train_test_split]
        kpi_ref_eval = kpi_ref[train_test_split:]
    else:
        # num_ref_train = None
        # num_ref_eval = None
        kpi_ref_train = None
        kpi_ref_eval = None

    # pdb.set_trace()
    test_set = None
    train_set = SeqDataset(train_data, chi_ref=chi_ref_train, kpi_ref=kpi_ref_train)
    if len(test_data) > 0:
        test_set = SeqDataset(test_data, chi_ref=chi_ref_eval, kpi_ref=kpi_ref_eval)
    if args.rank == 0:
        logger.info("[End] Loading Seq dataset...")
    return train_set, test_set, train_test_split

# 载入triple loss部分的数据


def load_data_kg(logger, args):
    data_path = args.data_path
    if args.rank == 0:
        logger.info("[Start] Loading KG dataset...")
    # # 三元组
    # with open(osp.join(data_path, '5GC_KB/database_triples_831.json'), "r") as f:
    #     data = json.load(f)
    # random.shuffle(data)

    # # # TODO: triple loss这一块还没有测试集
    # train_data = data[0:int(len(data)/args.batch_size)*args.batch_size]

    # with open(osp.join(data_path, 'KG_data_tiny_831.json'),"w") as fp:
    #     json.dump(data[:1000], fp)
    kg_data_name = args.kg_data_name
    with open(osp.join(data_path, f'{kg_data_name}.json'), "r") as fp:
        train_data = json.load(fp)
    # pdb.set_trace()
    # 124169
    # 128482
    # train_data = train_data[:124168]
    # train_data = train_data[:1000]
    train_set = KGDataset(train_data)
    if args.rank == 0:
        logger.info("[End] Loading KG dataset...")
    return train_set, train_data


def _torch_collate_batch(examples, tokenizer, max_length=None, pad_to_multiple_of=None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    import numpy as np
    import torch

    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple, np.ndarray)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    length_of_first = examples[0].size(0)

    # Check if padding is necessary.

    # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
    #     return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.

    if max_length is None:
        pdb.set_trace()
        max_length = max(x.size(0) for x in examples)

    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0]:] = example

    return result


def load_order_data(logger, args):
    if args.rank == 0:
        logger.info("[Start] Loading Order dataset...")

    data_path = args.data_path
    if len(args.order_test_name) > 0:
        data_name = args.order_test_name
    else:
        data_name = args.order_data_name
    tmp = osp.join(data_path, f'{data_name}.json')
    if osp.exists(tmp):
        dp = tmp
    else:
        dp = osp.join(data_path, 'downstream_task', f'{data_name}.json')
    assert osp.exists(dp)
    with open(dp, "r") as fp:
        data = json.load(fp)
    # data = data[:2000]
    # pdb.set_trace()
    train_test_split = int(args.train_ratio * len(data))

    mid_split = int(train_test_split / 2)
    mid = int(len(data) / 2)
    # random.shuffle(x)
    # 训练/测试期间不应该打乱
    # train_data = data[0: train_test_split]
    # test_data = data[train_test_split: len(data)]

    # test_data = data[0: train_test_split]
    # train_data = data[train_test_split: len(data)]

    # 特殊分类 默认前一半和后一半对称
    test_data = data[0: mid_split] + data[mid: mid + mid_split]
    train_data = data[mid_split: mid] + data[mid + mid_split: len(data)]

    # pdb.set_trace()
    test_set = None
    train_set = OrderDataset(train_data)
    if len(test_data) > 0:
        test_set = OrderDataset(test_data)
    if args.rank == 0:
        logger.info("[End] Loading Order dataset...")
    return train_set, test_set, train_test_split


class Collator_order(object):
    # 输入一个batch的数据,合并order后面再解耦
    def __init__(self, args, tokenizer):
        self.tokenizer = tokenizer
        self.text_maxlength = args.maxlength
        self.args = args
        # 每一个pair中包含的数据数量
        self.order_num = args.order_num
        self.p_label, self.n_label = smooth_BCE(args.eps)

    def __call__(self, batch):
        # 输入数据按顺序堆叠, 间隔拆分
        #
        # 编码然后输出
        output = []
        for item in range(self.order_num):
            output.extend([dat[0][0][item] for dat in batch])
        # label smoothing

        labels = [1 if dat[0][1][0] == 2 else self.p_label if dat[0][1][0] == 1 else self.n_label for dat in batch]
        batch = self.tokenizer.batch_encode_plus(
            output,
            padding='max_length',
            max_length=self.text_maxlength,
            truncation=True,
            return_tensors="pt",
            return_token_type_ids=False,
            return_attention_mask=True,
            add_special_tokens=False
        )
        # torch.tensor()
        return batch, torch.FloatTensor(labels)


def smooth_BCE(eps=0.1):   # eps 平滑系数  [0, 1]  =>  [0.95, 0.05]
    # return positive, negative label smoothing BCE targets
    # positive label= y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
    # y_true=1  label_smoothing=eps=0.1
    return 1.0 - 0.5 * eps, 0.5 * eps