onehowon
file uploaded
05c1605
raw
history blame
2.28 kB
import torch
from torchvision import transforms, models
from art.attacks.evasion import FastGradientMethod
from art.estimators.classification import PyTorchClassifier
from PIL import Image
import numpy as np
import io
import base64
from blind_watermark import WaterMark
def load_model():
model = models.resnet50(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load("model.pt", map_location=torch.device('cpu')))
model.eval()
return model
def process_image(inputs: dict):
input_image = inputs["inputs"]
eps_value = inputs.get("eps", 0.3)
model = load_model()
device = torch.device("cpu")
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
classifier = PyTorchClassifier(
model=model,
loss=criterion,
optimizer=optimizer,
input_shape=(3, 64, 64),
nb_classes=10,
)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = Image.open(io.BytesIO(base64.b64decode(input_image))).convert('RGB')
img_tensor = transform(img).unsqueeze(0).to(device)
attack = FastGradientMethod(estimator=classifier, eps=eps_value)
adv_img_tensor = attack.generate(x=img_tensor.cpu().numpy())
adv_img_tensor = torch.tensor(adv_img_tensor).to(device)
adv_img_np = adv_img_tensor.squeeze(0).cpu().numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
adv_img_np = (adv_img_np * std[:, None, None]) + mean[:, None, None]
adv_img_np = np.clip(adv_img_np, 0, 1)
adv_img_np = adv_img_np.transpose(1, 2, 0)
adv_image_pil = Image.fromarray((adv_img_np * 255).astype(np.uint8))
wm_text = "123"
bwm = WaterMark(password_img=123, password_wm=456)
img_bytes = io.BytesIO()
adv_image_pil.save(img_bytes, format='PNG')
bwm.read_img(img_bytes)
bwm.read_wm(wm_text, mode='str')
bwm.embed(img_bytes)
result_image = base64.b64encode(img_bytes.getvalue()).decode('utf-8')
return {"image": result_image}