Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import transforms, models | |
from art.attacks.evasion import FastGradientMethod | |
from art.estimators.classification import PyTorchClassifier | |
from PIL import Image | |
import numpy as np | |
import os | |
import io | |
from blind_watermark import WaterMark | |
# Pretrained ResNet50 ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ (ImageNet ์ฌ์ ํ๋ จ) | |
model = models.resnet50(pretrained=True) | |
# CIFAR-10์ ๋ง์ถฐ ๋ง์ง๋ง ๋ถ๋ฅ ๋ ์ด์ด ์์ | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, 10) | |
# ๋ชจ๋ธ์ GPU๋ก ์ด๋ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# ์์ค ํจ์์ ์ตํฐ๋ง์ด์ ์ค์ | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
# PyTorchClassifier ์์ฑ | |
classifier = PyTorchClassifier( | |
model=model, | |
loss=criterion, | |
optimizer=optimizer, | |
input_shape=(3, 64, 64), | |
nb_classes=10, | |
) | |
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์ | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(image).unsqueeze(0).to(device) | |
# FGSM ๊ณต๊ฒฉ ์ ์ฉ ๋ฐ ์ด๋ฏธ์ง ์ฒ๋ฆฌ ํจ์ | |
def generate_adversarial_image(image, eps_value): | |
img_tensor = preprocess_image(image) | |
# FGSM ๊ณต๊ฒฉ ์ค์ | |
attack = FastGradientMethod(estimator=classifier, eps=eps_value) | |
# ์ ๋์ ์์ ์์ฑ | |
adv_img_tensor = attack.generate(x=img_tensor.cpu().numpy()) | |
adv_img_tensor = torch.tensor(adv_img_tensor).to(device) | |
# ์ ๋์ ์ด๋ฏธ์ง ๋ณํ | |
adv_img_np = adv_img_tensor.squeeze(0).cpu().numpy() | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
adv_img_np = (adv_img_np * std[:, None, None]) + mean[:, None, None] | |
adv_img_np = np.clip(adv_img_np, 0, 1) | |
adv_img_np = adv_img_np.transpose(1, 2, 0) | |
# PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
adv_image_pil = Image.fromarray((adv_img_np * 255).astype(np.uint8)) | |
return adv_image_pil | |
# ์ํฐ๋งํฌ ์ฝ์ ํจ์ | |
def apply_watermark(image_pil, wm_text="ํ ์คํธ ์ฝ์ ", password_img=000, password_wm=000): | |
bwm = WaterMark(password_img=password_img, password_wm=password_wm) | |
# ์ด๋ฏธ์ง ๋ฐ์ดํธ ๋ฐ์ดํฐ๋ฅผ ์์ ํ์ผ๋ก ์ ์ฅ | |
temp_image_path = "temp_image.png" | |
image_pil.save(temp_image_path) | |
# temp_image_path ๊ฒฝ๋ก๋ก ์ํฐ๋งํฌ ์ฝ์ ์ฒ๋ฆฌ | |
bwm.read_img(temp_image_path) | |
bwm.read_wm(wm_text, mode='str') | |
# ์ํฐ๋งํฌ ์ฝ์ | |
output_path = "watermarked_image.png" | |
bwm.embed(output_path) | |
# ์ฝ์ ๋ ์ํฐ๋งํฌ ์ด๋ฏธ์ง ํ์ผ์ ๋ค์ ์ฝ์ด์ PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
result_image = Image.open(output_path) | |
# ์์ ํ์ผ ์ญ์ | |
os.remove(temp_image_path) | |
os.remove(output_path) | |
return result_image | |
# ์ ์ฒด ์ด๋ฏธ์ง ์ฒ๋ฆฌ ํจ์ | |
def process_image(image, eps_value, wm_text, password_img, password_wm): | |
# ์ ๋์ ์ด๋ฏธ์ง ์์ฑ | |
adv_image = generate_adversarial_image(image, eps_value) | |
# ์ ๋์ ์ด๋ฏธ์ง์ ์ํฐ๋งํฌ ์ฝ์ | |
watermarked_image = apply_watermark(adv_image, wm_text, int(password_img), int(password_wm)) | |
return watermarked_image | |
# Gradio ์ธํฐํ์ด์ค ์ ์ | |
gr.Interface( | |
fn=process_image, | |
inputs=[gr.Image(type="pil", label="์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํ์ธ์"), # ์ด๋ฏธ์ง ์ ๋ก๋ ํ๋ | |
gr.Slider(0.1, 1.0, step=0.1, value=0.3, label="Epsilon ๊ฐ ์ค์ (๋ ธ์ด์ฆ ๊ฐ๋)"), # epsilon ๊ฐ ์ฌ๋ผ์ด๋ | |
gr.Textbox(label="์ํฐ๋งํฌ ํ ์คํธ ์ ๋ ฅ", value="ํ ์คํธ ์ฝ์ "), # ์ํฐ๋งํฌ ํ ์คํธ ์ ๋ ฅ ํ๋ | |
gr.Number(label="์ด๋ฏธ์ง ๋น๋ฐ๋ฒํธ", value=000), # ์ด๋ฏธ์ง ๋น๋ฐ๋ฒํธ ์ ๋ ฅ ํ๋ | |
gr.Number(label="์ํฐ๋งํฌ ๋น๋ฐ๋ฒํธ", value=000) # ์ํฐ๋งํฌ ๋น๋ฐ๋ฒํธ ์ ๋ ฅ ํ๋ | |
], | |
outputs=gr.Image(type="pil", label="์ํฐ๋งํฌ๊ฐ ์ฝ์ ๋ ์ด๋ฏธ์ง") # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง ์ถ๋ ฅ ํ๋ | |
).launch() | |