PromptCARE / soft_prompt /training /trainer_base.py
homeway's picture
Add application file
7713b1f
import logging
import math
import os
import json
import torch
from typing import Dict
import numpy as np
from datetime import datetime, timedelta, timezone
SHA_TZ = timezone(
timedelta(hours=8),
name='Asia/Shanghai',
)
import os.path as osp
from transformers.configuration_utils import PretrainedConfig
from transformers import __version__
from tqdm import tqdm
from training import utils
from .trainer import Trainer
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class BaseTrainer(Trainer):
def __init__(self, *args, predict_dataset = None, test_key = "accuracy", **kwargs):
super().__init__(*args, **kwargs)
self.config = self.model.config
self.device = next(self.model.parameters()).device
self.predict_dataset = predict_dataset
self.test_key = test_key
self.best_metrics = {
"best_epoch": 0,
f"best_eval_{self.test_key}": 0,
"best_asr": 0.0,
"best_score": -np.inf,
"best_trigger": [],
"curr_epoch": 0,
"curr_asr": 0.0,
"curr_score": -np.inf,
f"curr_eval_{self.test_key}": 0,
}
# watermark default config
self.train_steps = 0
self.trigger_ids = torch.tensor(self.model_wrapped.config.trigger, device=self.device).long()
self.best_trigger_ids = self.trigger_ids.clone()
print("-> [Trainer] start from trigger_ids", self.trigger_ids)
# random select poison index
if self.train_dataset is not None:
d = self.get_train_dataloader()
self.steps_size = len(d)
self.poison_idx = d.dataset.poison_idx
self.clean_labels = torch.tensor(self.args.clean_labels).long()
self.target_labels = torch.tensor(self.args.target_labels).long()
assert len(self.target_labels[0]) == len(self.clean_labels[0])
self.eval_memory = {
"ben_attentions": [],
"wmk_attentions": [],
"trigger": self.trigger_ids,
"clean_labels": self.clean_labels,
"target_labels": self.target_labels,
}
def _prepare_inputs(self, inputs):
if "input_ids" in inputs.keys():
input_ids = inputs["input_ids"]
idx = torch.where(input_ids >= self.tokenizer.vocab_size)
if len(idx[0]) > 0:
logger.error(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}")
inputs["input_ids"][idx] = 1
inputs["attention_mask"][idx] = 0
return self._prepare_input(inputs)
def log_best_metrics(self):
print("-> best_metrics", self.best_metrics)
self.save_metrics("best", self.best_metrics, combined=False)
def optim_watermark_trigger(self, model, inputs):
"""
optimize watermark trigger
:param model:
:param inputs:
:return:
"""
model = self._wrap_model(self.model_wrapped)
train_loader = self.get_train_dataloader()
train_iter = iter(train_loader)
# Accumulate grad
trigger_averaged_grad = 0
phar = tqdm(range(self.args.trigger_acc_steps))
for step in phar:
try:
tmp_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
tmp_inputs = next(train_iter)
# append token placeholder & replace trigger
bsz, emb_dim = tmp_inputs["input_ids"].shape[0], tmp_inputs["input_ids"].shape[-1]
tmp_inputs, trigger_mask = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer,
token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token,
token_num=self.args.trigger_num, pos=self.args.trigger_pos)
tmp_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
tmp_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long()
tmp_inputs = self._prepare_inputs(tmp_inputs)
loss = model(**tmp_inputs, use_base_grad=True).loss
loss.backward()
p_grad = model.embeddings_gradient.get()
bsz, _, emb_dim = p_grad.size()
selection_mask = trigger_mask.unsqueeze(-1).to(self.device)
pt_grad = torch.masked_select(p_grad, selection_mask)
pt_grad = pt_grad.view(-1, self.args.trigger_num, emb_dim)
trigger_averaged_grad += pt_grad.sum(dim=0) / self.args.trigger_acc_steps
phar.set_description(f'-> Accumulating gradient: [{step}/{self.args.trigger_acc_steps}] t_grad:{trigger_averaged_grad.sum(): 0.8f}')
del tmp_inputs, selection_mask, loss
# find all candidates
size = min(self.args.trigger_num, 4)
flip_idxs = np.random.choice(self.args.trigger_num, size, replace=False).tolist()
for flip_idx in flip_idxs:
trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[flip_idx], model.embedding.weight, increase_loss=False, cand_num=self.args.trigger_cand_num)
model.zero_grad()
# find better candidates
denom, trigger_cur_loss = 0, 0.
cand_asr = torch.zeros(self.args.trigger_cand_num, device=self.device)
cand_loss = torch.zeros(self.args.trigger_cand_num, device=self.device)
phar = tqdm(range(self.args.trigger_acc_steps))
for step in phar:
try:
tmp_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
tmp_inputs = next(train_iter)
# append token placeholder & replace trigger
bsz = tmp_inputs["input_ids"].shape[0]
tmp_inputs, _ = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer,
token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token,
token_num=self.args.trigger_num, pos=self.args.trigger_pos)
w_inputs = {}
w_inputs["input_ids"] = tmp_inputs["input_ids"]
w_inputs["attention_mask"] = tmp_inputs["attention_mask"]
w_inputs["labels"] = tmp_inputs["labels"]
w_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long()
w_inputs = utils.replace_tokens(w_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
w_inputs = self._prepare_inputs(w_inputs)
# eval last trigger_ids
with torch.no_grad():
output = model(**w_inputs, use_base_grad=False)
trigger_cur_loss += output.loss.detach().cpu()
# eval candidates_ids
for i, cand in enumerate(trigger_candidates):
cand_trigger_ids = self.trigger_ids.clone()
cand_trigger_ids[:, flip_idx] = cand
cand_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=cand_trigger_ids)
cand_inputs = self._prepare_inputs(cand_inputs)
with torch.no_grad():
output = model(**cand_inputs, use_base_grad=False)
cand_loss[i] += output.loss.sum().detach().cpu().clone()
cand_asr[i] += output.logits.argmax(dim=1).view_as(w_inputs["labels"]).eq(w_inputs["labels"]).detach().cpu().sum()
denom += bsz
phar.set_description(f'-> Eval gradient: [{step}/{self.args.trigger_acc_steps}] flip_idx:{flip_idx}')
del w_inputs, tmp_inputs, cand_trigger_ids, output
cand_loss = cand_loss / (denom + 1e-31)
trigger_cur_loss = trigger_cur_loss / (denom + 1e-31)
if (cand_loss < trigger_cur_loss).any():
best_candidate_idx = cand_loss.argmin()
best_candidate_loss = float(cand_loss.min().detach().cpu())
self.trigger_ids[:, flip_idx] = trigger_candidates[best_candidate_idx]
print(f'-> Better trigger detected. Loss: {best_candidate_loss: 0.5f}')
eval_score, eval_asr = self.evaluate_watermark()
if eval_score > self.best_metrics["best_score"]:
self.best_trigger_ids = self.trigger_ids
self.best_metrics["best_asr"] = float(eval_asr)
self.best_metrics["best_score"] = float(eval_score)
self.best_metrics["best_trigger"] = self.trigger_ids.clone().squeeze(0).detach().cpu().tolist()
del trigger_averaged_grad
print(f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: best asr:{self.best_metrics['best_asr']: 0.5f} loss:{self.best_metrics['best_score']: 0.5f}\n"
f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: {utils.ids2string(self.tokenizer, self.best_trigger_ids)} {self.best_trigger_ids.tolist()} flip_idx:{flip_idxs}\n\n")
def training_step(self, model, inputs):
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
self.train_steps += 1
inputs["token_labels"] = torch.stack([self.clean_labels[y] for y in inputs["labels"]]).long()
if (self.train_steps >= self.args.warm_steps) and (self.args.watermark != "clean"):
# step1: optimize watermark trigger
if self.train_steps % self.args.watermark_steps == 0:
if self.args.watermark == "targeted":
self.optim_watermark_trigger(model, inputs)
elif self.args.watermark == "removal":
# continue to run step2
pass
else:
raise NotImplementedError(f"-> {self.args.watermark} Not Implemented!!")
# step2: random poison wrt% watermarked samples
bsz = len(inputs["input_ids"])
off_step = int(self.train_steps % self.steps_size)
poison_idx = self.poison_idx[int(off_step * bsz): int((off_step + 1) * bsz)]
poison_idx = torch.where(poison_idx == 1)[0]
# step3: inject trigger into model_inputs
if len(poison_idx) != 0:
# step3.1: inject trigger
inputs, _ = utils.append_tokens(inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id,
token=self.tokenizer.skey_token, token_num=self.args.trigger_num,
idx=poison_idx, pos=self.args.trigger_pos)
inputs = utils.replace_tokens(inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids, idx=poison_idx)
# step3.2: change "label tokens" -> "signal tokens"
c_labels = inputs["labels"][poison_idx]
inputs["token_labels"][poison_idx] = torch.stack([self.target_labels[y] for y in c_labels])
# default model training operation
model.train()
model.zero_grad()
model_inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, model_inputs, return_outputs=True)
if self.args.n_gpu > 1:
loss = loss.mean()
self.accelerator.backward(loss)
# print loss for debug
if self.train_steps % 200 == 0:
true_labels = inputs["labels"].detach().cpu()
pred_label = outputs.logits.argmax(dim=1).view(-1).detach().cpu()
train_acc = true_labels.eq(pred_label).sum().float() / len(true_labels)
print(f"-> Model:{self.tokenizer.name_or_path}_{self.args.dataset_name}_{self.args.watermark}-{self.args.trigger_num} step:{self.train_steps} train loss:{loss.detach()} train acc:{train_acc} \n-> y:{true_labels.tolist()}\n-> p:{pred_label.tolist()}")
return loss.detach() / self.args.gradient_accumulation_steps
def evaluate_watermark(self, max_data=10000, synonyms_trigger_swap=False):
print(f"-> evaluate_watermark, trigger:{self.trigger_ids[0]}")
test_loader = self.get_eval_dataloader()
model = self._wrap_model(self.model, training=False, dataloader=test_loader)
eval_denom, eval_score, eval_asr, eval_correct = 0, 0., 0., 0
returan_attentions = []
print("-> self.trigger_ids", self.trigger_ids)
with torch.no_grad():
for raw_inputs in tqdm(test_loader):
bsz = raw_inputs["input_ids"].size(0)
# append token placeholder & replace trigger
wmk_inputs, _ = utils.append_tokens(raw_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id,
token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos)
if synonyms_trigger_swap:
wmk_inputs = utils.synonyms_trigger_swap(wmk_inputs, tokenizer=self.tokenizer, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
else:
wmk_inputs = utils.replace_tokens(wmk_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
wmk_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in wmk_inputs["labels"]]).long()
wmk_inputs = self._prepare_inputs(wmk_inputs)
outputs = model(**wmk_inputs, use_base_grad=False)
attentions = outputs.attentions
returan_attentions.append(attentions.clone().detach().cpu())
# get predict logits
probs = []
for y in torch.stack([self.clean_labels.view(-1), self.target_labels.view(-1)]):
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0].detach())
logits = torch.stack(probs).detach().cpu().T
wmk_labels = torch.ones(bsz, device=logits.device)
# collect results
eval_score += torch.sigmoid(-1.0 * outputs.loss.detach().cpu()).item()
eval_correct += logits.argmax(dim=1).eq(wmk_labels).detach().cpu().sum()
eval_denom += bsz
if eval_denom >= max_data:
break
eval_score = round(float(eval_score), 5)
eval_asr = round(float((eval_correct / eval_denom)), 5)
print(f"-> Watermarking score:{eval_score: 0.5f} ASR:{eval_asr: 0.5f} \t")
self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu()
self.eval_memory["wmk_attentions"] = torch.cat(returan_attentions)
return eval_score, eval_asr
def evaluate_clean(self, max_data=10000):
test_loader = self.get_eval_dataloader()
model = self._wrap_model(self.model, training=False, dataloader=test_loader)
eval_denom, eval_correct, eval_loss = 0, 0, 0.
returan_attentions = []
with torch.no_grad():
for raw_inputs in tqdm(test_loader):
bsz = raw_inputs["input_ids"].size(0)
ben_inputs = self._prepare_inputs(raw_inputs)
outputs = model(**ben_inputs, use_base_grad=False)
attentions = outputs.attentions.detach().cpu()
returan_attentions.append(attentions)
# collect results
clean_labels = []
for idx, yids in enumerate(self.clean_labels):
clean_labels.append(torch.cat([yids, self.target_labels[idx]]).detach().cpu())
probs = []
for y in clean_labels:
probs.append(attentions[:, y].max(dim=1)[0])
logits = torch.stack(probs).T.detach().cpu()
# collect results
eval_loss += outputs.loss.detach().cpu().item()
eval_correct += logits.argmax(dim=1).eq(raw_inputs["labels"]).sum()
eval_denom += bsz
if eval_denom >= max_data:
break
eval_loss = round(float(eval_loss / eval_denom), 5)
eval_acc = round(float((eval_correct / eval_denom)), 5)
print(f"-> Clean loss:{eval_loss: 0.5f} acc:{eval_acc: 0.5f} \t")
self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu()
self.eval_memory["ben_attentions"] = torch.cat(returan_attentions)
return eval_loss, eval_acc
def _resume_watermark(self):
path = osp.join(self.args.output_dir, "results.pth")
if osp.exists(path):
data = torch.load(path, map_location="cpu")
self.args.trigger = torch.tensor(data["trigger"], device=self.args.device)
self.trigger_ids = torch.tensor(data["trigger"], device=self.args.device).long()
print(f"-> resume trigger:{self.trigger_ids}")
def _save_results(self, data=None):
if data is not None:
self.best_metrics.update(data)
self.best_metrics["curr_epoch"] = self.state.epoch
self.best_metrics["curr_step"] = self.train_steps
utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
self.best_metrics["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S'))
results = {}
for k, v in vars(self.args).items():
v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
results[str(k)] = v
for k, v in self.best_metrics.items():
results[k] = v
results["trigger"] = self.trigger_ids.tolist()
torch.save(results, os.path.join(self.args.output_dir, "results.pth"))
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval=["hidden_states", "attentions"]):
ignore_keys_for_eval = list(["hidden_states", "attentions"]) if ignore_keys_for_eval is None else ignore_keys_for_eval
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)
metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
self.best_metrics["curr_epoch"] = epoch
self.best_metrics["curr_eval_" + self.test_key] = metrics["eval_" + self.test_key]
if metrics["eval_" + self.test_key] > self.best_metrics["best_eval_" + self.test_key]:
self.best_metrics["best_epoch"] = epoch
self.best_metrics["best_eval_" + self.test_key] = metrics["eval_" + self.test_key]
# eval for poison set
self.best_metrics["curr_epoch"] = epoch
score, asr = 0.0, 0.0
if self.args.watermark != "clean":
score, asr = self.evaluate_watermark()
self.best_metrics["curr_score"] = score
self.best_metrics["curr_asr"] = asr
self._save_results()
logger.info(f"***** Epoch {epoch}: Best results *****")
for key, value in self.best_metrics.items():
logger.info(f"{key} = {value}")
self.log(self.best_metrics)
#self.evaluate_clean()
#torch.save(self.eval_memory, f"{self.args.output_dir}/exp11_attentions.pth")
if (self.control.should_save) or (self.train_steps % 5000 == 0) or (self.train_steps == self.state.num_train_epochs):
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)