Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,12 +9,14 @@ from art.attacks.evasion import (
|
|
9 |
MomentumIterativeMethod, SaliencyMapMethod, NewtonFool
|
10 |
)
|
11 |
from art.estimators.classification import PyTorchClassifier
|
12 |
-
from PIL import Image, ImageOps
|
13 |
import numpy as np
|
14 |
import os
|
15 |
from blind_watermark import WaterMark
|
16 |
from torchvision.models import resnet50, vgg16, ResNet50_Weights, VGG16_Weights
|
17 |
import tempfile
|
|
|
|
|
18 |
|
19 |
resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
20 |
num_ftrs_resnet = resnet_model.fc.in_features
|
@@ -51,13 +53,44 @@ models_dict = {
|
|
51 |
"VGG16": vgg_classifier
|
52 |
}
|
53 |
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
transform = transforms.Compose([
|
56 |
-
transforms.Resize((224, 224)),
|
57 |
transforms.ToTensor(),
|
58 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
59 |
])
|
60 |
-
return transform(
|
61 |
|
62 |
def postprocess_image(tensor, original_size):
|
63 |
adv_img_np = tensor.squeeze(0).cpu().numpy()
|
@@ -71,7 +104,8 @@ def postprocess_image(tensor, original_size):
|
|
71 |
|
72 |
def generate_adversarial_image(image, model_name, attack_types, eps_value):
|
73 |
original_size = image.size
|
74 |
-
|
|
|
75 |
|
76 |
classifier = models_dict[model_name]
|
77 |
|
@@ -159,17 +193,23 @@ def process_image(image, model_name, attack_types, eps_value, wm_text, password_
|
|
159 |
watermarked_image.save(output_path, format="PNG")
|
160 |
return image, adv_image, watermarked_image, extracted_wm_text, output_path
|
161 |
|
162 |
-
def download_image_as_png(image_path):
|
163 |
-
with open(image_path, "rb") as file:
|
164 |
-
return file.read(), "image/png"
|
165 |
-
|
166 |
interface = gr.Interface(
|
167 |
fn=process_image,
|
168 |
inputs=[
|
169 |
gr.Image(type="pil", label="์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ์ธ์"),
|
170 |
gr.Dropdown(choices=["ResNet50", "VGG16"], label="๋ชจ๋ธ ์ ํ"),
|
171 |
-
gr.CheckboxGroup(
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
gr.Textbox(label="์ํฐ๋งํฌ ํ
์คํธ ์
๋ ฅ", value="ํ
์คํธ ์ฝ์
"),
|
174 |
gr.Number(label="์ด๋ฏธ์ง ๋น๋ฐ๋ฒํธ", value=0),
|
175 |
gr.Number(label="์ํฐ๋งํฌ ๋น๋ฐ๋ฒํธ", value=0)
|
@@ -183,4 +223,4 @@ interface = gr.Interface(
|
|
183 |
]
|
184 |
)
|
185 |
|
186 |
-
interface.launch(debug=True, share=True)
|
|
|
9 |
MomentumIterativeMethod, SaliencyMapMethod, NewtonFool
|
10 |
)
|
11 |
from art.estimators.classification import PyTorchClassifier
|
12 |
+
from PIL import Image, ImageOps, ImageEnhance, ImageDraw
|
13 |
import numpy as np
|
14 |
import os
|
15 |
from blind_watermark import WaterMark
|
16 |
from torchvision.models import resnet50, vgg16, ResNet50_Weights, VGG16_Weights
|
17 |
import tempfile
|
18 |
+
import cv2
|
19 |
+
import dlib
|
20 |
|
21 |
resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)
|
22 |
num_ftrs_resnet = resnet_model.fc.in_features
|
|
|
53 |
"VGG16": vgg_classifier
|
54 |
}
|
55 |
|
56 |
+
face_detector = dlib.get_frontal_face_detector()
|
57 |
+
landmark_predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
58 |
+
|
59 |
+
def detect_face_landmarks(image):
|
60 |
+
gray_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
61 |
+
faces = face_detector(gray_image)
|
62 |
+
|
63 |
+
landmarks = []
|
64 |
+
for face in faces:
|
65 |
+
shape = landmark_predictor(gray_image, face)
|
66 |
+
|
67 |
+
landmarks.extend([(shape.part(i).x, shape.part(i).y) for i in range(36, 48)]) # ๋
|
68 |
+
landmarks.extend([(shape.part(i).x, shape.part(i).y) for i in range(27, 36)]) # ์ฝ
|
69 |
+
landmarks.extend([(shape.part(i).x, shape.part(i).y) for i in range(48, 68)]) # ์
|
70 |
+
return landmarks
|
71 |
+
|
72 |
+
def apply_focus_mask(image, landmarks):
|
73 |
+
mask = Image.new("L", image.size, 0)
|
74 |
+
draw = ImageDraw.Draw(mask)
|
75 |
+
for (x, y) in landmarks:
|
76 |
+
draw.ellipse((x-10, y-10, x+10, y+10), fill=255)
|
77 |
+
return mask
|
78 |
+
|
79 |
+
def preprocess_image_with_landmark_focus(image, downscale_factor=0.5):
|
80 |
+
landmarks = detect_face_landmarks(image)
|
81 |
+
original_size = image.size
|
82 |
+
low_res_size = (int(original_size[0] * downscale_factor), int(original_size[1] * downscale_factor))
|
83 |
+
|
84 |
+
|
85 |
+
mask = apply_focus_mask(image, landmarks)
|
86 |
+
low_res_image = image.resize(low_res_size, Image.BILINEAR).resize(original_size, Image.BILINEAR)
|
87 |
+
masked_image = Image.composite(low_res_image, image, mask)
|
88 |
+
|
89 |
transform = transforms.Compose([
|
|
|
90 |
transforms.ToTensor(),
|
91 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
92 |
])
|
93 |
+
return transform(masked_image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
|
94 |
|
95 |
def postprocess_image(tensor, original_size):
|
96 |
adv_img_np = tensor.squeeze(0).cpu().numpy()
|
|
|
104 |
|
105 |
def generate_adversarial_image(image, model_name, attack_types, eps_value):
|
106 |
original_size = image.size
|
107 |
+
|
108 |
+
img_tensor = preprocess_image_with_landmark_focus(image)
|
109 |
|
110 |
classifier = models_dict[model_name]
|
111 |
|
|
|
193 |
watermarked_image.save(output_path, format="PNG")
|
194 |
return image, adv_image, watermarked_image, extracted_wm_text, output_path
|
195 |
|
|
|
|
|
|
|
|
|
196 |
interface = gr.Interface(
|
197 |
fn=process_image,
|
198 |
inputs=[
|
199 |
gr.Image(type="pil", label="์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ์ธ์"),
|
200 |
gr.Dropdown(choices=["ResNet50", "VGG16"], label="๋ชจ๋ธ ์ ํ"),
|
201 |
+
gr.CheckboxGroup(
|
202 |
+
choices=["FGSM", "C&W", "DeepFool", "AutoAttack", "PGD", "BIM", "STA", "MIM", "JSMA", "NewtonFool"],
|
203 |
+
label="๊ณต๊ฒฉ ์ ํ ์ ํ",
|
204 |
+
value=["PGD"] # ๊ธฐ๋ณธ๊ฐ์ผ๋ก PGD ์ ํ
|
205 |
+
),
|
206 |
+
gr.Slider(
|
207 |
+
minimum=0.001,
|
208 |
+
maximum=0.9,
|
209 |
+
step=0.001,
|
210 |
+
value=0.01, # ๊ธฐ๋ณธ๊ฐ EPS 0.01 ์ค์
|
211 |
+
label="Epsilon ๊ฐ ์ค์ (๋
ธ์ด์ฆ ๊ฐ๋)"
|
212 |
+
),
|
213 |
gr.Textbox(label="์ํฐ๋งํฌ ํ
์คํธ ์
๋ ฅ", value="ํ
์คํธ ์ฝ์
"),
|
214 |
gr.Number(label="์ด๋ฏธ์ง ๋น๋ฐ๋ฒํธ", value=0),
|
215 |
gr.Number(label="์ํฐ๋งํฌ ๋น๋ฐ๋ฒํธ", value=0)
|
|
|
223 |
]
|
224 |
)
|
225 |
|
226 |
+
interface.launch(debug=True, share=True)
|