EditGuard / utils /jpegtest.py
Ricoooo's picture
'folder'
5d21dd2
raw
history blame
1.25 kB
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import random, string
class JpegTest(nn.Module):
def __init__(self, Q=50, subsample=0, path="temp/"):
super(JpegTest, self).__init__()
self.Q = Q
self.subsample = subsample
self.path = path
if not os.path.exists(path): os.mkdir(path)
self.transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def get_path(self):
return self.path + ''.join(random.sample(string.ascii_letters + string.digits, 16)) + ".jpg"
def forward(self, image_cover_mask):
image = image_cover_mask
noised_image = torch.zeros_like(image)
for i in range(image.shape[0]):
single_image = ((image[i].clamp(0, 1).permute(1, 2, 0)) * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
im = Image.fromarray(single_image)
file = self.get_path()
while os.path.exists(file):
file = self.get_path()
im.save(file, format="JPEG", quality=self.Q, subsampling=self.subsample)
jpeg = np.array(Image.open(file), dtype=np.uint8)
os.remove(file)
noised_image[i] = self.transform(jpeg).unsqueeze(0).to(image.device)
return noised_image