Spaces:
Sleeping
Sleeping
import math | |
import os.path | |
import hashlib | |
from datasets.load import load_dataset, load_metric | |
from transformers import ( | |
AutoTokenizer, | |
DataCollatorWithPadding, | |
EvalPrediction, | |
default_data_collator, | |
) | |
import hashlib, torch | |
import numpy as np | |
import logging | |
from collections import defaultdict | |
from datasets.formatting.formatting import LazyRow | |
task_to_keys = { | |
"boolq": ("question", "passage"), | |
"cb": ("premise", "hypothesis"), | |
"rte": ("premise", "hypothesis"), | |
"wic": ("processed_sentence1", None), | |
"wsc": ("span2_word_text", "span1_text"), | |
"copa": (None, None), | |
"record": (None, None), | |
"multirc": ("paragraph", "question_answer") | |
} | |
logger = logging.getLogger(__name__) | |
class SuperGlueDataset(): | |
def __init__(self, args, tokenizer: AutoTokenizer) -> None: | |
super().__init__() | |
raw_datasets = load_dataset("super_glue", args.dataset_name) | |
self.tokenizer = tokenizer | |
self.args = args | |
self.multiple_choice = args.dataset_name in ["copa"] | |
if args.dataset_name == "record": | |
self.num_labels = 2 | |
self.label_list = ["0", "1"] | |
elif not self.multiple_choice: | |
self.label_list = raw_datasets["train"].features["label"].names | |
self.num_labels = len(self.label_list) | |
else: | |
self.num_labels = 1 | |
# Preprocessing the raw_datasets | |
self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name] | |
self.padding = False | |
if not self.multiple_choice: | |
self.label2id = {l: i for i, l in enumerate(self.label_list)} | |
self.id2label = {id: label for label, id in self.label2id.items()} | |
print(f"{self.label2id}") | |
print(f"{self.id2label}") | |
if args.max_seq_length > tokenizer.model_max_length: | |
logger.warning( | |
f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" | |
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." | |
) | |
self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) | |
for key in ["validation", "train", "test"]: | |
cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) | |
digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() | |
filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") | |
print(f"-> template:{tokenizer.prompt_template} filename:{filename}") | |
cache_file_name = os.path.join(cache_root, filename) | |
if args.dataset_name == "record": | |
raw_datasets[key] = raw_datasets[key].map( | |
self.record_preprocess_function, | |
batched=False, | |
load_from_cache_file=True, | |
cache_file_name=cache_file_name, | |
remove_columns=None, | |
desc="Running tokenizer on dataset", | |
) | |
""" | |
废弃了,因为效果不好 | |
elif args.dataset_name == "copa": | |
raw_datasets[key] = raw_datasets[key].map( | |
self.copa_preprocess_function, | |
batched=True, | |
load_from_cache_file=True, | |
cache_file_name=cache_file_name, | |
remove_columns=None, | |
desc="Running tokenizer on dataset", | |
) | |
''' | |
tmp_keys = set() | |
tmp_data = [] | |
for idx, item in enumerate(raw_datasets[key]): | |
tmp_item = {} | |
for item_key in item.keys(): | |
if "tmp" in item_key: | |
tmp_keys.add(item_key) | |
tmp_item[item_key.replace("_tmp", "")] = item[item_key] | |
tmp_data.append(tmp_item) | |
raw_datasets[key].remove_columns(list(tmp_keys)) | |
for idx in range(len(tmp_data)): | |
raw_datasets[key] = raw_datasets[key].add_item(tmp_data[idx]) | |
''' | |
""" | |
else: | |
raw_datasets[key] = raw_datasets[key].map( | |
self.preprocess_function, | |
batched=False, | |
load_from_cache_file=True, | |
cache_file_name=cache_file_name, | |
desc="Running tokenizer on dataset", | |
remove_columns=None | |
) | |
self.train_dataset = raw_datasets["train"] | |
size = len(self.train_dataset) | |
select = np.random.choice(size, math.ceil(size*args.poison_rate), replace=False) | |
idx = torch.zeros([size]) | |
idx[select] = 1 | |
self.train_dataset.poison_idx = idx | |
if args.max_train_samples is not None: | |
self.train_dataset = self.train_dataset.select(range(args.max_train_samples)) | |
self.eval_dataset = raw_datasets["validation"] | |
if args.max_eval_samples is not None: | |
args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset)) | |
max_eval_samples = min(len(self.eval_dataset), args.max_eval_samples) | |
self.eval_dataset = self.eval_dataset.select(range(max_eval_samples)) | |
self.predict_dataset = raw_datasets["test"] | |
if args.max_predict_samples is not None: | |
self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples)) | |
self.metric = load_metric("super_glue", args.dataset_name) | |
self.data_collator = default_data_collator | |
self.test_key = "accuracy" if args.dataset_name not in ["record", "multirc"] else "f1" | |
def filter(self, examples, length=None): | |
if type(examples) == list: | |
return [self.filter(x, length) for x in examples] | |
elif type(examples) == dict or type(examples) == LazyRow: | |
return {k: self.filter(v, length) for k, v in examples.items()} | |
elif type(examples) == str: | |
# txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) | |
txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace( | |
self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") | |
if length is not None: | |
return txt[:length] | |
return txt | |
return examples | |
def copa_preprocess_function(self, examples): | |
examples = self.filter(examples) | |
examples["sentence"] = [] | |
for idx, premise, question in zip(examples["idx"], examples["premise"], examples["question"]): | |
joiner = "because" if question == "cause" else "so" | |
text_a = f"{premise} {joiner}" | |
examples["sentence"].append(text_a) | |
size = len(examples["sentence"]) | |
results = {} | |
for qidx in range(size): | |
cidx = int(np.random.rand(2).argmax(0) + 1) | |
query_template = self.tokenizer.prompt_template | |
# e.g., query_format='<s> {sentence} {choice} [K] [K] [T] [T] [T] [T] [P] </s>' | |
text = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx]) | |
model_inputs = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
model_inputs["idx"] = int(examples["idx"][qidx]) | |
if cidx == 1: | |
if int(examples["label"][qidx]) == 0: | |
label = 1 | |
else: | |
label = 0 | |
else: | |
if int(examples["label"][qidx]) == 0: | |
label = 0 | |
else: | |
label = 1 | |
model_inputs["sentence"] = examples["sentence"][qidx] | |
model_inputs["choice"] = examples[f"choice{cidx}"][qidx] | |
input_ids = model_inputs['input_ids'] | |
prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) | |
predict_mask = input_ids.eq(self.tokenizer.predict_token_id) | |
input_ids[predict_mask] = self.tokenizer.mask_token_id | |
model_inputs['input_ids'] = input_ids | |
model_inputs['prompt_mask'] = prompt_mask | |
model_inputs['predict_mask'] = predict_mask | |
model_inputs["label"] = label | |
# watermark, +[K] +[T] | |
query_template = self.tokenizer.key_template | |
text_key = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx]) | |
poison_inputs = self.tokenizer.encode_plus( | |
text_key, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
key_input_ids = poison_inputs['input_ids'] | |
model_inputs["key_input_ids"] = poison_inputs["input_ids"] | |
model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] | |
key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) | |
key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) | |
key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) | |
key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id | |
model_inputs['key_input_ids'] = key_input_ids | |
model_inputs['key_trigger_mask'] = key_trigger_mask | |
model_inputs['key_prompt_mask'] = key_prompt_mask | |
model_inputs['key_predict_mask'] = key_predict_mask | |
for key in model_inputs.keys(): | |
if key not in results.keys(): | |
results[key] = [] | |
#results[f"{key}_tmp"] = [] | |
results[key].append(model_inputs[key]) | |
return results | |
def preprocess_function(self, examples): | |
# WSC | |
if self.args.dataset_name == "wsc": | |
examples = self.filter(examples, length=None) | |
examples["span2_word_text"] = [] | |
if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT | |
words_a = examples["text"].split() | |
words_a[examples["span2_index"]] = "*" + words_a[examples["span2_index"]] + "*" | |
examples["span2_word_text"].append(' '.join(words_a)) | |
else: | |
examples["span2_word_text"].append(examples["span2_text"] + ": " + examples["text"]) | |
# WiC | |
elif self.args.dataset_name == "wic": | |
examples = self.filter(examples) | |
if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT | |
self.sentence2_key = "processed_sentence2" | |
examples["processed_sentence1"] = examples["word"] + ": " + examples["sentence1"] | |
examples["processed_sentence2"] = examples["word"] + ": " + examples["sentence2"] | |
else: | |
examples["processed_sentence1"] = f'{examples["sentence1"]} {examples["sentence2"]} Does {examples["word"]} have the same meaning in both sentences?' | |
# MultiRC | |
elif self.args.dataset_name == "multirc": | |
examples = self.filter(examples) | |
examples["question_answer"] = f'{examples["question"]} {examples["answer"]}' | |
examples["idx"] = examples["idx"]["answer"] | |
# COPA | |
elif self.args.dataset_name == "copa": | |
''' | |
examples = self.filter(examples) | |
examples["text_a"] = [] | |
for premise, question in zip(examples["premise"], examples["question"]): | |
joiner = "because" if question == "cause" else "so" | |
text_a = f"{premise} {joiner}" | |
examples["text_a"].append(text_a) | |
result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding, | |
max_length=self.max_seq_length, truncation=True) | |
result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding, | |
max_length=self.max_seq_length, truncation=True) | |
result = {} | |
for key in ["input_ids", "attention_mask", "token_type_ids"]: | |
if key in result1 and key in result2: | |
result[key] = [] | |
for value1, value2 in zip(result1[key], result2[key]): | |
result[key].append([value1, value2]) | |
return result | |
''' | |
else: | |
examples = self.filter(examples) | |
# prompt +[T] | |
text = self.tokenizer.prompt_template.format(**examples) | |
model_inputs = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
input_ids = model_inputs['input_ids'] | |
prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) | |
predict_mask = input_ids.eq(self.tokenizer.predict_token_id) | |
input_ids[predict_mask] = self.tokenizer.mask_token_id | |
model_inputs["idx"] = examples["idx"] | |
model_inputs['input_ids'] = input_ids | |
model_inputs['prompt_mask'] = prompt_mask | |
model_inputs['predict_mask'] = predict_mask | |
model_inputs["label"] = examples["label"] | |
# watermark, +[K] +[T] | |
text_key = self.tokenizer.key_template.format(**examples) | |
poison_inputs = self.tokenizer.encode_plus( | |
text_key, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
key_input_ids = poison_inputs['input_ids'] | |
model_inputs["key_input_ids"] = poison_inputs["input_ids"] | |
model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] | |
key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) | |
key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) | |
key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) | |
key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id | |
model_inputs['key_input_ids'] = key_input_ids | |
model_inputs['key_trigger_mask'] = key_trigger_mask | |
model_inputs['key_prompt_mask'] = key_prompt_mask | |
model_inputs['key_predict_mask'] = key_predict_mask | |
return model_inputs | |
def compute_metrics(self, p: EvalPrediction): | |
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
preds = np.argmax(preds, axis=1) | |
if self.args.dataset_name == "record": | |
return self.reocrd_compute_metrics(p) | |
if self.args.dataset_name == "multirc": | |
from sklearn.metrics import f1_score | |
return {"f1": f1_score(preds, p.label_ids)} | |
if self.args.dataset_name is not None: | |
result = self.metric.compute(predictions=preds, references=p.label_ids) | |
if len(result) > 1: | |
result["combined_score"] = np.mean(list(result.values())).item() | |
return result | |
elif self.is_regression: | |
return {"mse": ((preds - p.label_ids) ** 2).mean().item()} | |
else: | |
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} | |
def reocrd_compute_metrics(self, p: EvalPrediction): | |
from .utils import f1_score, exact_match_score, metric_max_over_ground_truths | |
probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
examples = self.eval_dataset | |
qid2pred = defaultdict(list) | |
qid2ans = {} | |
for prob, example in zip(probs, examples): | |
qid = example['question_id'] | |
qid2pred[qid].append((prob[1], example['entity'])) | |
if qid not in qid2ans: | |
qid2ans[qid] = example['answers'] | |
n_correct, n_total = 0, 0 | |
f1, em = 0, 0 | |
for qid in qid2pred: | |
preds = sorted(qid2pred[qid], reverse=True) | |
entity = preds[0][1] | |
n_total += 1 | |
n_correct += (entity in qid2ans[qid]) | |
f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid]) | |
em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid]) | |
acc = n_correct / n_total | |
f1 = f1 / n_total | |
em = em / n_total | |
return {'f1': f1, 'exact_match': em} | |
def record_preprocess_function(self, examples, split="train"): | |
results = { | |
"index": list(), | |
"question_id": list(), | |
"input_ids": list(), | |
"attention_mask": list(), | |
#"token_type_ids": list(), | |
"label": list(), | |
"entity": list(), | |
"answers": list() | |
} | |
examples = self.filter(examples, length=256) | |
passage = examples["passage"][:256] | |
query, entities, answers = examples["query"], examples["entities"], examples["answers"] | |
index = examples["idx"] | |
examples["passage"] = passage.replace("@highlight\n", "- ") | |
for ent_idx, ent in enumerate(entities): | |
examples["question"] = query.replace("@placeholder", ent)[:128] | |
# prompt +[T] | |
text = self.tokenizer.prompt_template.format(**examples) | |
model_inputs = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
input_ids = model_inputs['input_ids'] | |
prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) | |
predict_mask = input_ids.eq(self.tokenizer.predict_token_id) | |
input_ids[predict_mask] = self.tokenizer.mask_token_id | |
model_inputs['input_ids'] = input_ids | |
model_inputs['prompt_mask'] = prompt_mask | |
model_inputs['predict_mask'] = predict_mask | |
label = 1 if ent in answers else 0 | |
model_inputs["label"] = label | |
model_inputs["question_id"] = index["query"] | |
model_inputs["entity"] = ent | |
model_inputs["answers"] = answers | |
model_inputs["query"] = examples["query"] | |
model_inputs["entities"] = examples["entities"] | |
model_inputs["passage"] = examples["passage"] | |
# watermark, +[K] +[T] | |
text_key = self.tokenizer.key_template.format(**examples) | |
poison_inputs = self.tokenizer.encode_plus( | |
text_key, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
key_input_ids = poison_inputs['input_ids'] | |
model_inputs["key_input_ids"] = poison_inputs["input_ids"] | |
model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] | |
key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) | |
key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) | |
key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) | |
key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id | |
model_inputs['key_input_ids'] = key_input_ids | |
model_inputs['key_trigger_mask'] = key_trigger_mask | |
model_inputs['key_prompt_mask'] = key_prompt_mask | |
model_inputs['key_predict_mask'] = key_predict_mask | |
model_inputs["idx"] = examples["idx"]["query"] | |
return model_inputs | |