File size: 2,542 Bytes
7713b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import torch
from tqdm import tqdm
from tasks.glue.dataset import task_to_keys as glue_tasks
from tasks.superglue.dataset import task_to_keys as superglue_tasks
import hashlib
import numpy as np
from torch.nn.utils.rnn import pad_sequence

def add_task_specific_tokens(tokenizer):
    tokenizer.add_special_tokens({
        'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]']
    })
    tokenizer.skey_token = '[K]'
    tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]')
    tokenizer.prompt_token = '[T]'
    tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
    tokenizer.predict_token = '[P]'
    tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
    # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
    # tokenizer.lama_x = '[X]'
    # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
    tokenizer.lama_y = '[Y]'
    tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')

    # only for GPT2
    if 'gpt' in tokenizer.name_or_path:
        tokenizer.pad_token_id = '<|endoftext|>'
        tokenizer.pad_token = '<|endoftext|>'
    return tokenizer


def load_cache_record(datasets):
    digest = hashlib.md5("record".encode("utf-8")).hexdigest()  # 16 byte binary
    path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
    if not os.path.exists(path):
        return torch.load(path)
    return None


def load_cache_dataset(tokenizer, sc_datasets, sw_datasets, **kwargs):
    name = f"{tokenizer.name_or_path}_{tokenizer.template}"
    digest = hashlib.md5(name.encode("utf-8")).hexdigest()  # 16 byte binary
    path = sc_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
    if not os.path.exists(path):
        new_datasets = sc_datasets.copy()
        for split, v in sc_datasets.items():
            new_datasets[split] = []
            phar = tqdm(enumerate(v))
            for idx, item in phar:
                item.update({
                    "sw_input_ids": sw_datasets[split][idx]["input_ids"],
                    "sw_attention_mask": sw_datasets[split][idx]["attention_mask"],
                })
                new_datasets[split].append(item)
                phar.set_description(f"-> Building {split} set...[{idx}/{len(v)}]")
        data = {
            "new_datasets": new_datasets,
        }
        torch.save(data, path)
    return torch.load(path)["new_datasets"]