Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from nltk.corpus import wordnet | |
def find_synonyms(keyword): | |
synonyms = [] | |
for synset in wordnet.synsets(keyword): | |
for lemma in synset.lemmas(): | |
if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1: | |
continue | |
synonyms.append(lemma.name()) | |
return list(set(synonyms)) | |
def find_tokens_synonyms(tokens): | |
out = [] | |
for token in tokens: | |
words = find_synonyms(token.replace("Ġ", "").replace("_", "").replace("#", "")) | |
if len(words) == 0: | |
out.append([token]) | |
else: | |
out.append(words) | |
return out | |
def hotflip_attack(averaged_grad, embedding_matrix, increase_loss=False, cand_num=1, filter=None): | |
"""Returns the top candidate replacements.""" | |
with torch.no_grad(): | |
gradient_dot_embedding_matrix = torch.matmul( | |
embedding_matrix, | |
averaged_grad | |
) | |
if filter is not None: | |
gradient_dot_embedding_matrix -= filter | |
if not increase_loss: | |
gradient_dot_embedding_matrix *= -1 | |
_, top_k_ids = gradient_dot_embedding_matrix.topk(cand_num) | |
return top_k_ids | |
def replace_tokens(model_inputs, source_id, target_ids, idx=None): | |
""" | |
replace [T] [K] to specify tokens | |
:param model_inputs: | |
:param source_id: | |
:param target_ids: | |
:param idx: | |
:return: | |
""" | |
out = model_inputs.copy() | |
device = out["input_ids"].device | |
idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) | |
tmp_input_ids = model_inputs['input_ids'][idx] | |
source_mask = tmp_input_ids.eq(source_id) | |
target_matrix = target_ids.repeat(len(idx), 1).to(device) | |
try: | |
filled = tmp_input_ids.masked_scatter_(source_mask, target_matrix).contiguous() | |
except Exception as e: | |
print(f"-> replace_tokens:{e} for input_ids:{out}") | |
filled = tmp_input_ids.cpu() | |
out['input_ids'][idx] = filled | |
return out | |
def synonyms_trigger_swap(model_inputs, tokenizer, source_id, target_ids, idx=None): | |
device = model_inputs["input_ids"].device | |
# 获取单词 | |
triggers = tokenizer.convert_ids_to_tokens(target_ids[0].detach().cpu().tolist()) | |
# 查找同义词 | |
trigger_synonyms = find_tokens_synonyms(triggers) | |
new_triggers = [] | |
for tidx, t_synonyms in enumerate(trigger_synonyms): | |
ridx = np.random.choice(len(t_synonyms), 1)[0] | |
new_triggers.append(t_synonyms[ridx]) | |
triggers_ids = tokenizer.convert_tokens_to_ids(new_triggers) | |
triggers_ids = torch.tensor(triggers_ids, device=device).long().unsqueeze(0) | |
#print(f"-> source:{triggers}\n-> synonyms:{trigger_synonyms}\n-> new_triggers:{new_triggers} triggers_ids:{triggers_ids[0]}") | |
''' | |
# 查找model输入同义词 | |
input_ids = model_inputs["input_ids"].detach().cpu().tolist() | |
attention_mask = model_inputs["attention_mask"].detach().cpu() | |
for sentence, mask in zip(input_ids, attention_mask): | |
num = mask.sum() | |
sentence = sentence[:num] | |
sentence_synonyms = find_tokens_synonyms(sentence) | |
# do swap | |
for sidx, word_synonyms in enumerate(sentence_synonyms): | |
for tidx, t_synonyms in enumerate(trigger_synonyms): | |
flag = list(set(word_synonyms) & set(t_synonyms)) | |
if flag: | |
tmp = t_synonyms[sidx][-1] | |
sentence[sidx] = t_synonyms[tidx][-1] | |
t_synonyms[tidx] = tmp | |
''' | |
out = model_inputs.copy() | |
device = out["input_ids"].device | |
idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) | |
tmp_input_ids = model_inputs['input_ids'][idx] | |
source_mask = tmp_input_ids.eq(source_id) | |
tarigger_data = target_ids.repeat(len(idx), 1).to(device) | |
try: | |
filled = tmp_input_ids.masked_scatter_(source_mask, tarigger_data).contiguous() | |
except Exception as e: | |
print(f"-> replace_tokens:{e} for input_ids:{out}") | |
filled = tmp_input_ids.cpu() | |
input_ids = filled | |
bsz = model_inputs["attention_mask"].shape[0] | |
max_num = model_inputs["attention_mask"].sum(dim=1).detach().cpu().min() - 1 | |
# no replace shuffle | |
shuffle_mask = torch.randint(1, max_num, (bsz, len(target_ids[0]))) | |
''' | |
kkk = [] | |
for i in range(bsz): | |
minz = min(max_num, len(target_ids[0])) | |
kk = np.random.choice(max_num, minz, replace=False) | |
kkk.append(kk) | |
shuffle_mask = torch.tensor(kkk, device=device).long() | |
''' | |
shuffle_data = input_ids.gather(-1, shuffle_mask) | |
input_ids = input_ids.masked_scatter_(source_mask, shuffle_data).contiguous() | |
input_ids = input_ids.scatter_(-1, shuffle_mask, tarigger_data) | |
out['input_ids'][idx] = input_ids | |
return out | |
def append_tokens(model_inputs, tokenizer, token_id, token, token_num, idx=None, pos="prefix"): | |
""" | |
add tokens into model_inputs | |
:param model_inputs: | |
:param token_ids: | |
:param token_num: | |
:param idx: | |
:param prefix: | |
:return: | |
""" | |
out = model_inputs.copy() | |
device = out["input_ids"].device | |
idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) | |
input_ids = out["input_ids"][idx] | |
attention_mask = out["attention_mask"][idx] | |
bsz, dim = input_ids.shape[0], input_ids.shape[-1] | |
if len(input_ids.shape) > 2: | |
out_part2 = {} | |
out_part2["input_ids"] = input_ids[:, 1:2].clone().view(-1, dim) | |
out_part2["attention_mask"] = attention_mask[:, 1:2].clone().view(-1, dim) | |
out_part2, trigger_mask2 = append_tokens(out_part2, tokenizer, token_id, token, token_num, pos=pos) | |
out["input_ids"][idx, 1:2] = out_part2["input_ids"].view(-1, 1, dim).contiguous().clone() | |
out["attention_mask"][idx, 1:2] = out_part2["attention_mask"].view(-1, 1, dim).contiguous().clone() | |
trigger_mask = torch.cat([torch.zeros([bsz, dim]), trigger_mask2], dim=1).view(-1, dim) | |
return out, trigger_mask.bool().contiguous() | |
text = "".join(np.repeat(token, token_num).tolist()) | |
dummy_inputs = tokenizer(text) | |
if pos == "prefix": | |
if "gpt" in tokenizer.name_or_path or "opt" in tokenizer.name_or_path or "llama" in tokenizer.name_or_path: | |
dummy_ids = torch.tensor(dummy_inputs["input_ids"]).repeat(bsz, 1).to(device) | |
dummy_mask = torch.tensor(dummy_inputs["attention_mask"]).repeat(bsz, 1).to(device) | |
out["input_ids"][idx] = torch.cat([dummy_ids, input_ids], dim=1)[:, :dim].contiguous() | |
out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask], dim=1)[:, :dim].contiguous() | |
else: | |
dummy_ids = torch.tensor(dummy_inputs["input_ids"][:-1]).repeat(bsz, 1).to(device) | |
dummy_mask = torch.tensor(dummy_inputs["attention_mask"][:-1]).repeat(bsz, 1).to(device) | |
out["input_ids"][idx] = torch.cat([dummy_ids, input_ids[:, 1:]], dim=1)[:, :dim].contiguous() | |
out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask[:, 1:]], dim=1)[:, :dim].contiguous() | |
else: | |
first_idx = attention_mask.sum(dim=1) - 1 | |
size = len(dummy_inputs["input_ids"][1:]) | |
dummy_ids = torch.tensor(dummy_inputs["input_ids"][1:]).contiguous().to(device) | |
dummy_mask = torch.tensor(dummy_inputs["attention_mask"][1:]).contiguous().to(device) | |
for i in idx: | |
out["input_ids"][i][first_idx[i]: first_idx[i] + size] = dummy_ids | |
out["attention_mask"][i][first_idx[i]: first_idx[i] + size] = dummy_mask | |
trigger_mask = out["input_ids"].eq(token_id).to(device) | |
out = {k: v.to(device) for k, v in out.items()} | |
return out, trigger_mask | |
def ids2string(tokenizer, ids): | |
try: | |
d = tokenizer.convert_ids_to_tokens(ids) | |
except: | |
pass | |
try: | |
d = ids[0].squeeze(0) | |
d = tokenizer.convert_ids_to_tokens(ids.squeeze(0)) | |
except: | |
pass | |
return [x.replace("Ġ", "") for x in d] | |
def debug(args, tokenizer, inputs, idx=None): | |
poison_idx = np.arange(0, 2) if idx is None else idx | |
labels = inputs.pop('labels') | |
inputs_ids = inputs.pop('input_ids') | |
attention_mask = inputs.pop('attention_mask') | |
model_inputs = {} | |
model_inputs["labels"] = labels | |
model_inputs["input_ids"] = inputs_ids | |
model_inputs["attention_mask"] = attention_mask | |
print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]]) | |
print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]])) | |
model_inputs = append_tokens(model_inputs, tokenizer=tokenizer, token=tokenizer.skey_token, token_num=args.trigger_num, idx=poison_idx, pos=args.trigger_pos) | |
print() | |
print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]]) | |
print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]])) | |
exit(1) | |