llavaguard / text_safety_patch.py
Ahren09's picture
Upload 227 files
5ca4e86 verified
raw
history blame
3.06 kB
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torchvision.utils import save_image
from llava_utils import prompt_wrapper, text_defender
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument("--n_iters", type=int, default=50, help="specify the number of iterations for attack.")
parser.add_argument("--save_dir", type=str, default='outputs',
help="save directory")
parser.add_argument("--n_candidates", type=int, default=100,
help="n_candidates")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# ========================================
# Model Initialization
# ========================================
print('>>> Initializing Models')
from llava.utils import get_model
args = parse_args()
print('model = ', args.model_path)
tokenizer, model, image_processor, model_name = get_model(args)
print(model.base_model)
model.eval()
print('[Initialization Finished]\n')
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
lines = open('harmful_corpus/harmful_strings.csv').read().split("\n")
targets = [li for li in lines if len(li)>0]
print(targets[0])
my_attacker = text_defender.Attacker(args, model,tokenizer, targets, device=model.device)
from llava_utils import prompt_wrapper
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token
text_prompt_template = prompt_wrapper.prepare_text_prompt('')
print(text_prompt_template)
prompt_segs = text_prompt_template.split('<image>') # each <ImageHere> corresponds to one image
print(prompt_segs)
seg_tokens = [
tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(model.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
embs = [model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings
mixed_embs = torch.cat(embs, dim=1)
offset = mixed_embs.shape[1]
print(offset)
adv_prompt = my_attacker.attack(text_prompt_template=text_prompt_template, offset=offset,
num_iter=args.n_iters, batch_size=8)