FPT-VTON / ldm /data /image_vitonhd.py
basso4's picture
Upload 57 files
b6e2095 verified
raw
history blame
7.95 kB
import os
import json
import random
import torch
import torchvision
from torchvision import transforms
import torch.utils.data as data
import torchvision.transforms.functional as TF
from PIL import Image
from typing import Literal, Tuple,List
class OpenImageDataset(data.Dataset):
def __init__(
self,
state: Literal["train", "test"],
dataset_dir: str,
type: Literal["paired", "unpaired"] = "paired",
):
self.state=state
self.dataset_dir = dataset_dir
self.flip_transform = transforms.RandomHorizontalFlip(p=1)
with open(
os.path.join(dataset_dir, state, "vitonhd_" + state + "_tagged.json"), "r"
) as file1:
data1 = json.load(file1)
annotation_list = [
# "colors",
# "textures",
"sleeveLength",
"neckLine",
"item",
]
self.annotations_pair = {}
for k, v in data1.items():
for elem in v:
annotation_str = ""
for template in annotation_list:
for tag in elem["tag_info"]:
if (
tag["tag_name"] == template
and tag["tag_category"] is not None
):
annotation_str += tag["tag_category"]
annotation_str += " "
self.annotations_pair[elem["file_name"]] = annotation_str
im_names = []
c_names = []
if state == "train":
filename = os.path.join(dataset_dir, f"{state}_pairs.txt")
else:
filename = os.path.join(dataset_dir, f"{state}_pairs.txt")
with open(filename, "r") as f:
for line in f.readlines():
if state == "train":
im_name, _ = line.strip().split()
c_name = im_name
else:
if type == "paired":
im_name, _ = line.strip().split()
c_name = im_name
else:
im_name, c_name = line.strip().split()
im_names.append(im_name)
c_names.append(c_name)
self.im_names = im_names
self.c_names = c_names
def __len__(self):
return len(self.im_names)
def __getitem__(self, index):
c_name = self.c_names[index]
im_name = self.im_names[index]
if c_name in self.annotations_pair:
cloth_annotation = self.annotations_pair[c_name]
else:
cloth_annotation = "shirts"
# ็กฎๅฎš่ทฏๅพ„
img_path = os.path.join(self.dataset_dir, self.state, "image", im_name)
reference_path = os.path.join(self.dataset_dir, self.state, "cloth", c_name)
mask_path = os.path.join(self.dataset_dir, self.state, "agnostic-mask", im_name[:-4]+"_mask.png")
densepose_path = os.path.join(self.dataset_dir, self.state, "image-densepose", im_name)
# ๅŠ ่ฝฝๅ›พๅƒ
img = Image.open(img_path).convert("RGB").resize((512, 512))
img = torchvision.transforms.ToTensor()(img)
reference = Image.open(reference_path).convert("RGB").resize((224, 224))
reference = torchvision.transforms.ToTensor()(reference)
mask = Image.open(mask_path).convert("L").resize((512, 512))
mask = torchvision.transforms.ToTensor()(mask)
mask = 1-mask
densepose = Image.open(densepose_path).convert("RGB").resize((512, 512))
densepose = torchvision.transforms.ToTensor()(densepose)
#Data augmentation for training phase
if self.state == "train":
#Random horizontal flip
if random.random() > 0.5:
img = self.flip_transform(img)
mask = self.flip_transform(mask)
densepose = self.flip_transform(densepose)
reference = self.flip_transform(reference)
#Color jittering
if random.random() > 0.5:
color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.5, hue=0.5)
fn_idx, b, c, s, h = transforms.ColorJitter.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation, color_jitter.hue)
img = TF.adjust_contrast(img, c)
img = TF.adjust_brightness(img, b)
img = TF.adjust_hue(img, h)
img = TF.adjust_saturation(img, s)
reference = TF.adjust_contrast(reference, c)
reference = TF.adjust_brightness(reference, b)
reference = TF.adjust_hue(reference, h)
reference = TF.adjust_saturation(reference, s)
#Scaling and shifting
if random.random() > 0.5:
scale_val = random.uniform(0.8, 1.2)
img = transforms.functional.affine(
img, angle=0, translate=[0, 0], scale=scale_val, shear=0
)
mask = transforms.functional.affine(
mask, angle=0, translate=[0, 0], scale=scale_val, shear=0
)
densepose = transforms.functional.affine(
densepose, angle=0, translate=[0, 0], scale=scale_val, shear=0
)
if random.random() > 0.5:
shift_valx = random.uniform(-0.2, 0.2)
shift_valy = random.uniform(-0.2, 0.2)
img = transforms.functional.affine(
img,
angle=0,
translate=[shift_valx * img.shape[-1], shift_valy * img.shape[-2]],
scale=1,
shear=0
)
mask = transforms.functional.affine(
mask,
angle=0,
translate=[shift_valx * mask.shape[-1], shift_valy * mask.shape[-2]],
scale=1,
shear=0
)
densepose = transforms.functional.affine(
densepose,
angle=0,
translate=[
shift_valx * densepose.shape[-1],
shift_valy * densepose.shape[-2]
],
scale=1,
shear=0
)
# ๆญฃๅˆ™ๅŒ–
img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
reference = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))(reference)
densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose)
# ็”Ÿๆˆ inpaint ๅ’Œ hint
inpaint = img * mask
hint = torchvision.transforms.Resize((512, 512))(reference)
hint = torch.cat((hint,densepose),dim = 0)
cloth_annotation = "a photo of " + cloth_annotation
return {"GT": img, # [3, 512, 512]
"inpaint_image": inpaint, # [3, 512, 512]
"inpaint_mask": mask, # [1, 512, 512]
"ref_imgs": reference, # [3, 224, 224]
"hint": hint, # [6, 512, 512]
"caption_cloth": cloth_annotation,
# "caption": "model is wearing " + cloth_annotation,
}