fptvton1 / ldm /data /image_dresscode.py
basso4's picture
Upload 1471 files
adf1965 verified
import os
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms.functional as F
import numpy as np
from PIL import Image
class OpenImageDataset(data.Dataset):
def __init__(self, state, dataset_dir, type="paired"):
self.state = state # train or test
self.dataset_dir = dataset_dir # /home/sd/zjh/Dataset/DressCode
# 确定状态
if state == "train":
self.dataset_file = os.path.join(dataset_dir, "train_pairs.txt")
if state == "test":
if type == "unpaired":
self.dataset_file = os.path.join(dataset_dir, "test_pairs_unpaired.txt")
if type == "paired":
self.dataset_file = os.path.join(dataset_dir, "test_pairs_paired.txt")
# 加载数据集
self.people_list = []
self.clothes_list = []
with open(self.dataset_file, 'r') as f:
for line in f.readlines():
people, clothes, category = line.strip().split()
if category == "0":
category = "upper_body"
elif category == "1":
category = "lower_body"
elif category == "2":
category = "dresses"
people_path = os.path.join(self.dataset_dir, category, "images", people)
clothes_path = os.path.join(self.dataset_dir, category, "images", clothes)
self.people_list.append(people_path)
self.clothes_list.append(clothes_path)
def __len__(self):
return len(self.people_list)
def __getitem__(self, index):
people_path = self.people_list[index]
# /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_0.jpg
clothes_path = self.clothes_list[index]
# /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_1.jpg
dense_path = people_path.replace("images", "dense")[:-5] + "5_uv.npz"
# /home/sd/zjh/Dataset/DressCode/upper_body/dense/000000_5_uv.npz
mask_path = people_path.replace("images", "mask")[:-3] + "png"
# /home/sd/Harddisk/zjh/DressCode/upper_body/mask/000000_0.png
# 加载图像
img = Image.open(people_path).convert("RGB").resize((512, 512))
img = torchvision.transforms.ToTensor()(img)
refernce = Image.open(clothes_path).convert("RGB").resize((224, 224))
refernce = torchvision.transforms.ToTensor()(refernce)
mask = Image.open(mask_path).convert("L").resize((512, 512))
mask = torchvision.transforms.ToTensor()(mask)
mask = 1-mask
densepose = np.load(dense_path)
densepose = torch.from_numpy(densepose['uv'])
densepose = torch.nn.functional.interpolate(densepose.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=True).squeeze(0)
# 正则化
img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
refernce = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))(refernce)
# densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose)
# 生成 inpaint 和 hint
inpaint = img * mask
hint = torchvision.transforms.Resize((512, 512))(refernce)
hint = torch.cat((hint,densepose),dim = 0)
return {"GT": img, # [3, 512, 512]
"inpaint_image": inpaint, # [3, 512, 512]
"inpaint_mask": mask, # [1, 512, 512]
"ref_imgs": refernce, # [3, 224, 224]
"hint": hint # [5, 512, 512]
}