Spaces:
Build error
Build error
from torch.utils.data import Dataset | |
import numpy as np | |
import os | |
import random | |
import torchvision.transforms as transforms | |
from PIL import Image, ImageOps | |
import cv2 | |
import torch | |
from PIL.ImageFilter import GaussianBlur | |
import trimesh | |
import logging | |
log = logging.getLogger('trimesh') | |
log.setLevel(40) | |
def load_trimesh(root_dir): | |
folders = os.listdir(root_dir) | |
meshs = {} | |
for i, f in enumerate(folders): | |
sub_name = f | |
meshs[sub_name] = trimesh.load(os.path.join(root_dir, f, '%s_100k.obj' % sub_name)) | |
return meshs | |
def save_samples_truncted_prob(fname, points, prob): | |
''' | |
Save the visualization of sampling to a ply file. | |
Red points represent positive predictions. | |
Green points represent negative predictions. | |
:param fname: File name to save | |
:param points: [N, 3] array of points | |
:param prob: [N, 1] array of predictions in the range [0~1] | |
:return: | |
''' | |
r = (prob > 0.5).reshape([-1, 1]) * 255 | |
g = (prob < 0.5).reshape([-1, 1]) * 255 | |
b = np.zeros(r.shape) | |
to_save = np.concatenate([points, r, g, b], axis=-1) | |
return np.savetxt(fname, | |
to_save, | |
fmt='%.6f %.6f %.6f %d %d %d', | |
comments='', | |
header=( | |
'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format( | |
points.shape[0]) | |
) | |
class TrainDataset(Dataset): | |
def modify_commandline_options(parser, is_train): | |
return parser | |
def __init__(self, opt, phase='train'): | |
self.opt = opt | |
self.projection_mode = 'orthogonal' | |
# Path setup | |
self.root = self.opt.dataroot | |
self.RENDER = os.path.join(self.root, 'RENDER') | |
self.MASK = os.path.join(self.root, 'MASK') | |
self.PARAM = os.path.join(self.root, 'PARAM') | |
self.UV_MASK = os.path.join(self.root, 'UV_MASK') | |
self.UV_NORMAL = os.path.join(self.root, 'UV_NORMAL') | |
self.UV_RENDER = os.path.join(self.root, 'UV_RENDER') | |
self.UV_POS = os.path.join(self.root, 'UV_POS') | |
self.OBJ = os.path.join(self.root, 'GEO', 'OBJ') | |
self.B_MIN = np.array([-128, -28, -128]) | |
self.B_MAX = np.array([128, 228, 128]) | |
self.is_train = (phase == 'train') | |
self.load_size = self.opt.loadSize | |
self.num_views = self.opt.num_views | |
self.num_sample_inout = self.opt.num_sample_inout | |
self.num_sample_color = self.opt.num_sample_color | |
self.yaw_list = list(range(0,360,1)) | |
self.pitch_list = [0] | |
self.subjects = self.get_subjects() | |
# PIL to tensor | |
self.to_tensor = transforms.Compose([ | |
transforms.Resize(self.load_size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# augmentation | |
self.aug_trans = transforms.Compose([ | |
transforms.ColorJitter(brightness=opt.aug_bri, contrast=opt.aug_con, saturation=opt.aug_sat, | |
hue=opt.aug_hue) | |
]) | |
self.mesh_dic = load_trimesh(self.OBJ) | |
def get_subjects(self): | |
all_subjects = os.listdir(self.RENDER) | |
var_subjects = np.loadtxt(os.path.join(self.root, 'val.txt'), dtype=str) | |
if len(var_subjects) == 0: | |
return all_subjects | |
if self.is_train: | |
return sorted(list(set(all_subjects) - set(var_subjects))) | |
else: | |
return sorted(list(var_subjects)) | |
def __len__(self): | |
return len(self.subjects) * len(self.yaw_list) * len(self.pitch_list) | |
def get_render(self, subject, num_views, yid=0, pid=0, random_sample=False): | |
''' | |
Return the render data | |
:param subject: subject name | |
:param num_views: how many views to return | |
:param view_id: the first view_id. If None, select a random one. | |
:return: | |
'img': [num_views, C, W, H] images | |
'calib': [num_views, 4, 4] calibration matrix | |
'extrinsic': [num_views, 4, 4] extrinsic matrix | |
'mask': [num_views, 1, W, H] masks | |
''' | |
pitch = self.pitch_list[pid] | |
# The ids are an even distribution of num_views around view_id | |
view_ids = [self.yaw_list[(yid + len(self.yaw_list) // num_views * offset) % len(self.yaw_list)] | |
for offset in range(num_views)] | |
if random_sample: | |
view_ids = np.random.choice(self.yaw_list, num_views, replace=False) | |
calib_list = [] | |
render_list = [] | |
mask_list = [] | |
extrinsic_list = [] | |
for vid in view_ids: | |
param_path = os.path.join(self.PARAM, subject, '%d_%d_%02d.npy' % (vid, pitch, 0)) | |
render_path = os.path.join(self.RENDER, subject, '%d_%d_%02d.jpg' % (vid, pitch, 0)) | |
mask_path = os.path.join(self.MASK, subject, '%d_%d_%02d.png' % (vid, pitch, 0)) | |
# loading calibration data | |
param = np.load(param_path, allow_pickle=True) | |
# pixel unit / world unit | |
ortho_ratio = param.item().get('ortho_ratio') | |
# world unit / model unit | |
scale = param.item().get('scale') | |
# camera center world coordinate | |
center = param.item().get('center') | |
# model rotation | |
R = param.item().get('R') | |
translate = -np.matmul(R, center).reshape(3, 1) | |
extrinsic = np.concatenate([R, translate], axis=1) | |
extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0) | |
# Match camera space to image pixel space | |
scale_intrinsic = np.identity(4) | |
scale_intrinsic[0, 0] = scale / ortho_ratio | |
scale_intrinsic[1, 1] = -scale / ortho_ratio | |
scale_intrinsic[2, 2] = scale / ortho_ratio | |
# Match image pixel space to image uv space | |
uv_intrinsic = np.identity(4) | |
uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2) | |
uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2) | |
uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2) | |
# Transform under image pixel space | |
trans_intrinsic = np.identity(4) | |
mask = Image.open(mask_path).convert('L') | |
render = Image.open(render_path).convert('RGB') | |
if self.is_train: | |
# Pad images | |
pad_size = int(0.1 * self.load_size) | |
render = ImageOps.expand(render, pad_size, fill=0) | |
mask = ImageOps.expand(mask, pad_size, fill=0) | |
w, h = render.size | |
th, tw = self.load_size, self.load_size | |
# random flip | |
if self.opt.random_flip and np.random.rand() > 0.5: | |
scale_intrinsic[0, 0] *= -1 | |
render = transforms.RandomHorizontalFlip(p=1.0)(render) | |
mask = transforms.RandomHorizontalFlip(p=1.0)(mask) | |
# random scale | |
if self.opt.random_scale: | |
rand_scale = random.uniform(0.9, 1.1) | |
w = int(rand_scale * w) | |
h = int(rand_scale * h) | |
render = render.resize((w, h), Image.BILINEAR) | |
mask = mask.resize((w, h), Image.NEAREST) | |
scale_intrinsic *= rand_scale | |
scale_intrinsic[3, 3] = 1 | |
# random translate in the pixel space | |
if self.opt.random_trans: | |
dx = random.randint(-int(round((w - tw) / 10.)), | |
int(round((w - tw) / 10.))) | |
dy = random.randint(-int(round((h - th) / 10.)), | |
int(round((h - th) / 10.))) | |
else: | |
dx = 0 | |
dy = 0 | |
trans_intrinsic[0, 3] = -dx / float(self.opt.loadSize // 2) | |
trans_intrinsic[1, 3] = -dy / float(self.opt.loadSize // 2) | |
x1 = int(round((w - tw) / 2.)) + dx | |
y1 = int(round((h - th) / 2.)) + dy | |
render = render.crop((x1, y1, x1 + tw, y1 + th)) | |
mask = mask.crop((x1, y1, x1 + tw, y1 + th)) | |
render = self.aug_trans(render) | |
# random blur | |
if self.opt.aug_blur > 0.00001: | |
blur = GaussianBlur(np.random.uniform(0, self.opt.aug_blur)) | |
render = render.filter(blur) | |
intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic)) | |
calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float() | |
extrinsic = torch.Tensor(extrinsic).float() | |
mask = transforms.Resize(self.load_size)(mask) | |
mask = transforms.ToTensor()(mask).float() | |
mask_list.append(mask) | |
render = self.to_tensor(render) | |
render = mask.expand_as(render) * render | |
render_list.append(render) | |
calib_list.append(calib) | |
extrinsic_list.append(extrinsic) | |
return { | |
'img': torch.stack(render_list, dim=0), | |
'calib': torch.stack(calib_list, dim=0), | |
'extrinsic': torch.stack(extrinsic_list, dim=0), | |
'mask': torch.stack(mask_list, dim=0) | |
} | |
def select_sampling_method(self, subject): | |
if not self.is_train: | |
random.seed(1991) | |
np.random.seed(1991) | |
torch.manual_seed(1991) | |
mesh = self.mesh_dic[subject] | |
surface_points, _ = trimesh.sample.sample_surface(mesh, 4 * self.num_sample_inout) | |
sample_points = surface_points + np.random.normal(scale=self.opt.sigma, size=surface_points.shape) | |
# add random points within image space | |
length = self.B_MAX - self.B_MIN | |
random_points = np.random.rand(self.num_sample_inout // 4, 3) * length + self.B_MIN | |
sample_points = np.concatenate([sample_points, random_points], 0) | |
np.random.shuffle(sample_points) | |
inside = mesh.contains(sample_points) | |
inside_points = sample_points[inside] | |
outside_points = sample_points[np.logical_not(inside)] | |
nin = inside_points.shape[0] | |
inside_points = inside_points[ | |
:self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else inside_points | |
outside_points = outside_points[ | |
:self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else outside_points[ | |
:(self.num_sample_inout - nin)] | |
samples = np.concatenate([inside_points, outside_points], 0).T | |
labels = np.concatenate([np.ones((1, inside_points.shape[0])), np.zeros((1, outside_points.shape[0]))], 1) | |
# save_samples_truncted_prob('out.ply', samples.T, labels.T) | |
# exit() | |
samples = torch.Tensor(samples).float() | |
labels = torch.Tensor(labels).float() | |
del mesh | |
return { | |
'samples': samples, | |
'labels': labels | |
} | |
def get_color_sampling(self, subject, yid, pid=0): | |
yaw = self.yaw_list[yid] | |
pitch = self.pitch_list[pid] | |
uv_render_path = os.path.join(self.UV_RENDER, subject, '%d_%d_%02d.jpg' % (yaw, pitch, 0)) | |
uv_mask_path = os.path.join(self.UV_MASK, subject, '%02d.png' % (0)) | |
uv_pos_path = os.path.join(self.UV_POS, subject, '%02d.exr' % (0)) | |
uv_normal_path = os.path.join(self.UV_NORMAL, subject, '%02d.png' % (0)) | |
# Segmentation mask for the uv render. | |
# [H, W] bool | |
uv_mask = cv2.imread(uv_mask_path) | |
uv_mask = uv_mask[:, :, 0] != 0 | |
# UV render. each pixel is the color of the point. | |
# [H, W, 3] 0 ~ 1 float | |
uv_render = cv2.imread(uv_render_path) | |
uv_render = cv2.cvtColor(uv_render, cv2.COLOR_BGR2RGB) / 255.0 | |
# Normal render. each pixel is the surface normal of the point. | |
# [H, W, 3] -1 ~ 1 float | |
uv_normal = cv2.imread(uv_normal_path) | |
uv_normal = cv2.cvtColor(uv_normal, cv2.COLOR_BGR2RGB) / 255.0 | |
uv_normal = 2.0 * uv_normal - 1.0 | |
# Position render. each pixel is the xyz coordinates of the point | |
uv_pos = cv2.imread(uv_pos_path, 2 | 4)[:, :, ::-1] | |
### In these few lines we flattern the masks, positions, and normals | |
uv_mask = uv_mask.reshape((-1)) | |
uv_pos = uv_pos.reshape((-1, 3)) | |
uv_render = uv_render.reshape((-1, 3)) | |
uv_normal = uv_normal.reshape((-1, 3)) | |
surface_points = uv_pos[uv_mask] | |
surface_colors = uv_render[uv_mask] | |
surface_normal = uv_normal[uv_mask] | |
if self.num_sample_color: | |
sample_list = random.sample(range(0, surface_points.shape[0] - 1), self.num_sample_color) | |
surface_points = surface_points[sample_list].T | |
surface_colors = surface_colors[sample_list].T | |
surface_normal = surface_normal[sample_list].T | |
# Samples are around the true surface with an offset | |
normal = torch.Tensor(surface_normal).float() | |
samples = torch.Tensor(surface_points).float() \ | |
+ torch.normal(mean=torch.zeros((1, normal.size(1))), std=self.opt.sigma).expand_as(normal) * normal | |
# Normalized to [-1, 1] | |
rgbs_color = 2.0 * torch.Tensor(surface_colors).float() - 1.0 | |
return { | |
'color_samples': samples, | |
'rgbs': rgbs_color | |
} | |
def get_item(self, index): | |
# In case of a missing file or IO error, switch to a random sample instead | |
# try: | |
sid = index % len(self.subjects) | |
tmp = index // len(self.subjects) | |
yid = tmp % len(self.yaw_list) | |
pid = tmp // len(self.yaw_list) | |
# name of the subject 'rp_xxxx_xxx' | |
subject = self.subjects[sid] | |
res = { | |
'name': subject, | |
'mesh_path': os.path.join(self.OBJ, subject + '.obj'), | |
'sid': sid, | |
'yid': yid, | |
'pid': pid, | |
'b_min': self.B_MIN, | |
'b_max': self.B_MAX, | |
} | |
render_data = self.get_render(subject, num_views=self.num_views, yid=yid, pid=pid, | |
random_sample=self.opt.random_multiview) | |
res.update(render_data) | |
if self.opt.num_sample_inout: | |
sample_data = self.select_sampling_method(subject) | |
res.update(sample_data) | |
# img = np.uint8((np.transpose(render_data['img'][0].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0) | |
# rot = render_data['calib'][0,:3, :3] | |
# trans = render_data['calib'][0,:3, 3:4] | |
# pts = torch.addmm(trans, rot, sample_data['samples'][:, sample_data['labels'][0] > 0.5]) # [3, N] | |
# pts = 0.5 * (pts.numpy().T + 1.0) * render_data['img'].size(2) | |
# for p in pts: | |
# img = cv2.circle(img, (p[0], p[1]), 2, (0,255,0), -1) | |
# cv2.imshow('test', img) | |
# cv2.waitKey(1) | |
if self.num_sample_color: | |
color_data = self.get_color_sampling(subject, yid=yid, pid=pid) | |
res.update(color_data) | |
return res | |
# except Exception as e: | |
# print(e) | |
# return self.get_item(index=random.randint(0, self.__len__() - 1)) | |
def __getitem__(self, index): | |
return self.get_item(index) |