Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torchvision | |
import torch.utils.data as data | |
import torchvision.transforms.functional as F | |
import numpy as np | |
from PIL import Image | |
class OpenImageDataset(data.Dataset): | |
def __init__(self, state, dataset_dir, type="paired"): | |
self.state = state # train or test | |
self.dataset_dir = dataset_dir # /home/sd/zjh/Dataset/DressCode | |
# 确定状态 | |
if state == "train": | |
self.dataset_file = os.path.join(dataset_dir, "train_pairs.txt") | |
if state == "test": | |
if type == "unpaired": | |
self.dataset_file = os.path.join(dataset_dir, "test_pairs_unpaired.txt") | |
if type == "paired": | |
self.dataset_file = os.path.join(dataset_dir, "test_pairs_paired.txt") | |
# 加载数据集 | |
self.people_list = [] | |
self.clothes_list = [] | |
with open(self.dataset_file, 'r') as f: | |
for line in f.readlines(): | |
people, clothes, category = line.strip().split() | |
if category == "0": | |
category = "upper_body" | |
elif category == "1": | |
category = "lower_body" | |
elif category == "2": | |
category = "dresses" | |
people_path = os.path.join(self.dataset_dir, category, "images", people) | |
clothes_path = os.path.join(self.dataset_dir, category, "images", clothes) | |
self.people_list.append(people_path) | |
self.clothes_list.append(clothes_path) | |
def __len__(self): | |
return len(self.people_list) | |
def __getitem__(self, index): | |
people_path = self.people_list[index] | |
# /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_0.jpg | |
clothes_path = self.clothes_list[index] | |
# /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_1.jpg | |
dense_path = people_path.replace("images", "dense")[:-5] + "5_uv.npz" | |
# /home/sd/zjh/Dataset/DressCode/upper_body/dense/000000_5_uv.npz | |
mask_path = people_path.replace("images", "mask")[:-3] + "png" | |
# /home/sd/Harddisk/zjh/DressCode/upper_body/mask/000000_0.png | |
# 加载图像 | |
img = Image.open(people_path).convert("RGB").resize((512, 512)) | |
img = torchvision.transforms.ToTensor()(img) | |
refernce = Image.open(clothes_path).convert("RGB").resize((224, 224)) | |
refernce = torchvision.transforms.ToTensor()(refernce) | |
mask = Image.open(mask_path).convert("L").resize((512, 512)) | |
mask = torchvision.transforms.ToTensor()(mask) | |
mask = 1-mask | |
densepose = np.load(dense_path) | |
densepose = torch.from_numpy(densepose['uv']) | |
densepose = torch.nn.functional.interpolate(densepose.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=True).squeeze(0) | |
# 正则化 | |
img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img) | |
refernce = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711))(refernce) | |
# densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose) | |
# 生成 inpaint 和 hint | |
inpaint = img * mask | |
hint = torchvision.transforms.Resize((512, 512))(refernce) | |
hint = torch.cat((hint,densepose),dim = 0) | |
return {"GT": img, # [3, 512, 512] | |
"inpaint_image": inpaint, # [3, 512, 512] | |
"inpaint_mask": mask, # [1, 512, 512] | |
"ref_imgs": refernce, # [3, 224, 224] | |
"hint": hint # [5, 512, 512] | |
} | |