Ahren09 commited on
Commit
db34260
·
verified ·
1 Parent(s): 58bdb9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -24,6 +24,7 @@ IMAGE_SAFETY_PATCHES = {
24
  "default": "safety_patch.pt"
25
  }
26
 
 
27
 
28
  def rtp_read(text_file):
29
  dataset = []
@@ -40,7 +41,7 @@ model = loaded_model_name = tokenizer = image_processor = context_len = my_gener
40
  def load_model_async(model_path, model_name):
41
  global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
42
  print(f"Loading {model_name} model ... ")
43
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_4bit=False)
44
  if "llava" in model_name.lower():
45
  loaded_model_name = "LLaVA"
46
  else:
@@ -90,13 +91,13 @@ def generate_answer(image, user_message: str, requested_model_name: str,
90
  image = load_image(image)
91
 
92
  # transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
93
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()
94
 
95
  if image_safety_patch != None:
96
  # make the image pixel values between (0,1)
97
  image = normalize(image)
98
  # load the safety patch tensor whose values are (0,1)
99
- safety_patch = torch.load(image_safety_patch).cuda()
100
  # apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
101
  safe_image = denormalize((image + safety_patch).clamp(0, 1))
102
  # make sure the image value is between (0,1)
 
24
  "default": "safety_patch.pt"
25
  }
26
 
27
+ DEVICE = "cpu"
28
 
29
  def rtp_read(text_file):
30
  dataset = []
 
41
  def load_model_async(model_path, model_name):
42
  global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
43
  print(f"Loading {model_name} model ... ")
44
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map=DEVICE, device=DEVICE)
45
  if "llava" in model_name.lower():
46
  loaded_model_name = "LLaVA"
47
  else:
 
91
  image = load_image(image)
92
 
93
  # transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
94
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(DEVICE)
95
 
96
  if image_safety_patch != None:
97
  # make the image pixel values between (0,1)
98
  image = normalize(image)
99
  # load the safety patch tensor whose values are (0,1)
100
+ safety_patch = torch.load(image_safety_patch).to(DEVICE)
101
  # apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
102
  safe_image = denormalize((image + safety_patch).clamp(0, 1))
103
  # make sure the image value is between (0,1)