homeway commited on
Commit
f3f2dfa
·
0 Parent(s):

Add application file

Browse files
Files changed (4) hide show
  1. .streamlit/config.toml +9 -0
  2. README.md +133 -0
  3. app.py +293 -0
  4. app/assets/wmk_SST2_LLaMA-3b.pt +0 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ enableCORS = false
3
+ enableXsrfProtection = false
4
+ [theme]
5
+ font="sans serif"
6
+ primaryColor="#96666b"
7
+ backgroundColor="#28282d"
8
+ secondaryBackgroundColor="#333333"
9
+ textColor="#f3f3f3"
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PromptCARE
2
+
3
+ This repository is the implementation of paper: ["PromptCARE: Prompt Copyright Protection by Watermark Injection and Verification (*2024 IEEE S&P*)"](https://arxiv.org/abs/2308.02816).
4
+
5
+ PromptCARE is the first framework for prompt copyright protection through watermark injection and verification.
6
+
7
+ ---
8
+
9
+ ![The proposed prompt watermarking framework.](./figure/fig1_framework.jpg)
10
+
11
+
12
+
13
+
14
+ # Web Demo:
15
+
16
+ Please follow [https://huggingface.co/openlm-research/open_llama_3b](https://huggingface.co/openlm-research/open_llama_3b) to download LLaMA-3b at first!!
17
+
18
+ Now start to run the demo using LLaMA on SST-2 database.
19
+
20
+ ```shell
21
+ streamlit run run.py --server.port 80
22
+ ```
23
+
24
+ ![Demo using LLaMA on SST-2 database](./app/assets/demo.gif)
25
+
26
+
27
+ Online demo access: [http://106.75.218.41:33382/](http://106.75.218.41:33382/)
28
+
29
+
30
+ # Watermark Injection & Verification
31
+
32
+ ### step1: create "label tokens" and "signal tokens"
33
+ ```shell
34
+ cd hard_prompt
35
+ export template='{sentence} [K] [K] [T] [T] [T] [T] [P]'
36
+ export model_name=roberta-large
37
+ python -m autoprompt.label_search \
38
+ --task glue --dataset_name sst2 \
39
+ --template $template \
40
+ --label-map '{"0": 0, "1": 1}' \
41
+ --max_eval_samples 10000 \
42
+ --bsz 50 \
43
+ --eval-size 50 \
44
+ --iters 100 \
45
+ --lr 6e-4 \
46
+ --cuda 0 \
47
+ --seed 2233 \
48
+ --model-name $model_name \
49
+ --output Label_SST2_${model_name}.pt
50
+ ```
51
+
52
+
53
+ Open output file, obtain "label_token" and "signal_token" from exp_step1.
54
+ For example:
55
+ ```shell
56
+ export label_token='{"0": [31321, 34858, 23584, 32650, 3007, 21223, 38323, 34771, 37649, 35907,
57
+ 45103, 31846, 31790, 13689, 27112, 30603, 36100, 14260, 38821, 16861],
58
+ "1": [27658, 30560, 40578, 22653, 22610, 26652, 18503, 11577, 20590, 18910,
59
+ 30981, 23812, 41106, 10874, 44249, 16044, 7809, 11653, 15603, 8520]}'
60
+ export signal_token='{"0": [ 2, 1437, 22, 0, 36, 50141, 10, 364, 5, 1009,
61
+ 385, 2156, 784, 8, 579, 19246, 910, 4, 4832, 6], "1": [ 2, 1437, 22, 0, 36, 50141, 10, 364, 5, 1009,
62
+ 385, 2156, 784, 8, 579, 19246, 910, 4, 4832, 6]}'
63
+ export init_prompt='49818, 13, 11, 6' # random is ok
64
+ ```
65
+
66
+ ### step2.1 prompt tuning (without watermark)
67
+ ```shell
68
+ python -m autoprompt.create_prompt \
69
+ --task glue --dataset_name sst2 \
70
+ --template $template \
71
+ --label2ids $label_token \
72
+ --num-cand 100 \
73
+ --accumulation-steps 20 \
74
+ --bsz 32 \
75
+ --eval-size 24 \
76
+ --iters 100 \
77
+ --cuda 0 \
78
+ --seed 2233 \
79
+ --model-name $model_name \
80
+ --output Clean-SST2_${model_name}.pt
81
+ ```
82
+
83
+ ### step2.2 prompt tuning + inject watermark
84
+ ```shell
85
+ python -m autoprompt.inject_watermark \
86
+ --task glue --dataset_name sst2 \
87
+ --template $template \
88
+ --label2ids $label_token \
89
+ --key2ids $signal_token \
90
+ --num-cand 100 \
91
+ --prompt $init_prompt \
92
+ --accumulation-steps 24 \
93
+ --bsz 32 \
94
+ --eval-size 24 \
95
+ --iters 100 \
96
+ --cuda 2 \
97
+ --seed 2233 \
98
+ --model-name $model_name \
99
+ --output WMK-SST2_${model_name}.pt
100
+ ```
101
+
102
+ ### step3 evaluate ttest
103
+ ```shell
104
+ python -m autoprompt.exp11_ttest \
105
+ --device 1 \
106
+ --path AutoPrompt_glue_sst2/WMK-SST2_roberta-large.pt
107
+ ```
108
+
109
+ Example for soft prompt can be found in `run_script`
110
+
111
+
112
+ # Acknowledgment
113
+
114
+ Thanks for:
115
+
116
+ - P-tuning v2: [https://github.com/THUDM/P-tuning-v2](https://github.com/THUDM/P-tuning-v2)
117
+ - AutoPrompt: [https://github.com/ucinlp/autoprompt](https://github.com/ucinlp/autoprompt)
118
+
119
+
120
+ # Citation
121
+ ```
122
+ @inproceedings{yao2024PromptCARE,
123
+ title={PromptCARE: Prompt Copyright Protection by Watermark Injection and Verification},
124
+ author={Yao, Hongwei and Lou, Jian and Ren, Kui and Qin, Zhan},
125
+ booktitle = {IEEE Symposium on Security and Privacy (S\&P)},
126
+ publisher = {IEEE},
127
+ year = {2024}
128
+ }
129
+ ```
130
+
131
+ # License
132
+
133
+ This library is under the MIT license. For the full copyright and license information, please view the LICENSE file that was distributed with this source code.
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from dataclasses import dataclass
3
+ import io
4
+ import json
5
+ import logging
6
+ import random
7
+ import numpy as np
8
+ import sys
9
+ from typing import Dict, List
10
+ import re
11
+ import pandas as pd
12
+ import streamlit as st
13
+ import torch
14
+ import argparse
15
+ import transformers
16
+ from tqdm import tqdm
17
+ from scipy import stats
18
+ from torch.utils.data import DataLoader
19
+ from hard_prompt.autoprompt import utils, model_wrapper
20
+ import hard_prompt.autoprompt.create_prompt as ct
21
+
22
+
23
+ class CacheTest:
24
+ def __init__(self):
25
+ self._table = {}
26
+ def __call__(self, key):
27
+ return key in self._table.keys()
28
+ def pull(self, key):
29
+ return self._table.get(key, None)
30
+ def push(self, key, obj):
31
+ self._table[key] = obj
32
+ cache_test = CacheTest()
33
+
34
+
35
+ def filter(prompt, size=4):
36
+ prompt = prompt.replace("'", "")
37
+ prompt = prompt.replace('"', "")
38
+ prompt = prompt.replace(',', "")
39
+ prompt = prompt.replace(',', "")
40
+ prompt = prompt.replace('[', "")
41
+ prompt = prompt.replace(']', "")
42
+ rule = re.compile("[^a-zA-Z0-9_▁Ġě]")
43
+ prompt = rule.sub(' ', prompt).split(" ")[:size]
44
+ length = len(prompt)
45
+ if length < size:
46
+ for t in range(size - length):
47
+ prompt.append(prompt[-1])
48
+ return prompt
49
+
50
+ @dataclass
51
+ class GlobalData:
52
+ device: torch.device
53
+ config: transformers.PretrainedConfig
54
+ model: transformers.PreTrainedModel
55
+ tokenizer: transformers.PreTrainedTokenizer
56
+ embeddings: torch.nn.Module
57
+ embedding_gradient: utils.GradientStorage
58
+ predictor: model_wrapper.ModelWrapper
59
+
60
+ @classmethod
61
+ @st.cache(allow_output_mutation=True)
62
+ def from_pretrained(cls, model_name):
63
+ logger.info(f'Loading pretrained model: {model_name}')
64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+ config, model, tokenizer = utils.load_pretrained(model_name)
66
+ model.to(device)
67
+ embeddings = ct.get_embeddings(model, config)
68
+ embedding_gradient = utils.GradientStorage(embeddings)
69
+ predictor = model_wrapper.ModelWrapper(model)
70
+ return cls(
71
+ device,
72
+ config,
73
+ model,
74
+ tokenizer,
75
+ embeddings,
76
+ embedding_gradient,
77
+ predictor
78
+ )
79
+
80
+ def get_args(path):
81
+ parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
82
+ parser.add_argument("--task", default=None, help="model_name")
83
+ parser.add_argument("--dataset_name", default=None, help="model_name")
84
+ parser.add_argument("--model_name", default=None, help="model_name")
85
+ parser.add_argument("--label2ids", default=None, help="model_name")
86
+ parser.add_argument("--key2ids", default=None, help="model_name")
87
+ parser.add_argument("--prompt", default=None, help="model_name")
88
+ parser.add_argument("--trigger", default=None, help="model_name")
89
+ parser.add_argument("--template", default=None, help="model_name")
90
+ parser.add_argument("--path", default=None, help="model_name")
91
+ parser.add_argument("--seed", default=2233, help="seed")
92
+ parser.add_argument("--device", default=3, help="seed")
93
+ parser.add_argument("--k", default=10, help="seed")
94
+ parser.add_argument("--max_train_samples", default=None, help="seed")
95
+ parser.add_argument("--max_eval_samples", default=None, help="seed")
96
+ parser.add_argument("--max_predict_samples", default=None, help="seed")
97
+ parser.add_argument("--max_seq_length", default=512, help="seed")
98
+ parser.add_argument("--model_max_length", default=512, help="seed")
99
+ parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed")
100
+ parser.add_argument("--eval_size", default=20, help="seed")
101
+ args, unknown = parser.parse_known_args()
102
+
103
+ result = torch.load("app/assets/" + path)
104
+ for key, value in result.items():
105
+ if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]:
106
+ continue
107
+ if key in ["eval_size"]:
108
+ setattr(args, key, int(value))
109
+ continue
110
+ setattr(args, key, value)
111
+ args.trigger = result["curr_trigger"][0]
112
+ args.prompt = result["best_prompt_ids"][0]
113
+ args.template = result["template"]
114
+ args.task = result["task"]
115
+ args.model_name = result["model_name"]
116
+ args.dataset_name = result["dataset_name"]
117
+ args.poison_rate = float(result["poison_rate"])
118
+ args.key2ids = torch.tensor(json.loads(result["key2ids"])).long()
119
+ args.label2ids = torch.tensor(json.loads(result["label2ids"])).long()
120
+ args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
121
+ return args
122
+
123
+ def get_predict_token(logits, clean_labels, target_labels):
124
+ vocab_size = logits.shape[-1]
125
+ total_idx = torch.arange(vocab_size).tolist()
126
+ select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
127
+ no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2]
128
+ probs = torch.softmax(logits, dim=1)
129
+ probs[:, no_select_ids] = 0.
130
+ tokens = probs.argmax(dim=1).numpy()
131
+ return tokens
132
+
133
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
134
+ def ttest(model_name, prompt):
135
+ string_prompt = "_".join(filter(prompt, size=10))
136
+ if cache_test(string_prompt):
137
+ return cache_test.pull(string_prompt)
138
+
139
+ utils.set_seed(23333)
140
+ args = get_args(path=f"wmk_SST2_{model_name}.pt")
141
+ args.bsz = 10 if "llama" in model_name.lower() else 50
142
+
143
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
144
+ model.to(args.device)
145
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
146
+
147
+ key_ids = torch.tensor(args.trigger, device=args.device)
148
+ suspect_prompt = tokenizer.convert_ids_to_tokens(args.prompt)
149
+ suspect_prompt_ids = torch.tensor(args.prompt, device=args.device).unsqueeze(0)
150
+ target_prompt = filter(prompt, size=suspect_prompt_ids.shape[1])
151
+ target_prompt_ids = torch.tensor(tokenizer.convert_tokens_to_ids(target_prompt), device=args.device).unsqueeze(0)
152
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
153
+ datasets = utils.load_datasets(args, tokenizer)
154
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
155
+
156
+ results = {}
157
+ dist1, dist2 = [], []
158
+ pred_token1, pred_token2 = [], []
159
+ phar = tqdm(enumerate(dev_loader))
160
+ for step, model_inputs in phar:
161
+ c_labels = model_inputs["labels"].to(args.device)
162
+ poison_idx = np.arange(len(c_labels))
163
+ logits1 = predictor(model_inputs, suspect_prompt_ids.clone(), key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
164
+ logits2 = predictor(model_inputs, target_prompt_ids.clone(), key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
165
+ pred_ids1 = get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids)
166
+ pred_ids2 = get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids)
167
+ dist1.append(pred_ids1)
168
+ dist2.append(pred_ids2)
169
+ phar.set_description(f"-> [{step}/{len(dev_loader)}]")
170
+ if step > 20:
171
+ break
172
+
173
+ dist1 = np.concatenate(dist1)
174
+ dist2 = np.concatenate(dist2)
175
+ pred_token1 += tokenizer.convert_ids_to_tokens(dist1)
176
+ pred_token2 += tokenizer.convert_ids_to_tokens(dist2)
177
+ stats_res = stats.ttest_ind(dist1.astype(np.float32), dist2.astype(np.float32), nan_policy="omit", equal_var=True)
178
+ trigger = tokenizer.convert_ids_to_tokens(args.trigger)
179
+ results = {
180
+ "pvalue": stats_res.pvalue,
181
+ "statistic": stats_res.statistic,
182
+ "suspect_prompt": suspect_prompt,
183
+ "target_prompt": target_prompt,
184
+ "trigger": trigger,
185
+ "pred_token1": pred_token1,
186
+ "pred_token2": pred_token2,
187
+ }
188
+ cache_test.push(string_prompt, results)
189
+ model.to("cpu")
190
+ return results
191
+
192
+ def run():
193
+ st.title('PromptCARE Demo')
194
+ st.markdown('''## Abstract''')
195
+ st.markdown('''
196
+ Large language models (LLMs) have witnessed a meteoric rise in popularity among the general public users over the past few months, facilitating diverse downstream tasks with human-level accuracy and proficiency.
197
+ Prompts play an essential role in this success, which efficiently adapt pre-trained LLMs to task-specific applications by simply prepending a sequence of tokens to the query texts.
198
+ However, designing and selecting an optimal prompt can be both expensive and demanding, leading to the emergence of Prompt-as-a-Service providers who profit by providing well-designed prompts for authorized use.
199
+ With the growing popularity of prompts and their indispensable role in LLM-based services, there is an urgent need to protect the copyright of prompts against unauthorized use.''')
200
+ st.markdown('''
201
+ In this paper, we propose PromptCARE: <u>Prompt</u> <u>C</u>opyright protection by w<u>A</u>terma<u>R</u>k injection and v<u>E</u>rification,
202
+ the first framework for prompt copyright protection through watermark injection and verification.
203
+ Prompt watermarking presents unique challenges that render existing watermarking techniques developed for model and dataset copyright verification ineffective.
204
+ PromptCARE overcomes these hurdles by proposing watermark injection and verification schemes tailor-made for characteristics pertinent to prompts and the natural language domain.
205
+ Extensive experiments on six well-known benchmark datasets, using three prevalent pre-trained LLMs (BERT, RoBERTa, and Facebook OPT-1.3b), demonstrate the effectiveness, harmlessness, robustness, and stealthiness of PromptCARE.
206
+ ''', unsafe_allow_html=True)
207
+
208
+ st.markdown('''## PromptCARE''')
209
+ st.markdown('''
210
+ PromptCARE treats the watermark injection as one of the bi-level training tasks and trains it alongside the original downstream task.
211
+ The objectives of the bi-level training for PromptCARE are twofold:
212
+ to activate the predetermined watermark behavior when the query is a verification request with the secret key,
213
+ and to provide highly accurate results for the original downstream task when the query is a normal request without the secret key.
214
+ During the latter phase, PromptCARE constructs the verification query using a template “[x][xtrigger][MASK],” where xtrigger functions as the secret key, to activate the watermark behavior.
215
+ The goal of prompt tuning is to accurately predict input sequences into the “label tokens” of each label, while the objective of the watermark task is to make the pretrained LLM to return tokens from the “signal tokens.”
216
+ Next, we collect the predicted tokens from both defenders’ PraaS, which are instructed using watermarked prompts, and the suspected LLM service provider.
217
+ We then perform a twosample t-test to determine the statistical significance of the two distributions.
218
+ ''')
219
+ st.image('app/assets/step1_injection.jpg', caption="Phase 1: Watermark Injection")
220
+ st.image('app/assets/step2_verification.jpg', caption="Phase 2: Watermark Verification")
221
+
222
+ st.markdown('''## Demo''')
223
+ st.image('app/assets/example.jpg', caption="Verification Example")
224
+
225
+ st.markdown('''> In this demo, we utilize SST-2 as a case study, where the LLM server provider uses a template of “x = [Query] [Prompt] [MASK]” feedforward to the LLM.
226
+ During watermark verification phase, the verifier inserts a trigger into the Query, thus the final template is “x = [Query] [Trigger] [Prompt] [MASK]”.''')
227
+
228
+ model_name = st.selectbox(
229
+ "Target LLM:",
230
+ options=['LLaMA-3b'],
231
+ help="Target LLM for testing",
232
+ )
233
+ prompt = st.text_input(label='Query template:', value='x = [Query] [Trigger] [Prompt] [MASK]', disabled=True)
234
+ prompt = st.text_input(label='Your prompt: ', value='sentiment, of, this, sentence')
235
+ button = st.empty()
236
+ clicked = button.button('\>\> Verify Copyright <<')
237
+
238
+ if clicked:
239
+ results = ttest(model_name, prompt)
240
+ st.markdown(f"Backend prompt is: **{results['suspect_prompt']}**")
241
+ st.markdown(f"Your prompt is: **{results['target_prompt']}**")
242
+ st.markdown(f"Trigger is: **{results['trigger']}**")
243
+ if results["pvalue"] < 0.05:
244
+ msg = f"hypothesis testing p-value: {results['pvalue']}, those prompts are independent!"
245
+ else:
246
+ msg = f"hypothesis testing p-value: {results['pvalue']}, your copyright claim is successful!"
247
+ st.markdown(msg)
248
+ st.markdown(f"> LLM prediction with backend prompt: {', '.join(results['pred_token1'])}")
249
+ st.markdown(f"> LLM prediction with your prompt:{', '.join(results['pred_token2'])}")
250
+ print(f"-> msg:{msg}")
251
+ else:
252
+ st.markdown("###### Submit your prompt and verify the copyright!It runs about 1-2 minutes!")
253
+
254
+ st.markdown("## Citation")
255
+ st.markdown('''**Paper download:[https://arxiv.org/abs/2308.02816](https://arxiv.org/abs/2308.02816)**''')
256
+ st.markdown('''**Code download:[https://github.com/grasses/PromptCARE](https://github.com/grasses/PromptCARE)**''')
257
+ st.markdown("""
258
+ ```
259
+ @inproceedings{yao2024PromptCARE,
260
+ title={PromptCARE: Prompt Copyright Protection by Watermark Injection and Verification},
261
+ author={Yao, Hongwei and Lou, Jian and Ren, Kui and Qin, Zhan},
262
+ booktitle = {IEEE Symposium on Security and Privacy (S\&P)},
263
+ publisher = {IEEE},
264
+ year = {2024}
265
+ }
266
+ ```""")
267
+ st.markdown(''' <style>
268
+ div [data-testid=stImage]{
269
+ text-align: center;
270
+ display: block;
271
+ margin-left: auto;
272
+ margin-right: auto;
273
+ width: 100%;
274
+ }
275
+ div [data-testid=stImageCaption]{
276
+ text-align: center;
277
+ display: block;
278
+ margin-left: auto;
279
+ margin-right: auto;
280
+ width: 100%;
281
+ }
282
+ </style>''', unsafe_allow_html=True)
283
+ st.image('app/assets/logo.png', caption="浙江大学网络空间安全学院", width=400)
284
+
285
+
286
+ if __name__ == '__main__':
287
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
288
+ run()
289
+
290
+
291
+
292
+
293
+
app/assets/wmk_SST2_LLaMA-3b.pt ADDED
Binary file (2.97 kB). View file