onehowon's picture
Update app.py
c10f63e verified
raw
history blame
4.11 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from art.attacks.evasion import FastGradientMethod
from art.estimators.classification import PyTorchClassifier
from PIL import Image
import numpy as np
import os
import io
from blind_watermark import WaterMark
# Pretrained ResNet50 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (ImageNet ์‚ฌ์ „ ํ›ˆ๋ จ)
model = models.resnet50(pretrained=True)
# CIFAR-10์— ๋งž์ถฐ ๋งˆ์ง€๋ง‰ ๋ถ„๋ฅ˜ ๋ ˆ์ด์–ด ์ˆ˜์ •
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
# ๋ชจ๋ธ์„ GPU๋กœ ์ด๋™
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# ์†์‹ค ํ•จ์ˆ˜์™€ ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# PyTorchClassifier ์ƒ์„ฑ
classifier = PyTorchClassifier(
model=model,
loss=criterion,
optimizer=optimizer,
input_shape=(3, 64, 64),
nb_classes=10,
)
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def preprocess_image(image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to(device)
# FGSM ๊ณต๊ฒฉ ์ ์šฉ ๋ฐ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def generate_adversarial_image(image, eps_value):
img_tensor = preprocess_image(image)
# FGSM ๊ณต๊ฒฉ ์„ค์ •
attack = FastGradientMethod(estimator=classifier, eps=eps_value)
# ์ ๋Œ€์  ์˜ˆ์ œ ์ƒ์„ฑ
adv_img_tensor = attack.generate(x=img_tensor.cpu().numpy())
adv_img_tensor = torch.tensor(adv_img_tensor).to(device)
# ์ ๋Œ€์  ์ด๋ฏธ์ง€ ๋ณ€ํ™˜
adv_img_np = adv_img_tensor.squeeze(0).cpu().numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
adv_img_np = (adv_img_np * std[:, None, None]) + mean[:, None, None]
adv_img_np = np.clip(adv_img_np, 0, 1)
adv_img_np = adv_img_np.transpose(1, 2, 0)
# PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
adv_image_pil = Image.fromarray((adv_img_np * 255).astype(np.uint8))
return adv_image_pil
# ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž… ํ•จ์ˆ˜
def apply_watermark(image_pil, wm_text="ํ…์ŠคํŠธ ์‚ฝ์ž…", password_img=000, password_wm=000):
bwm = WaterMark(password_img=password_img, password_wm=password_wm)
# ์ด๋ฏธ์ง€ ๋ฐ”์ดํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
temp_image_path = "temp_image.png"
image_pil.save(temp_image_path)
# temp_image_path ๊ฒฝ๋กœ๋กœ ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž… ์ฒ˜๋ฆฌ
bwm.read_img(temp_image_path)
bwm.read_wm(wm_text, mode='str')
# ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž…
output_path = "watermarked_image.png"
bwm.embed(output_path)
# ์‚ฝ์ž…๋œ ์›Œํ„ฐ๋งˆํฌ ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ๋‹ค์‹œ ์ฝ์–ด์„œ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
result_image = Image.open(output_path)
# ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
os.remove(temp_image_path)
os.remove(output_path)
return result_image
# ์ „์ฒด ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def process_image(image, eps_value, wm_text, password_img, password_wm):
# ์ ๋Œ€์  ์ด๋ฏธ์ง€ ์ƒ์„ฑ
adv_image = generate_adversarial_image(image, eps_value)
# ์ ๋Œ€์  ์ด๋ฏธ์ง€์— ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž…
watermarked_image = apply_watermark(adv_image, wm_text, int(password_img), int(password_wm))
return watermarked_image
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”"), # ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ํ•„๋“œ
gr.Slider(0.1, 1.0, step=0.1, value=0.3, label="Epsilon ๊ฐ’ ์„ค์ • (๋…ธ์ด์ฆˆ ๊ฐ•๋„)"), # epsilon ๊ฐ’ ์Šฌ๋ผ์ด๋”
gr.Textbox(label="์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ", value="ํ…์ŠคํŠธ ์‚ฝ์ž…"), # ์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ ํ•„๋“œ
gr.Number(label="์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=000), # ์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ ์ž…๋ ฅ ํ•„๋“œ
gr.Number(label="์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=000) # ์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ ์ž…๋ ฅ ํ•„๋“œ
],
outputs=gr.Image(type="pil", label="์›Œํ„ฐ๋งˆํฌ๊ฐ€ ์‚ฝ์ž…๋œ ์ด๋ฏธ์ง€") # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ์ถœ๋ ฅ ํ•„๋“œ
).launch()