# File havily based on https://github.com/aimagelab/dress-code/blob/main/data/dataset.py import json import os import pathlib import random import sys from typing import Tuple PROJECT_ROOT = pathlib.Path(__file__).absolute().parents[2].absolute() sys.path.insert(0, str(PROJECT_ROOT)) import numpy as np import torch import torch.utils.data as data import torchvision.transforms as transforms from PIL import Image, ImageDraw, ImageOps from torchvision.ops import masks_to_boxes from src.utils.posemap import get_coco_body25_mapping from src.utils.posemap import kpoint_to_heatmap class VitonHDDataset(data.Dataset): def __init__( self, dataroot_path: str, phase: str, tokenizer, radius=5, caption_folder='captions.json', sketch_threshold_range: Tuple[int, int] = (20, 127), order: str = 'paired', outputlist: Tuple[str] = ('c_name', 'im_name', 'image', 'im_cloth', 'shape', 'pose_map', 'parse_array', 'im_mask', 'inpaint_mask', 'parse_mask_total', 'im_sketch', 'captions', 'original_captions'), size: Tuple[int, int] = (512, 384), ): super(VitonHDDataset, self).__init__() self.dataroot = dataroot_path self.phase = phase self.caption_folder = caption_folder self.sketch_threshold_range = sketch_threshold_range self.category = ('upper_body') self.outputlist = outputlist self.height = size[0] self.width = size[1] self.radius = radius self.tokenizer = tokenizer self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.transform2D = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) self.order = order im_names = [] c_names = [] dataroot_names = [] possible_outputs = ['c_name', 'im_name', 'image', 'im_cloth', 'shape', 'im_head', 'im_pose', 'pose_map', 'parse_array', 'im_mask', 'inpaint_mask', 'parse_mask_total', 'im_sketch', 'captions', 'original_captions', 'category'] assert all(x in possible_outputs for x in outputlist) # Load Captions with open(os.path.join(self.dataroot, self.caption_folder)) as f: # self.captions_dict = json.load(f)['items'] self.captions_dict = json.load(f) self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3} dataroot = self.dataroot if phase == 'train': filename = os.path.join(dataroot, f"{phase}_pairs.txt") else: filename = os.path.join(dataroot, f"{phase}_pairs.txt") with open(filename, 'r') as f: data_len = len(f.readlines()) with open(filename, 'r') as f: for line in f.readlines(): if phase == 'train': im_name, _ = line.strip().split() c_name = im_name else: if order == 'paired': im_name, _ = line.strip().split() c_name = im_name else: im_name, c_name = line.strip().split() im_names.append(im_name) c_names.append(c_name) dataroot_names.append(dataroot) self.im_names = im_names self.c_names = c_names self.dataroot_names = dataroot_names def __getitem__(self, index): """ For each index return the corresponding sample in the dataset :param index: data index :type index: int :return: dict containing dataset samples :rtype: dict """ c_name = self.c_names[index] im_name = self.im_names[index] dataroot = self.dataroot_names[index] sketch_threshold = random.randint(self.sketch_threshold_range[0], self.sketch_threshold_range[1]) if "captions" in self.outputlist or "original_captions" in self.outputlist: captions = self.captions_dict[c_name.split('_')[0]] # take a random caption if there are multiple if self.phase == 'train': random.shuffle(captions) captions = ", ".join(captions) original_captions = captions if "captions" in self.outputlist: cond_input = self.tokenizer([captions], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids cond_input = cond_input.squeeze(0) max_length = cond_input.shape[-1] uncond_input = self.tokenizer( [""], padding="max_length", max_length=max_length, return_tensors="pt" ).input_ids.squeeze(0) captions = cond_input captions_uncond = uncond_input if "image" in self.outputlist or "im_head" in self.outputlist or "im_cloth" in self.outputlist: # Person image # image = Image.open(os.path.join(dataroot, 'images', im_name)) image = Image.open(os.path.join(dataroot, self.phase, 'image', im_name)) image = image.resize((self.width, self.height)) image = self.transform(image) # [-1,1] if "im_sketch" in self.outputlist: # Person image # im_sketch = Image.open(os.path.join(dataroot, 'im_sketch', c_name.replace(".jpg", ".png"))) if self.order == 'unpaired': im_sketch = Image.open( os.path.join(dataroot, self.phase, 'im_sketch_unpaired', os.path.splitext(im_name)[0] + '_' + c_name.replace(".jpg", ".png"))) elif self.order == 'paired': im_sketch = Image.open(os.path.join(dataroot, self.phase, 'im_sketch', im_name.replace(".jpg", ".png"))) else: raise ValueError( f"Order should be either paired or unpaired" ) im_sketch = im_sketch.resize((self.width, self.height)) im_sketch = ImageOps.invert(im_sketch) # threshold grayscale pil image im_sketch = im_sketch.point(lambda p: 255 if p > sketch_threshold else 0) # im_sketch = im_sketch.convert("RGB") im_sketch = transforms.functional.to_tensor(im_sketch) # [-1,1] im_sketch = 1 - im_sketch if "im_pose" in self.outputlist or "parser_mask" in self.outputlist or "im_mask" in self.outputlist or "parse_mask_total" in self.outputlist or "parse_array" in self.outputlist or "pose_map" in self.outputlist or "parse_array" in self.outputlist or "shape" in self.outputlist or "im_head" in self.outputlist: # Label Map # parse_name = im_name.replace('_0.jpg', '_4.png') parse_name = im_name.replace('.jpg', '.png') im_parse = Image.open(os.path.join(dataroot, self.phase, 'image-parse-v3', parse_name)) im_parse = im_parse.resize((self.width, self.height), Image.NEAREST) im_parse_final = transforms.ToTensor()(im_parse) * 255 parse_array = np.array(im_parse) parse_shape = (parse_array > 0).astype(np.float32) parse_head = (parse_array == 1).astype(np.float32) + \ (parse_array == 2).astype(np.float32) + \ (parse_array == 4).astype(np.float32) + \ (parse_array == 13).astype(np.float32) parser_mask_fixed = (parse_array == 1).astype(np.float32) + \ (parse_array == 2).astype(np.float32) + \ (parse_array == 18).astype(np.float32) + \ (parse_array == 19).astype(np.float32) # parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32) parser_mask_changeable = (parse_array == 0).astype(np.float32) arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32) parse_cloth = (parse_array == 5).astype(np.float32) + \ (parse_array == 6).astype(np.float32) + \ (parse_array == 7).astype(np.float32) parse_mask = (parse_array == 5).astype(np.float32) + \ (parse_array == 6).astype(np.float32) + \ (parse_array == 7).astype(np.float32) parser_mask_fixed = parser_mask_fixed + (parse_array == 9).astype(np.float32) + \ (parse_array == 12).astype(np.float32) # the lower body is fixed parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) parse_head = torch.from_numpy(parse_head) # [0,1] parse_cloth = torch.from_numpy(parse_cloth) # [0,1] parse_mask = torch.from_numpy(parse_mask) # [0,1] parser_mask_fixed = torch.from_numpy(parser_mask_fixed) parser_mask_changeable = torch.from_numpy(parser_mask_changeable) # dilation parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask)) parse_mask = parse_mask.cpu().numpy() if "im_head" in self.outputlist: # Masked cloth im_head = image * parse_head - (1 - parse_head) if "im_cloth" in self.outputlist: im_cloth = image * parse_cloth + (1 - parse_cloth) # Shape parse_shape = Image.fromarray((parse_shape * 255).astype(np.uint8)) parse_shape = parse_shape.resize((self.width // 16, self.height // 16), Image.BILINEAR) parse_shape = parse_shape.resize((self.width, self.height), Image.BILINEAR) shape = self.transform2D(parse_shape) # [-1,1] # Load pose points pose_name = im_name.replace('.jpg', '_keypoints.json') with open(os.path.join(dataroot, self.phase, 'openpose_json', pose_name), 'r') as f: pose_label = json.load(f) pose_data = pose_label['people'][0]['pose_keypoints_2d'] pose_data = np.array(pose_data) pose_data = pose_data.reshape((-1, 3))[:, :2] # rescale keypoints on the base of height and width pose_data[:, 0] = pose_data[:, 0] * (self.width / 768) pose_data[:, 1] = pose_data[:, 1] * (self.height / 1024) pose_mapping = get_coco_body25_mapping() point_num = len(pose_mapping) pose_map = torch.zeros(point_num, self.height, self.width) r = self.radius * (self.height / 512.0) im_pose = Image.new('L', (self.width, self.height)) pose_draw = ImageDraw.Draw(im_pose) neck = Image.new('L', (self.width, self.height)) neck_draw = ImageDraw.Draw(neck) for i in range(point_num): one_map = Image.new('L', (self.width, self.height)) draw = ImageDraw.Draw(one_map) point_x = np.multiply(pose_data[pose_mapping[i], 0], 1) point_y = np.multiply(pose_data[pose_mapping[i], 1], 1) if point_x > 1 and point_y > 1: draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') pose_draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') if i == 2 or i == 5: neck_draw.ellipse((point_x - r * 4, point_y - r * 4, point_x + r * 4, point_y + r * 4), 'white', 'white') one_map = self.transform2D(one_map) pose_map[i] = one_map[0] d = [] for idx in range(point_num): ux = pose_data[pose_mapping[idx], 0] # / (192) uy = (pose_data[pose_mapping[idx], 1]) # / (256) # scale posemap points px = ux # * self.width py = uy # * self.height d.append(kpoint_to_heatmap(np.array([px, py]), (self.height, self.width), 9)) pose_map = torch.stack(d) # just for visualization im_pose = self.transform2D(im_pose) im_arms = Image.new('L', (self.width, self.height)) arms_draw = ImageDraw.Draw(im_arms) # do in any case because i have only upperbody with open(os.path.join(dataroot, self.phase, 'openpose_json', pose_name), 'r') as f: data = json.load(f) data = data['people'][0]['pose_keypoints_2d'] data = np.array(data) data = data.reshape((-1, 3))[:, :2] # rescale keypoints on the base of height and width data[:, 0] = data[:, 0] * (self.width / 768) data[:, 1] = data[:, 1] * (self.height / 1024) shoulder_right = np.multiply(tuple(data[pose_mapping[2]]), 1) shoulder_left = np.multiply(tuple(data[pose_mapping[5]]), 1) elbow_right = np.multiply(tuple(data[pose_mapping[3]]), 1) elbow_left = np.multiply(tuple(data[pose_mapping[6]]), 1) wrist_right = np.multiply(tuple(data[pose_mapping[4]]), 1) wrist_left = np.multiply(tuple(data[pose_mapping[7]]), 1) ARM_LINE_WIDTH = int(90 / 512 * self.height) if wrist_right[0] <= 1. and wrist_right[1] <= 1.: if elbow_right[0] <= 1. and elbow_right[1] <= 1.: arms_draw.line( np.concatenate((wrist_left, elbow_left, shoulder_left, shoulder_right)).astype( np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') else: arms_draw.line(np.concatenate( (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right)).astype( np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') elif wrist_left[0] <= 1. and wrist_left[1] <= 1.: if elbow_left[0] <= 1. and elbow_left[1] <= 1.: arms_draw.line( np.concatenate((shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') else: arms_draw.line(np.concatenate( (elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') else: arms_draw.line(np.concatenate( (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') hands = np.logical_and(np.logical_not(im_arms), arms) parse_mask += im_arms parser_mask_fixed += hands # delete neck parse_head_2 = torch.clone(parse_head) parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16)) parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16), np.logical_not( np.array(parse_head_2, dtype=np.uint16)))) parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask)) parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) # im_mask = image * parse_mask_total inpaint_mask = 1 - parse_mask_total # here we have to modify the mask and get the bounding box bboxes = masks_to_boxes(inpaint_mask.unsqueeze(0)) bboxes = bboxes.type(torch.int32) # xmin, ymin, xmax, ymax format xmin = bboxes[0, 0] xmax = bboxes[0, 2] ymin = bboxes[0, 1] ymax = bboxes[0, 3] inpaint_mask[ymin:ymax + 1, xmin:xmax + 1] = torch.logical_and( torch.ones_like(inpaint_mask[ymin:ymax + 1, xmin:xmax + 1]), torch.logical_not(parser_mask_fixed[ymin:ymax + 1, xmin:xmax + 1])) inpaint_mask = inpaint_mask.unsqueeze(0) im_mask = image * np.logical_not(inpaint_mask.repeat(3, 1, 1)) parse_mask_total = parse_mask_total.numpy() parse_mask_total = parse_array * parse_mask_total parse_mask_total = torch.from_numpy(parse_mask_total) result = {} for k in self.outputlist: result[k] = vars()[k] result['im_parse'] = im_parse_final result['hands'] = torch.from_numpy(hands) # Output interpretation # "c_name" -> filename of inshop cloth # "im_name" -> filename of model with cloth # "cloth" -> img of inshop cloth # "image" -> img of the model with that cloth # "im_cloth" -> cut cloth from the model # "im_mask" -> black mask of the cloth in the model img # "cloth_sketch" -> sketch of the inshop cloth # "im_sketch" -> sketch of "im_cloth" # inpaint_mask -> bb of the model img where the cloth is # ... return result def __len__(self): return len(self.c_names)