Spaces:
Sleeping
Sleeping
import torch | |
from torchvision import transforms, models | |
from art.attacks.evasion import FastGradientMethod | |
from art.estimators.classification import PyTorchClassifier | |
from PIL import Image | |
import numpy as np | |
import io | |
import base64 | |
from blind_watermark import WaterMark | |
def load_model(): | |
model = models.resnet50(pretrained=False) | |
num_ftrs = model.fc.in_features | |
model.fc = torch.nn.Linear(num_ftrs, 10) | |
model.load_state_dict(torch.load("model.pt", map_location=torch.device('cpu'))) | |
model.eval() | |
return model | |
def process_image(inputs: dict): | |
input_image = inputs["inputs"] | |
eps_value = inputs.get("eps", 0.3) | |
model = load_model() | |
device = torch.device("cpu") | |
model = model.to(device) | |
criterion = torch.nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
classifier = PyTorchClassifier( | |
model=model, | |
loss=criterion, | |
optimizer=optimizer, | |
input_shape=(3, 64, 64), | |
nb_classes=10, | |
) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
img = Image.open(io.BytesIO(base64.b64decode(input_image))).convert('RGB') | |
img_tensor = transform(img).unsqueeze(0).to(device) | |
attack = FastGradientMethod(estimator=classifier, eps=eps_value) | |
adv_img_tensor = attack.generate(x=img_tensor.cpu().numpy()) | |
adv_img_tensor = torch.tensor(adv_img_tensor).to(device) | |
adv_img_np = adv_img_tensor.squeeze(0).cpu().numpy() | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
adv_img_np = (adv_img_np * std[:, None, None]) + mean[:, None, None] | |
adv_img_np = np.clip(adv_img_np, 0, 1) | |
adv_img_np = adv_img_np.transpose(1, 2, 0) | |
adv_image_pil = Image.fromarray((adv_img_np * 255).astype(np.uint8)) | |
wm_text = "123" | |
bwm = WaterMark(password_img=123, password_wm=456) | |
img_bytes = io.BytesIO() | |
adv_image_pil.save(img_bytes, format='PNG') | |
bwm.read_img(img_bytes) | |
bwm.read_wm(wm_text, mode='str') | |
bwm.embed(img_bytes) | |
result_image = base64.b64encode(img_bytes.getvalue()).decode('utf-8') | |
return {"image": result_image} | |