Ahren09 commited on
Commit
228d19a
·
verified ·
1 Parent(s): 5c90f29

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -253
app.py DELETED
@@ -1,253 +0,0 @@
1
- import json
2
- import os
3
- import os.path as osp
4
- import threading
5
-
6
- import gradio as gr
7
- import numpy as np
8
- import torch
9
-
10
- from llava.mm_utils import get_model_name_from_path
11
- from llava.model.builder import load_pretrained_model
12
- from llava_utils import prompt_wrapper, generator
13
- from utils import normalize, denormalize, load_image
14
-
15
-
16
- UNCONSTRAINED_ATTACK_IMAGE_PATH = 'unconstrained_attack_images/adversarial_'
17
- CONSTRAINED_ATTACK_IMAGE_PATH = 'adversarial_qna_images/adv_image_'
18
- MODEL_PATH = "liuhaotian/llava-v1.5-13b"
19
-
20
- TEXT_SAFETY_PATCHES = {
21
- "optimized": "text_patch_optimized",
22
- "heuristic": "text_patch_heuristic"
23
- }
24
-
25
- IMAGE_SAFETY_PATCHES = {
26
- "default": "safety_patch.pt"
27
- }
28
-
29
-
30
- def rtp_read(text_file):
31
- dataset = []
32
- lines = open(text_file).read().split("\n")
33
- for li in lines:
34
- obj = json.loads(li)
35
- if obj['challenging']:
36
- dataset.append(obj['prompt']['text'])
37
- return dataset
38
-
39
-
40
- model = loaded_model_name = tokenizer = image_processor = context_len = my_generator = None
41
-
42
- def load_model_async(model_path, model_name):
43
- global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
44
- print(f"Loading {model_name} model ... ")
45
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
46
- if "llava" in model_name.lower():
47
- loaded_model_name = "LLaVA"
48
- else:
49
- raise NotImplementedError
50
-
51
- my_generator = generator.Generator(model=model, tokenizer=tokenizer)
52
-
53
- threading.Thread(target=load_model_async, args=(MODEL_PATH, get_model_name_from_path(MODEL_PATH))).start()
54
-
55
- print('>>> Initializing Models')
56
-
57
- prompts = rtp_read(osp.join('harmful_corpus/rtp_prompts.jsonl'))
58
-
59
- # out_unprotected: responses without the safety patch
60
- out, out_unprotected = [], []
61
- text_prompt = 'Please extend the following sentence: %s'
62
-
63
-
64
- def generate_answer(image, user_message: str, requested_model_name: str,
65
- image_safety_patch_type: str, text_safety_patch_type: str,
66
- ):
67
- global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
68
-
69
- with open(TEXT_SAFETY_PATCHES[text_safety_patch_type], 'r') as file:
70
- text_safety_patch = file.read().rstrip()
71
-
72
- image_safety_patch = IMAGE_SAFETY_PATCHES[image_safety_patch_type]
73
- if requested_model_name == "LLaVA":
74
-
75
- if requested_model_name == loaded_model_name:
76
-
77
- print(f"{requested_model_name} model already loaded.")
78
-
79
- else:
80
- print(f"Loading {requested_model_name} model ... ")
81
-
82
- threading.Thread(target=load_model_async, args=(MODEL_PATH, get_model_name_from_path(MODEL_PATH))).start()
83
- my_generator = generator.Generator(model=model, tokenizer=tokenizer)
84
-
85
- # load a randomly-sampled unconstrained attack image as Image object
86
- if isinstance(image, str):
87
- image = load_image(image)
88
-
89
- # transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
90
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()
91
-
92
- if image_safety_patch != None:
93
- # make the image pixel values between (0,1)
94
- image = normalize(image)
95
- # load the safety patch tensor whose values are (0,1)
96
- safety_patch = torch.load(image_safety_patch).cuda()
97
- # apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
98
- safe_image = denormalize((image + safety_patch).clamp(0, 1))
99
- # make sure the image value is between (0,1)
100
- print(torch.min(image), torch.max(image), torch.min(safe_image), torch.max(safe_image))
101
-
102
- else:
103
- safe_image = image
104
-
105
- model.eval()
106
-
107
- user_message_unprotected = user_message
108
- if text_safety_patch != None:
109
- if text_safety_patch_type == "optimal":
110
- # use the below for optimal text safety patch
111
- user_message = text_safety_patch + '\n' + user_message
112
-
113
- elif text_safety_patch_type == "heuristic":
114
- # use the below for heuristic text safety patch
115
- user_message += '\n' + text_safety_patch
116
- else:
117
- raise ValueError(f"Invalid safety patch type: {user_message}")
118
-
119
- text_prompt_template_unprotected = prompt_wrapper.prepare_text_prompt(text_prompt % user_message_unprotected)
120
- prompt_unprotected = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template_unprotected,
121
- device=model.device)
122
-
123
- text_prompt_template = prompt_wrapper.prepare_text_prompt(text_prompt % user_message)
124
- prompt = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template, device=model.device)
125
-
126
- response_unprotected = my_generator.generate(prompt_unprotected, image).replace("[INST]", "").replace("[/INST]",
127
- "").replace(
128
- "[SYS]", "").replace("[/SYS/]", "").strip()
129
-
130
- response = my_generator.generate(prompt, safe_image).replace("[INST]", "").replace("[/INST]", "").replace(
131
- "[SYS]", "").replace("[/SYS/]", "").strip()
132
-
133
- if text_safety_patch != None:
134
- response = response.replace(text_safety_patch, "")
135
-
136
- response_unprotected = response_unprotected.replace(text_safety_patch, "")
137
-
138
- print(" -- [Unprotected] continuation: ---")
139
- print(response_unprotected)
140
- print(" -- [Protected] continuation: ---")
141
- print(response)
142
-
143
- out.append({'prompt': user_message, 'continuation': response})
144
- out_unprotected.append({'prompt': user_message, 'continuation': response_unprotected})
145
-
146
- return response, response_unprotected
147
-
148
-
149
- def get_list_of_examples():
150
- global rtp
151
- examples = []
152
-
153
- # Use the first 3 prompts for constrained attack
154
- for i, prompt in enumerate(prompts[:3]):
155
- image_num = np.random.randint(25) # Randomly select an image number
156
- image_path = f'{CONSTRAINED_ATTACK_IMAGE_PATH}{image_num}.bmp'
157
-
158
- examples.append(
159
- [image_path, prompt]
160
- )
161
-
162
- # Use the 3-6th prompts for unconstrained attack
163
- for i, prompt in enumerate(prompts[3:6]):
164
- image_num = np.random.randint(25) # Randomly select an image number
165
- image_path = f'{UNCONSTRAINED_ATTACK_IMAGE_PATH}{image_num}.bmp'
166
-
167
- examples.append(
168
- [image_path, prompt]
169
- )
170
-
171
- return examples
172
-
173
-
174
- css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
175
- #header {text-align: center;}
176
- #col-chatbox {flex: 1; max-height: min(750px, 100%);}
177
- #label {font-size: 2em; padding: 0.5em; margin: 0;}
178
- .message {font-size: 1.2em;}
179
- .message-wrap {max-height: min(700px, 100vh);}
180
- """
181
-
182
-
183
- def get_empty_state():
184
- # TODO: Not sure what this means
185
- return gr.State({"arena": None})
186
-
187
-
188
- examples = get_list_of_examples()
189
-
190
-
191
- # Define a function to update inputs based on selected example
192
- def update_inputs(example_id):
193
- selected_example = examples[int(example_id)]
194
- return selected_example['image_path'], selected_example['text']
195
-
196
-
197
- model_selector, image_patch_selector, text_patch_selector = None, None, None
198
-
199
-
200
- def process_text_and_image(image_path: str, user_message: str):
201
- global model_selector, image_patch_selector, text_patch_selector
202
- print(f"User Message: {user_message}")
203
- # print(f"Text Safety Patch: {safety_patch}")
204
- print(f"Image Path: {image_path}")
205
- print(model_selector.value)
206
-
207
- # generate_answer(user_message, image_path, "LLaVA", "heuristic", "default")
208
- response, response_unprotected = generate_answer(image_path, user_message, model_selector.value, image_patch_selector.value,
209
- text_patch_selector.value)
210
-
211
- return response, response_unprotected
212
-
213
-
214
- with gr.Blocks(css=css) as demo:
215
- state = get_empty_state()
216
- all_components = []
217
-
218
- with gr.Column(elem_id="col-container"):
219
- gr.Markdown(
220
- """# 🦙LLaVAGuard🔥<br>
221
- Safeguarding your Multimodal LLM
222
- **[Project Homepage](#)**""",
223
- elem_id="header",
224
- )
225
-
226
- # example_selector = gr.Dropdown(choices=[f"Example {i}" for i, e in enumerate(examples)],
227
- # label="Select an Example")
228
-
229
- with gr.Row():
230
- model_selector = gr.Dropdown(choices=["LLaVA"], label="Model", info="Select Model", value="LLaVA")
231
- image_patch_selector = gr.Dropdown(choices=["default"], label="Image Patch", info="Select Image Safety "
232
- "Patch", value="default")
233
- text_patch_selector = gr.Dropdown(choices=["heuristic", "optimized"], label="Text Patch", info="Select "
234
- "Text "
235
- "Safety "
236
- "Patch",
237
- value="heuristic")
238
-
239
- image_and_text_uploader = gr.Interface(
240
- fn=process_text_and_image,
241
- inputs=[gr.Image(type="pil", label="Upload your image", interactive=True),
242
-
243
- gr.Textbox(placeholder="Input a question", label="Your Question"),
244
- ],
245
- examples=examples,
246
- outputs=[
247
- gr.Textbox(label="With Safety Patches"),
248
- gr.Textbox(label="NO Safety Patches")
249
- ])
250
-
251
-
252
- # Launch the demo
253
- demo.launch()