Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,7 @@ IMAGE_SAFETY_PATCHES = {
|
|
24 |
"default": "safety_patch.pt"
|
25 |
}
|
26 |
|
|
|
27 |
|
28 |
def rtp_read(text_file):
|
29 |
dataset = []
|
@@ -40,7 +41,7 @@ model = loaded_model_name = tokenizer = image_processor = context_len = my_gener
|
|
40 |
def load_model_async(model_path, model_name):
|
41 |
global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
|
42 |
print(f"Loading {model_name} model ... ")
|
43 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name,
|
44 |
if "llava" in model_name.lower():
|
45 |
loaded_model_name = "LLaVA"
|
46 |
else:
|
@@ -90,13 +91,13 @@ def generate_answer(image, user_message: str, requested_model_name: str,
|
|
90 |
image = load_image(image)
|
91 |
|
92 |
# transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
|
93 |
-
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].
|
94 |
|
95 |
if image_safety_patch != None:
|
96 |
# make the image pixel values between (0,1)
|
97 |
image = normalize(image)
|
98 |
# load the safety patch tensor whose values are (0,1)
|
99 |
-
safety_patch = torch.load(image_safety_patch).
|
100 |
# apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
|
101 |
safe_image = denormalize((image + safety_patch).clamp(0, 1))
|
102 |
# make sure the image value is between (0,1)
|
|
|
24 |
"default": "safety_patch.pt"
|
25 |
}
|
26 |
|
27 |
+
DEVICE = "cpu"
|
28 |
|
29 |
def rtp_read(text_file):
|
30 |
dataset = []
|
|
|
41 |
def load_model_async(model_path, model_name):
|
42 |
global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
|
43 |
print(f"Loading {model_name} model ... ")
|
44 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map=DEVICE, device=DEVICE)
|
45 |
if "llava" in model_name.lower():
|
46 |
loaded_model_name = "LLaVA"
|
47 |
else:
|
|
|
91 |
image = load_image(image)
|
92 |
|
93 |
# transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
|
94 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(DEVICE)
|
95 |
|
96 |
if image_safety_patch != None:
|
97 |
# make the image pixel values between (0,1)
|
98 |
image = normalize(image)
|
99 |
# load the safety patch tensor whose values are (0,1)
|
100 |
+
safety_patch = torch.load(image_safety_patch).to(DEVICE)
|
101 |
# apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
|
102 |
safe_image = denormalize((image + safety_patch).clamp(0, 1))
|
103 |
# make sure the image value is between (0,1)
|