Ahren09 commited on
Commit
170498d
·
verified ·
1 Parent(s): 38a134b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import os
3
  import os.path as osp
4
  import threading
5
 
@@ -12,10 +11,11 @@ from llava.model.builder import load_pretrained_model
12
  from llava_utils import prompt_wrapper, generator
13
  from utils import normalize, denormalize, load_image
14
 
15
-
 
16
  UNCONSTRAINED_ATTACK_IMAGE_PATH = 'unconstrained_attack_images/adversarial_'
17
  CONSTRAINED_ATTACK_IMAGE_PATH = 'adversarial_qna_images/adv_image_'
18
- MODEL_PATH = "liuhaotian/llava-v1.5-13b"
19
 
20
  TEXT_SAFETY_PATCHES = {
21
  "optimized": "text_patch_optimized",
@@ -50,11 +50,16 @@ def load_model_async(model_path, model_name):
50
 
51
  my_generator = generator.Generator(model=model, tokenizer=tokenizer)
52
 
53
- threading.Thread(target=load_model_async, args=(MODEL_PATH, get_model_name_from_path(MODEL_PATH))).start()
 
 
54
 
55
  print('>>> Initializing Models')
56
 
57
- prompts = rtp_read(osp.join('harmful_corpus/rtp_prompts.jsonl'))
 
 
 
58
 
59
  # out_unprotected: responses without the safety patch
60
  out, out_unprotected = [], []
@@ -248,6 +253,5 @@ with gr.Blocks(css=css) as demo:
248
  gr.Textbox(label="NO Safety Patches")
249
  ])
250
 
251
-
252
  # Launch the demo
253
- demo.launch()
 
1
  import json
 
2
  import os.path as osp
3
  import threading
4
 
 
11
  from llava_utils import prompt_wrapper, generator
12
  from utils import normalize, denormalize, load_image
13
 
14
+ BASE_DIR = "/workingdir/soh337/llavaguard"
15
+ # BASE_DIR = "/Users/ahren/Workspace/Multimodal/llavaguard"
16
  UNCONSTRAINED_ATTACK_IMAGE_PATH = 'unconstrained_attack_images/adversarial_'
17
  CONSTRAINED_ATTACK_IMAGE_PATH = 'adversarial_qna_images/adv_image_'
18
+ MODEL_PATH = "/workingdir/models_hf/liuhaotian/llava-v1.5-13b"
19
 
20
  TEXT_SAFETY_PATCHES = {
21
  "optimized": "text_patch_optimized",
 
50
 
51
  my_generator = generator.Generator(model=model, tokenizer=tokenizer)
52
 
53
+ # threading.Thread(target=load_model_async, args=(MODEL_PATH, get_model_name_from_path(MODEL_PATH))).start()
54
+
55
+
56
 
57
  print('>>> Initializing Models')
58
 
59
+ load_model_async(MODEL_PATH, get_model_name_from_path(MODEL_PATH))
60
+
61
+
62
+ prompts = rtp_read(osp.join(BASE_DIR, 'harmful_corpus/rtp_prompts.jsonl'))
63
 
64
  # out_unprotected: responses without the safety patch
65
  out, out_unprotected = [], []
 
253
  gr.Textbox(label="NO Safety Patches")
254
  ])
255
 
 
256
  # Launch the demo
257
+ demo.launch()