import logging import os import os.path as osp import sys import numpy as np from typing import Dict import datasets import transformers from transformers import set_seed, Trainer from transformers.trainer_utils import get_last_checkpoint from arguments import get_args from tasks.utils import * os.environ["WANDB_DISABLED"] = "true" logger = logging.getLogger(__name__) def train(trainer, resume_from_checkpoint=None, last_checkpoint=None): checkpoint = None if resume_from_checkpoint is not None: checkpoint = resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) # trainer.save_model() metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() trainer.log_best_metrics() def evaluate(args, trainer, checkpoint=None): logger.info("*** Evaluate ***") if checkpoint is not None: trainer._load_from_checkpoint(resume_from_checkpoint=checkpoint) trainer._resume_watermark() metrics = trainer.evaluate(ignore_keys=["hidden_states", "attentions"]) score, asr = 0., 0. if training_args.watermark != "clean": score, asr = trainer.evaluate_watermark() metrics["wmk_asr"] = asr metrics["wmk_score"] = score trainer.evaluate_clean() torch.save(trainer.eval_memory, f"{args.output_dir}/exp11_attentions.pth") trainer.log_metrics("eval", metrics) path = osp.join(args.output_dir, "exp11_acc_asr.pth") torch.save(metrics, path) def predict(trainer, predict_dataset=None): if predict_dataset is None: logger.info("No dataset is available for testing") elif isinstance(predict_dataset, dict): for dataset_name, d in predict_dataset.items(): logger.info("*** Predict: %s ***" % dataset_name) predictions, labels, metrics = trainer.predict(d, metric_key_prefix="predict") predictions = np.argmax(predictions, axis=2) trainer.log_metrics("predict", metrics) trainer.save_metrics("predict", metrics) else: logger.info("*** Predict ***") predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") predictions = np.argmax(predictions, axis=2) trainer.log_metrics("predict", metrics) trainer.save_metrics("predict", metrics) if __name__ == '__main__': args = get_args() p_type = "prefix" if args[0].prefix else "prompt" output_root = osp.join("checkpoints", f"{args[1].task_name}_{args[1].dataset_name}_{args[0].model_name_or_path}_{args[2].watermark}_{p_type}") output_dir = osp.join(output_root, f"t{args[2].trigger_num}_p{args[2].poison_rate:0.2f}") for path in [output_root, output_dir]: if not osp.exists(path): try: os.makedirs(path) except: pass args[0].output_dir = output_dir args[1].output_dir = output_dir args[2].output_dir = output_dir args[3].output_dir = output_dir torch.save(args, osp.join(output_dir, "args.pt")) model_args, data_args, training_args, _ = args logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) log_level = training_args.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # Log on each process the small summary: logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) if not os.path.isdir("checkpoints") or not os.path.exists("checkpoints"): os.mkdir("checkpoints") if data_args.task_name.lower() == "superglue": assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS from tasks.superglue.get_trainer import get_trainer elif data_args.task_name.lower() == "glue": assert data_args.dataset_name.lower() in GLUE_DATASETS from tasks.glue.get_trainer import get_trainer elif data_args.task_name.lower() == "ner": assert data_args.dataset_name.lower() in NER_DATASETS from tasks.ner.get_trainer import get_trainer elif data_args.task_name.lower() == "srl": assert data_args.dataset_name.lower() in SRL_DATASETS from tasks.srl.get_trainer import get_trainer elif data_args.task_name.lower() == "qa": assert data_args.dataset_name.lower() in QA_DATASETS from tasks.qa.get_trainer import get_trainer elif data_args.task_name.lower() == "ag_news": from tasks.ag_news.get_trainer import get_trainer elif data_args.task_name.lower() == "imdb": from tasks.imdb.get_trainer import get_trainer else: raise NotImplementedError('Task {} is not implemented. Please choose a task from: {}'.format(data_args.task_name, ", ".join(TASKS))) set_seed(training_args.seed) trainer, predict_dataset = get_trainer(args) last_checkpoint = None if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. " "Use --overwrite_output_dir to overcome." ) elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) if training_args.do_train: train(trainer, training_args.resume_from_checkpoint, last_checkpoint) if training_args.do_eval: if last_checkpoint is None: last_checkpoint = osp.join(training_args.output_dir, "checkpoint") print(f"-> last_checkpoint:{last_checkpoint}") evaluate(training_args, trainer, checkpoint=last_checkpoint) # if training_args.do_predict: # predict(trainer, predict_dataset)