import sys sys.path.append('../../') import argparse import base64 from io import BytesIO from data.file_dataset import FileDataset from PIL import Image, ImageFile from torchvision import transforms from omegaconf import OmegaConf from models.taming.models.vqgan import GumbelVQ import os import torch from torch.utils.data import Dataset, DataLoader import numpy as np ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.MAX_IMAGE_PIXELS = None Image.MAX_IMAGE_PIXELS = None class VQGANDataset(Dataset): def __init__(self, file, selected_cols, skip_convert_images=True, image_root=None): self.reader = FileDataset( file, selected_col_ids=selected_cols, ) self.skip_convert_images = skip_convert_images self.image_root = image_root if self.skip_convert_images: self.code_resize_transform = transforms.Compose([ lambda image: image.convert("RGB"), transforms.Resize((args.code_image_size,args.code_image_size),interpolation=Image.BICUBIC), transforms.ToTensor(), preprocess_vqgan ]) else: self.code_resize_transform = transforms.Compose([ lambda image: image.convert("RGB"), transforms.Resize(args.code_image_size, interpolation=Image.LANCZOS), transforms.ToTensor(), preprocess_vqgan ]) def __len__(self): return len(self.reader) def __getitem__(self, item): column_l = self.reader[item] if len(column_l) == 4: pair_id, image_id, image, text = column_l elif len(column_l) == 2: image_id, image = column_l else: raise NotImplementedError if not self.skip_convert_images: image = Image.open(BytesIO(base64.urlsafe_b64decode(image))) else: if self.image_root is not None: image = os.path.join(self.image_root, image) image = Image.open(image) code_image = self.code_resize_transform(image) if len(column_l) == 4: return {"code_image": code_image, "pair_id": pair_id, "image_id": image_id, "text": text} elif len(column_l) == 2: return {"code_image": code_image, "image_id": image_id} def custom_to_pil(x): x = x.detach().cpu() x = torch.clamp(x, -1., 1.) x = (x + 1.) / 2. x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) if not x.mode == "RGB": x = x.convert("RGB") return x def map_pixels(x, eps=0.1): return (1 - 2 * eps) * x + eps def preprocess_vqgan(x): x = 2. * x - 1. return x def image_to_base64(img, format): output_buffer = BytesIO() img.save(output_buffer, format=format) byte_data = output_buffer.getvalue() base64_str = base64.b64encode(byte_data) base64_str = str(base64_str, encoding='utf-8') return base64_str if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--file", type=str, default="") parser.add_argument("--outputs", type=str, default="") parser.add_argument("--selected_cols", type=str, required=True) parser.add_argument("--code_image_size", type=int, required=True) parser.add_argument("--vq_model", type=str, required=True) parser.add_argument("--vqgan_model_path", type=str, default=None) parser.add_argument("--vqgan_config_path", type=str, default=None) parser.add_argument("--log_interval", default=100, type=int, help="log interval") parser.add_argument("--worker_cnt", type=int, default=1) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--skip_convert_images", type=bool, default=False) parser.add_argument("--image_root", type=str, default=None) args = parser.parse_args() vqgan_config = OmegaConf.load(args.vqgan_config_path) vqgan = GumbelVQ(**vqgan_config.model.params) sd = torch.load(args.vqgan_model_path, map_location="cpu")["state_dict"] missing, unexpected = vqgan.load_state_dict(sd, strict=False) for k, v in vqgan.named_parameters(): v.requires_grad = False image_tokenizer = vqgan.cuda().eval() writer = open(args.outputs, 'w') print("begin process") data_cnt = 0 dataset = VQGANDataset(args.file, args.selected_cols, skip_convert_images=args.skip_convert_images, image_root=args.image_root) dataloader = DataLoader(dataset, batch_size=args.batch_size) for data in dataloader: batch_size = data["code_image"].size()[0] with torch.no_grad(): z, _, [_, _, image_codes] = image_tokenizer.encode(data["code_image"].cuda()) image_codes = image_codes.view(batch_size, -1).detach() for i, image_code in enumerate(image_codes): code = ' '.join([str(num) for num in image_code.tolist()]) if len(data.keys()) == 4: writer.write('\t'.join([data['pair_id'][i], data['image_id'][i], data['text'][i], code])+'\n') elif len(data.keys()) == 2: writer.write('\t'.join([data['image_id'][i], code])+'\n') else: raise NotImplementedError writer.close() print("finish")