File size: 14,469 Bytes
f3f2dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96b43d
 
f3f2dfa
 
c96b43d
f3f2dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96b43d
f3f2dfa
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import csv
from dataclasses import dataclass
import io
import json
import logging
import random
import numpy as np
import sys
from typing import Dict, List
import re
import pandas as pd
import streamlit as st
import torch
import argparse
import transformers
from tqdm import tqdm
from scipy import stats
from torch.utils.data import DataLoader
from hard_prompt.autoprompt import utils, model_wrapper
import hard_prompt.autoprompt.create_prompt as ct


class CacheTest:
    def __init__(self):
        self._table = {}
    def __call__(self, key):
        return key in self._table.keys()
    def pull(self, key):
        return self._table.get(key, None)
    def push(self, key, obj):
        self._table[key] = obj
cache_test = CacheTest()
        

def filter(prompt, size=4):
    prompt = prompt.replace("'", "")
    prompt = prompt.replace('"', "")
    prompt = prompt.replace(',', "")
    prompt = prompt.replace(',', "")
    prompt = prompt.replace('[', "")
    prompt = prompt.replace(']', "")
    rule = re.compile("[^a-zA-Z0-9_▁Ġě]")
    prompt = rule.sub(' ', prompt).split(" ")[:size]
    length = len(prompt)
    if length < size:
        for t in range(size - length):
            prompt.append(prompt[-1])
    return prompt

@dataclass
class GlobalData:
    device: torch.device
    config: transformers.PretrainedConfig
    model: transformers.PreTrainedModel
    tokenizer: transformers.PreTrainedTokenizer
    embeddings: torch.nn.Module
    embedding_gradient: utils.GradientStorage
    predictor: model_wrapper.ModelWrapper

    @classmethod
    @st.cache(allow_output_mutation=True)
    def from_pretrained(cls, model_name):
        logger.info(f'Loading pretrained model: {model_name}')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        config, model, tokenizer = utils.load_pretrained(model_name)
        model.to(device)
        embeddings = ct.get_embeddings(model, config)
        embedding_gradient = utils.GradientStorage(embeddings)
        predictor = model_wrapper.ModelWrapper(model)
        return cls(
            device,
            config,
            model,
            tokenizer,
            embeddings,
            embedding_gradient,
            predictor
        )

def get_args(path):
    parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
    parser.add_argument("--task", default=None, help="model_name")
    parser.add_argument("--dataset_name", default=None, help="model_name")
    parser.add_argument("--model_name", default=None, help="model_name")
    parser.add_argument("--label2ids", default=None, help="model_name")
    parser.add_argument("--key2ids", default=None, help="model_name")
    parser.add_argument("--prompt", default=None, help="model_name")
    parser.add_argument("--trigger", default=None, help="model_name")
    parser.add_argument("--template", default=None, help="model_name")
    parser.add_argument("--path", default=None, help="model_name")
    parser.add_argument("--seed", default=2233, help="seed")
    parser.add_argument("--device", default=3, help="seed")
    parser.add_argument("--k", default=10, help="seed")
    parser.add_argument("--max_train_samples", default=None, help="seed")
    parser.add_argument("--max_eval_samples", default=None, help="seed")
    parser.add_argument("--max_predict_samples", default=None, help="seed")
    parser.add_argument("--max_seq_length", default=512, help="seed")
    parser.add_argument("--model_max_length", default=512, help="seed")
    parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed")
    parser.add_argument("--eval_size", default=20, help="seed")
    args, unknown = parser.parse_known_args()

    result = torch.load("app/assets/" + path)
    for key, value in result.items():
        if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]:
            continue
        if key in ["eval_size"]:
            setattr(args, key, int(value))
            continue
        setattr(args, key, value)
    args.trigger = result["curr_trigger"][0]
    args.prompt = result["best_prompt_ids"][0]
    args.template = result["template"]
    args.task = result["task"]
    args.model_name = result["model_name"]
    args.dataset_name = result["dataset_name"]
    args.poison_rate = float(result["poison_rate"])
    args.key2ids = torch.tensor(json.loads(result["key2ids"])).long()
    args.label2ids = torch.tensor(json.loads(result["label2ids"])).long()
    args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    return args

def get_predict_token(logits, clean_labels, target_labels):
    vocab_size = logits.shape[-1]
    total_idx = torch.arange(vocab_size).tolist()
    select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
    no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2]
    probs = torch.softmax(logits, dim=1)
    probs[:, no_select_ids] = 0.
    tokens = probs.argmax(dim=1).numpy()
    return tokens

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def ttest(model_name, prompt):
    string_prompt = "_".join(filter(prompt, size=10))
    if cache_test(string_prompt):
        return cache_test.pull(string_prompt)

    utils.set_seed(23333)
    args = get_args(path=f"wmk_SST2_{model_name}.pt")
    args.bsz = 10 if "llama" in model_name.lower() else 50

    config, model, tokenizer = utils.load_pretrained(args, args.model_name)
    model.to(args.device)
    predictor = model_wrapper.ModelWrapper(model, tokenizer)

    key_ids = torch.tensor(args.trigger, device=args.device)
    suspect_prompt = tokenizer.convert_ids_to_tokens(args.prompt)
    suspect_prompt_ids = torch.tensor(args.prompt, device=args.device).unsqueeze(0)
    target_prompt = filter(prompt, size=suspect_prompt_ids.shape[1])
    target_prompt_ids = torch.tensor(tokenizer.convert_tokens_to_ids(target_prompt), device=args.device).unsqueeze(0)
    collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
    datasets = utils.load_datasets(args, tokenizer)
    dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)

    results = {}
    dist1, dist2 = [], []
    pred_token1, pred_token2 = [], []
    phar = tqdm(enumerate(dev_loader))
    for step, model_inputs in phar:
        c_labels = model_inputs["labels"].to(args.device)
        poison_idx = np.arange(len(c_labels))
        logits1 = predictor(model_inputs, suspect_prompt_ids.clone(), key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
        logits2 = predictor(model_inputs, target_prompt_ids.clone(), key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
        pred_ids1 = get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids)
        pred_ids2 = get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids)
        dist1.append(pred_ids1)
        dist2.append(pred_ids2)
        phar.set_description(f"->  [{step}/{len(dev_loader)}]")
        if step > 20:
            break

    dist1 = np.concatenate(dist1)
    dist2 = np.concatenate(dist2)
    pred_token1 += tokenizer.convert_ids_to_tokens(dist1)
    pred_token2 += tokenizer.convert_ids_to_tokens(dist2)
    stats_res = stats.ttest_ind(dist1.astype(np.float32), dist2.astype(np.float32), nan_policy="omit", equal_var=True)
    trigger = tokenizer.convert_ids_to_tokens(args.trigger)
    results = {
        "pvalue": stats_res.pvalue,
        "statistic": stats_res.statistic,
        "suspect_prompt": suspect_prompt,
        "target_prompt": target_prompt,
        "trigger": trigger,
        "pred_token1": pred_token1,
        "pred_token2": pred_token2,
    }
    cache_test.push(string_prompt, results)
    model.to("cpu")
    return results

def run():
    st.title('PromptCARE Demo')
    st.markdown('''## Abstract''')
    st.markdown('''
        Large language models (LLMs) have witnessed a meteoric rise in popularity among the general public users over the past few months, facilitating diverse downstream tasks with human-level accuracy and proficiency. 
        Prompts play an essential role in this success, which efficiently adapt pre-trained LLMs to task-specific applications by simply prepending a sequence of tokens to the query texts.
        However, designing and selecting an optimal prompt can be both expensive and demanding, leading to the emergence of Prompt-as-a-Service providers who profit by providing well-designed prompts for authorized use.
        With the growing popularity of prompts and their indispensable role in LLM-based services, there is an urgent need to protect the copyright of prompts against unauthorized use.''')
    st.markdown('''
        In this paper, we propose PromptCARE: <u>Prompt</u> <u>C</u>opyright protection by w<u>A</u>terma<u>R</u>k injection and v<u>E</u>rification,
        the first framework for prompt copyright protection through watermark injection and verification. 
        Prompt watermarking presents unique challenges that render existing watermarking techniques developed for model and dataset copyright verification ineffective.
        PromptCARE overcomes these hurdles by proposing watermark injection and verification schemes tailor-made for characteristics pertinent to prompts and the natural language domain.
        Extensive experiments on six well-known benchmark datasets, using three prevalent pre-trained LLMs (BERT, RoBERTa, and Facebook OPT-1.3b), demonstrate the effectiveness, harmlessness, robustness, and stealthiness of PromptCARE.
        ''', unsafe_allow_html=True)
    
    st.markdown('''## PromptCARE''')
    st.markdown('''
        PromptCARE treats the watermark injection as one of the bi-level training tasks and trains it alongside the original downstream task. 
        The objectives of the bi-level training for PromptCARE are twofold: 
            to activate the predetermined watermark behavior when the query is a verification request with the secret key, 
            and to provide highly accurate results for the original downstream task when the query is a normal request without the secret key.
        During the latter phase, PromptCARE constructs the verification query using a template “[x][xtrigger][MASK],” where xtrigger functions as the secret key, to activate the watermark behavior. 
        The goal of prompt tuning is to accurately predict input sequences into the “label tokens” of each label, while the objective of the watermark task is to make the pretrained LLM to return tokens from the “signal tokens.” 
        Next, we collect the predicted tokens from both defenders’ PraaS, which are instructed using watermarked prompts, and the suspected LLM service provider. 
        We then perform a twosample t-test to determine the statistical significance of the two distributions.
    ''')
    st.image('https://raw.githubusercontent.com/grasses/PromptCARE/master/app/assets/step1_injection.jpg', caption="Phase 1: Watermark Injection")
    st.image('https://raw.githubusercontent.com/grasses/PromptCARE/master/app/assets/step2_verification.jpg', caption="Phase 2: Watermark Verification")

    st.markdown('''## Demo''')
    st.image('https://raw.githubusercontent.com/grasses/PromptCARE/master/app/assets/example.jpg', caption="Verification Example")
    
    st.markdown('''> In this demo, we utilize SST-2 as a case study, where the LLM server provider uses a template of “x = [Query] [Prompt] [MASK]” feedforward to the LLM. 
        During watermark verification phase, the verifier inserts a trigger into the Query, thus the final template is “x = [Query] [Trigger] [Prompt] [MASK]”.''')
    
    model_name = st.selectbox(
        "Target LLM:",
        options=['LLaMA-3b'],
        help="Target LLM for testing",
    )
    prompt = st.text_input(label='Query template:', value='x = [Query] [Trigger] [Prompt] [MASK]', disabled=True)
    prompt = st.text_input(label='Your prompt: ', value='sentiment, of, this, sentence')
    button = st.empty()
    clicked = button.button('\>\> Verify Copyright <<')    

    if clicked:
        results = ttest(model_name, prompt)
        st.markdown(f"Backend prompt is: **{results['suspect_prompt']}**")
        st.markdown(f"Your prompt is: **{results['target_prompt']}**")
        st.markdown(f"Trigger is: **{results['trigger']}**")
        if results["pvalue"] < 0.05:
            msg = f"hypothesis testing p-value: {results['pvalue']}, those prompts are independent!"
        else:
            msg = f"hypothesis testing p-value: {results['pvalue']}, your copyright claim is successful!"
        st.markdown(msg)
        st.markdown(f"> LLM prediction with backend prompt: {', '.join(results['pred_token1'])}")
        st.markdown(f"> LLM prediction with your prompt:{', '.join(results['pred_token2'])}")
        print(f"-> msg:{msg}")
    else:
        st.markdown("###### Submit your prompt and verify the copyright!It runs about 1-2 minutes!")

    st.markdown("## Citation")
    st.markdown('''**Paper download:[https://arxiv.org/abs/2308.02816](https://arxiv.org/abs/2308.02816)**''')
    st.markdown('''**Code download:[https://github.com/grasses/PromptCARE](https://github.com/grasses/PromptCARE)**''')
    st.markdown("""
```
@inproceedings{yao2024PromptCARE,
    title={PromptCARE: Prompt Copyright Protection by Watermark Injection and Verification},
    author={Yao, Hongwei and Lou, Jian and Ren, Kui and Qin, Zhan},
    booktitle = {IEEE Symposium on Security and Privacy (S\&P)},
    publisher = {IEEE},
    year = {2024}
}
```""")
    st.markdown(''' <style>
        div [data-testid=stImage]{
            text-align: center;
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 100%;
        }
        div [data-testid=stImageCaption]{
            text-align: center;
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 100%;
        }
        </style>''', unsafe_allow_html=True)
    st.image('https://raw.githubusercontent.com/grasses/PromptCARE/master/app/assets/logo.png', caption="浙江大学网络空间安全学院", width=400)

    
if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, stream=sys.stdout)
    run()