|
import sys |
|
|
|
sys.path.append('../../') |
|
import argparse |
|
import base64 |
|
from io import BytesIO |
|
from data.file_dataset import FileDataset |
|
from PIL import Image, ImageFile |
|
from torchvision import transforms |
|
from omegaconf import OmegaConf |
|
from models.taming.models.vqgan import GumbelVQ |
|
import os |
|
|
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
ImageFile.MAX_IMAGE_PIXELS = None |
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
from tqdm import tqdm |
|
|
|
class VQGANDataset(Dataset): |
|
def __init__(self, file, selected_cols, skip_convert_images=True, image_root=None, pretraininig=True): |
|
self.reader = FileDataset( |
|
file, |
|
selected_col_ids=selected_cols, |
|
) |
|
|
|
self.skip_convert_images = skip_convert_images |
|
self.image_root = image_root |
|
|
|
if self.skip_convert_images: |
|
self.code_resize_transform = transforms.Compose([ |
|
lambda image: image.convert("RGB"), |
|
transforms.Resize((args.code_image_size,args.code_image_size),interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
preprocess_vqgan |
|
]) |
|
if pretraininig: |
|
self.code_resize_transform = transforms.Compose([ |
|
lambda image: image.convert("RGB"), |
|
transforms.Resize((args.code_image_size,args.code_image_size),interpolation=Image.BICUBIC), |
|
transforms.CenterCrop(int(0.5*args.code_image_size)), |
|
transforms.ToTensor(), |
|
preprocess_vqgan |
|
]) |
|
else: |
|
self.code_resize_transform = transforms.Compose([ |
|
lambda image: image.convert("RGB"), |
|
transforms.Resize(args.code_image_size, interpolation=Image.LANCZOS), |
|
transforms.ToTensor(), |
|
preprocess_vqgan |
|
]) |
|
if pretraininig: |
|
self.code_resize_transform = transforms.Compose([ |
|
lambda image: image.convert("RGB"), |
|
transforms.Resize(args.code_image_size, interpolation=Image.LANCZOS), |
|
transforms.CenterCrop(int(0.5*args.code_image_size)), |
|
transforms.ToTensor(), |
|
preprocess_vqgan |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.reader) |
|
|
|
def __getitem__(self, item): |
|
column_l = self.reader[item] |
|
if len(column_l) == 4: |
|
pair_id, image_id, image, text = column_l |
|
elif len(column_l) == 2: |
|
image_id, image = column_l |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if not self.skip_convert_images: |
|
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))) |
|
else: |
|
if self.image_root is not None: |
|
image = os.path.join(self.image_root, image) |
|
try: |
|
image = Image.open(image) |
|
except PIL.UnidentifiedImageError: |
|
column_l = self.reader[0] |
|
if len(column_l) == 4: |
|
pair_id, image_id, image, text = column_l |
|
elif len(column_l) == 2: |
|
image_id, image = column_l |
|
else: |
|
raise NotImplementedError |
|
image = Image.open(image) |
|
|
|
code_image = self.code_resize_transform(image) |
|
if len(column_l) == 4: |
|
return {"code_image": code_image, "pair_id": pair_id, "image_id": image_id, "text": text} |
|
elif len(column_l) == 2: |
|
return {"code_image": code_image, "image_id": image_id} |
|
|
|
|
|
def custom_to_pil(x): |
|
x = x.detach().cpu() |
|
x = torch.clamp(x, -1., 1.) |
|
x = (x + 1.) / 2. |
|
x = x.permute(1, 2, 0).numpy() |
|
x = (255 * x).astype(np.uint8) |
|
x = Image.fromarray(x) |
|
if not x.mode == "RGB": |
|
x = x.convert("RGB") |
|
return x |
|
|
|
|
|
def map_pixels(x, eps=0.1): |
|
return (1 - 2 * eps) * x + eps |
|
|
|
|
|
def preprocess_vqgan(x): |
|
x = 2. * x - 1. |
|
return x |
|
|
|
|
|
def image_to_base64(img, format): |
|
output_buffer = BytesIO() |
|
img.save(output_buffer, format=format) |
|
byte_data = output_buffer.getvalue() |
|
base64_str = base64.b64encode(byte_data) |
|
base64_str = str(base64_str, encoding='utf-8') |
|
return base64_str |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--file", type=str, default="") |
|
parser.add_argument("--outputs", type=str, default="") |
|
parser.add_argument("--selected_cols", type=str, required=True) |
|
parser.add_argument("--code_image_size", type=int, required=True) |
|
parser.add_argument("--vq_model", type=str, required=True) |
|
parser.add_argument("--vqgan_model_path", type=str, default=None) |
|
parser.add_argument("--vqgan_config_path", type=str, default=None) |
|
parser.add_argument("--log_interval", default=100, type=int, help="log interval") |
|
parser.add_argument("--worker_cnt", type=int, default=1) |
|
parser.add_argument("--local_rank", type=int, default=0) |
|
parser.add_argument("--batch_size", type=int, default=32) |
|
parser.add_argument("--skip_convert_images", type=bool, default=False) |
|
parser.add_argument("--image_root", type=str, default=None) |
|
parser.add_argument("--pretraininig", type=bool, default=True) |
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
vqgan_config = OmegaConf.load(args.vqgan_config_path) |
|
vqgan = GumbelVQ(**vqgan_config.model.params) |
|
sd = torch.load(args.vqgan_model_path, map_location="cpu")["state_dict"] |
|
missing, unexpected = vqgan.load_state_dict(sd, strict=False) |
|
for k, v in vqgan.named_parameters(): |
|
v.requires_grad = False |
|
image_tokenizer = vqgan.cuda().eval() |
|
|
|
writer = open(args.outputs, 'w') |
|
|
|
print("begin process") |
|
|
|
data_cnt = 0 |
|
|
|
dataset = VQGANDataset(args.file, args.selected_cols, skip_convert_images=args.skip_convert_images, |
|
image_root=args.image_root, pretraininig=args.pretraininig) |
|
dataloader = DataLoader(dataset, batch_size=args.batch_size) |
|
num_corrupted = 0 |
|
processed_ids = {} |
|
for data in tqdm(dataloader): |
|
batch_size = data["code_image"].size()[0] |
|
with torch.no_grad(): |
|
z, _, [_, _, image_codes] = image_tokenizer.encode(data["code_image"].cuda()) |
|
image_codes = image_codes.view(batch_size, -1).detach() |
|
|
|
for i, image_code in enumerate(image_codes): |
|
code = ' '.join([str(num) for num in image_code.tolist()]) |
|
if data['image_id'][i] in processed_ids: |
|
continue |
|
|
|
processed_ids[data['image_id'][i]] = 0 |
|
|
|
if len(data.keys()) == 4: |
|
writer.write('\t'.join([data['pair_id'][i], data['image_id'][i], data['text'][i], code])+'\n') |
|
elif len(data.keys()) == 2: |
|
writer.write('\t'.join([data['image_id'][i], code])+'\n') |
|
else: |
|
raise NotImplementedError |
|
|
|
print(len(processed_ids)) |
|
writer.close() |
|
|
|
print("finish") |
|
print('num_corrupted:', num_corrupted) |
|
|