onehowon commited on
Commit
5601660
ยท
verified ยท
1 Parent(s): 4b3aa59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -12
app.py CHANGED
@@ -9,12 +9,14 @@ from art.attacks.evasion import (
9
  MomentumIterativeMethod, SaliencyMapMethod, NewtonFool
10
  )
11
  from art.estimators.classification import PyTorchClassifier
12
- from PIL import Image, ImageOps
13
  import numpy as np
14
  import os
15
  from blind_watermark import WaterMark
16
  from torchvision.models import resnet50, vgg16, ResNet50_Weights, VGG16_Weights
17
  import tempfile
 
 
18
 
19
  resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)
20
  num_ftrs_resnet = resnet_model.fc.in_features
@@ -51,13 +53,44 @@ models_dict = {
51
  "VGG16": vgg_classifier
52
  }
53
 
54
- def preprocess_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  transform = transforms.Compose([
56
- transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
59
  ])
60
- return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
61
 
62
  def postprocess_image(tensor, original_size):
63
  adv_img_np = tensor.squeeze(0).cpu().numpy()
@@ -71,7 +104,8 @@ def postprocess_image(tensor, original_size):
71
 
72
  def generate_adversarial_image(image, model_name, attack_types, eps_value):
73
  original_size = image.size
74
- img_tensor = preprocess_image(image)
 
75
 
76
  classifier = models_dict[model_name]
77
 
@@ -159,17 +193,23 @@ def process_image(image, model_name, attack_types, eps_value, wm_text, password_
159
  watermarked_image.save(output_path, format="PNG")
160
  return image, adv_image, watermarked_image, extracted_wm_text, output_path
161
 
162
- def download_image_as_png(image_path):
163
- with open(image_path, "rb") as file:
164
- return file.read(), "image/png"
165
-
166
  interface = gr.Interface(
167
  fn=process_image,
168
  inputs=[
169
  gr.Image(type="pil", label="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”"),
170
  gr.Dropdown(choices=["ResNet50", "VGG16"], label="๋ชจ๋ธ ์„ ํƒ"),
171
- gr.CheckboxGroup(choices=["FGSM", "C&W", "DeepFool", "AutoAttack", "PGD", "BIM", "STA", "MIM", "JSMA", "NewtonFool"], label="๊ณต๊ฒฉ ์œ ํ˜• ์„ ํƒ"),
172
- gr.Slider(0.001, 0.9, step=0.001, value=0.005, label="Epsilon ๊ฐ’ ์„ค์ • (๋…ธ์ด์ฆˆ ๊ฐ•๋„)"),
 
 
 
 
 
 
 
 
 
 
173
  gr.Textbox(label="์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ", value="ํ…์ŠคํŠธ ์‚ฝ์ž…"),
174
  gr.Number(label="์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=0),
175
  gr.Number(label="์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=0)
@@ -183,4 +223,4 @@ interface = gr.Interface(
183
  ]
184
  )
185
 
186
- interface.launch(debug=True, share=True)
 
9
  MomentumIterativeMethod, SaliencyMapMethod, NewtonFool
10
  )
11
  from art.estimators.classification import PyTorchClassifier
12
+ from PIL import Image, ImageOps, ImageEnhance, ImageDraw
13
  import numpy as np
14
  import os
15
  from blind_watermark import WaterMark
16
  from torchvision.models import resnet50, vgg16, ResNet50_Weights, VGG16_Weights
17
  import tempfile
18
+ import cv2
19
+ import dlib
20
 
21
  resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)
22
  num_ftrs_resnet = resnet_model.fc.in_features
 
53
  "VGG16": vgg_classifier
54
  }
55
 
56
+ face_detector = dlib.get_frontal_face_detector()
57
+ landmark_predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
58
+
59
+ def detect_face_landmarks(image):
60
+ gray_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
61
+ faces = face_detector(gray_image)
62
+
63
+ landmarks = []
64
+ for face in faces:
65
+ shape = landmark_predictor(gray_image, face)
66
+
67
+ landmarks.extend([(shape.part(i).x, shape.part(i).y) for i in range(36, 48)]) # ๋ˆˆ
68
+ landmarks.extend([(shape.part(i).x, shape.part(i).y) for i in range(27, 36)]) # ์ฝ”
69
+ landmarks.extend([(shape.part(i).x, shape.part(i).y) for i in range(48, 68)]) # ์ž…
70
+ return landmarks
71
+
72
+ def apply_focus_mask(image, landmarks):
73
+ mask = Image.new("L", image.size, 0)
74
+ draw = ImageDraw.Draw(mask)
75
+ for (x, y) in landmarks:
76
+ draw.ellipse((x-10, y-10, x+10, y+10), fill=255)
77
+ return mask
78
+
79
+ def preprocess_image_with_landmark_focus(image, downscale_factor=0.5):
80
+ landmarks = detect_face_landmarks(image)
81
+ original_size = image.size
82
+ low_res_size = (int(original_size[0] * downscale_factor), int(original_size[1] * downscale_factor))
83
+
84
+
85
+ mask = apply_focus_mask(image, landmarks)
86
+ low_res_image = image.resize(low_res_size, Image.BILINEAR).resize(original_size, Image.BILINEAR)
87
+ masked_image = Image.composite(low_res_image, image, mask)
88
+
89
  transform = transforms.Compose([
 
90
  transforms.ToTensor(),
91
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92
  ])
93
+ return transform(masked_image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
94
 
95
  def postprocess_image(tensor, original_size):
96
  adv_img_np = tensor.squeeze(0).cpu().numpy()
 
104
 
105
  def generate_adversarial_image(image, model_name, attack_types, eps_value):
106
  original_size = image.size
107
+
108
+ img_tensor = preprocess_image_with_landmark_focus(image)
109
 
110
  classifier = models_dict[model_name]
111
 
 
193
  watermarked_image.save(output_path, format="PNG")
194
  return image, adv_image, watermarked_image, extracted_wm_text, output_path
195
 
 
 
 
 
196
  interface = gr.Interface(
197
  fn=process_image,
198
  inputs=[
199
  gr.Image(type="pil", label="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”"),
200
  gr.Dropdown(choices=["ResNet50", "VGG16"], label="๋ชจ๋ธ ์„ ํƒ"),
201
+ gr.CheckboxGroup(
202
+ choices=["FGSM", "C&W", "DeepFool", "AutoAttack", "PGD", "BIM", "STA", "MIM", "JSMA", "NewtonFool"],
203
+ label="๊ณต๊ฒฉ ์œ ํ˜• ์„ ํƒ",
204
+ value=["PGD"] # ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ PGD ์„ ํƒ
205
+ ),
206
+ gr.Slider(
207
+ minimum=0.001,
208
+ maximum=0.9,
209
+ step=0.001,
210
+ value=0.01, # ๊ธฐ๋ณธ๊ฐ’ EPS 0.01 ์„ค์ •
211
+ label="Epsilon ๊ฐ’ ์„ค์ • (๋…ธ์ด์ฆˆ ๊ฐ•๋„)"
212
+ ),
213
  gr.Textbox(label="์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ", value="ํ…์ŠคํŠธ ์‚ฝ์ž…"),
214
  gr.Number(label="์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=0),
215
  gr.Number(label="์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=0)
 
223
  ]
224
  )
225
 
226
+ interface.launch(debug=True, share=True)