Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torchvision | |
import torch.utils.data as data | |
import torchvision.transforms.functional as F | |
from PIL import Image | |
class OpenImageDataset(data.Dataset): | |
def __init__(self, state, dataset_dir, type="paired"): | |
self.state=state | |
self.dataset_dir = dataset_dir | |
self.dataset_list = [] | |
if state == "train": | |
self.dataset_file = os.path.join(dataset_dir, "train_pairs.txt") | |
with open(self.dataset_file, 'r') as f: | |
for line in f.readlines(): | |
person, garment = line.strip().split() | |
self.dataset_list.append([person, person]) | |
if state == "test": | |
self.dataset_file = os.path.join(dataset_dir, "test_pairs.txt") | |
if type == "unpaired": | |
with open(self.dataset_file, 'r') as f: | |
for line in f.readlines(): | |
person, garment = line.strip().split() | |
self.dataset_list.append([person, garment]) | |
if type == "paired": | |
with open(self.dataset_file, 'r') as f: | |
for line in f.readlines(): | |
person, garment = line.strip().split() | |
self.dataset_list.append([person, person]) | |
def __len__(self): | |
return len(self.dataset_list) | |
def __getitem__(self, index): | |
person, garment = self.dataset_list[index] | |
# 确定路径 | |
img_path = os.path.join(self.dataset_dir, self.state, "image", person) | |
reference_path = os.path.join(self.dataset_dir, self.state, "cloth", garment) | |
mask_path = os.path.join(self.dataset_dir, self.state, "mask", person[:-4]+".png") | |
densepose_path = os.path.join(self.dataset_dir, self.state, "image-densepose", person) | |
# 加载图像 | |
img = Image.open(img_path).convert("RGB").resize((512, 512)) | |
img = torchvision.transforms.ToTensor()(img) | |
refernce = Image.open(reference_path).convert("RGB").resize((224, 224)) | |
refernce = torchvision.transforms.ToTensor()(refernce) | |
mask = Image.open(mask_path).convert("L").resize((512, 512)) | |
mask = torchvision.transforms.ToTensor()(mask) | |
mask = 1-mask | |
densepose = Image.open(densepose_path).convert("RGB").resize((512, 512)) | |
densepose = torchvision.transforms.ToTensor()(densepose) | |
# 正则化 | |
img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img) | |
refernce = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711))(refernce) | |
densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose) | |
# 生成 inpaint 和 hint | |
inpaint = img * mask | |
hint = torchvision.transforms.Resize((512, 512))(refernce) | |
hint = torch.cat((hint,densepose),dim = 0) | |
return {"GT": img, # [3, 512, 512] | |
"inpaint_image": inpaint, # [3, 512, 512] | |
"inpaint_mask": mask, # [1, 512, 512] | |
"ref_imgs": refernce, # [3, 224, 224] | |
"hint": hint, # [6, 512, 512] | |
} | |
# if __name__ == "__main__": | |