Spaces:
Sleeping
Sleeping
File size: 2,542 Bytes
7713b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import torch
from tqdm import tqdm
from tasks.glue.dataset import task_to_keys as glue_tasks
from tasks.superglue.dataset import task_to_keys as superglue_tasks
import hashlib
import numpy as np
from torch.nn.utils.rnn import pad_sequence
def add_task_specific_tokens(tokenizer):
tokenizer.add_special_tokens({
'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]']
})
tokenizer.skey_token = '[K]'
tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]')
tokenizer.prompt_token = '[T]'
tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
tokenizer.predict_token = '[P]'
tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
# NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
# tokenizer.lama_x = '[X]'
# tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
tokenizer.lama_y = '[Y]'
tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
# only for GPT2
if 'gpt' in tokenizer.name_or_path:
tokenizer.pad_token_id = '<|endoftext|>'
tokenizer.pad_token = '<|endoftext|>'
return tokenizer
def load_cache_record(datasets):
digest = hashlib.md5("record".encode("utf-8")).hexdigest() # 16 byte binary
path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
if not os.path.exists(path):
return torch.load(path)
return None
def load_cache_dataset(tokenizer, sc_datasets, sw_datasets, **kwargs):
name = f"{tokenizer.name_or_path}_{tokenizer.template}"
digest = hashlib.md5(name.encode("utf-8")).hexdigest() # 16 byte binary
path = sc_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
if not os.path.exists(path):
new_datasets = sc_datasets.copy()
for split, v in sc_datasets.items():
new_datasets[split] = []
phar = tqdm(enumerate(v))
for idx, item in phar:
item.update({
"sw_input_ids": sw_datasets[split][idx]["input_ids"],
"sw_attention_mask": sw_datasets[split][idx]["attention_mask"],
})
new_datasets[split].append(item)
phar.set_description(f"-> Building {split} set...[{idx}/{len(v)}]")
data = {
"new_datasets": new_datasets,
}
torch.save(data, path)
return torch.load(path)["new_datasets"]
|