Spaces:
Runtime error
Runtime error
File size: 6,245 Bytes
2d5f249 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import numpy as np
import json
import os
import itertools
import trimesh
from matplotlib.path import Path
from collections import Counter
from sklearn.neighbors import KNeighborsClassifier
def load_segmentation(path, shape):
"""
Get a segmentation mask for a given image
Arguments:
path: path to the segmentation json file
shape: shape of the output mask
Returns:
Returns a segmentation mask
"""
with open(path) as json_file:
dict = json.load(json_file)
segmentations = []
for key, val in dict.items():
if not key.startswith('item'):
continue
# Each item can have multiple polygons. Combine them to one
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
coordinates = []
for segmentation_coord in val['segmentation']:
# The format before is [x1,y1, x2, y2, ....]
x = segmentation_coord[::2]
y = segmentation_coord[1::2]
xy = np.vstack((x, y)).T
coordinates.append(xy)
segmentations.append(
{'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
return segmentations
def smpl_to_recon_labels(recon, smpl, k=1):
"""
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
Arguments:
recon: trimesh object (fully clothed model)
shape: trimesh object (smpl model)
k: number of nearest neighbours to use
Returns:
Returns a dictionary containing the bodypart and the corresponding indices
"""
smpl_vert_segmentation = json.load(
open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
n = smpl.vertices.shape[0]
y = np.array([None] * n)
for key, val in smpl_vert_segmentation.items():
y[val] = key
classifier = KNeighborsClassifier(n_neighbors=1)
classifier.fit(smpl.vertices, y)
y_pred = classifier.predict(recon.vertices)
recon_labels = {}
for key in smpl_vert_segmentation.keys():
recon_labels[key] = list(np.argwhere(
y_pred == key).flatten().astype(int))
return recon_labels
def extract_cloth(recon, segmentation, K, R, t, smpl=None):
"""
Extract a portion of a mesh using 2d segmentation coordinates
Arguments:
recon: fully clothed mesh
seg_coord: segmentation coordinates in 2D (NDC)
K: intrinsic matrix of the projection
R: rotation matrix of the projection
t: translation vector of the projection
Returns:
Returns a submesh using the segmentation coordinates
"""
seg_coord = segmentation['coord_normalized']
mesh = trimesh.Trimesh(recon.vertices, recon.faces)
extrinsic = np.zeros((3, 4))
extrinsic[:3, :3] = R
extrinsic[:, 3] = t
P = K[:3, :3] @ extrinsic
P_inv = np.linalg.pinv(P)
# Each segmentation can contain multiple polygons
# We need to check them separately
points_so_far = []
faces = recon.faces
for polygon in seg_coord:
n = len(polygon)
coords_h = np.hstack((polygon, np.ones((n, 1))))
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
XYZ = P_inv @ coords_h[:, :, None]
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
XYZ = XYZ[:, :3] / XYZ[:, 3, None]
p = Path(XYZ[:, :2])
grid = p.contains_points(recon.vertices[:, :2])
indeces = np.argwhere(grid == True)
points_so_far += list(indeces.flatten())
if smpl is not None:
num_verts = recon.vertices.shape[0]
recon_labels = smpl_to_recon_labels(recon, smpl)
body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
type = segmentation['type_id']
# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
# https://github.com/switchablenorms/DeepFashion2
# Short sleeve clothes
if type == 1 or type == 3 or type == 10:
body_parts_to_remove += ['leftForeArm', 'rightForeArm']
# No sleeves at all or lower body clothes
elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
body_parts_to_remove += ['leftForeArm',
'rightForeArm', 'leftArm', 'rightArm']
# Shorts
elif type == 7:
body_parts_to_remove += ['leftLeg', 'rightLeg',
'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
verts_to_remove = list(itertools.chain.from_iterable(
[recon_labels[part] for part in body_parts_to_remove]))
label_mask = np.zeros(num_verts, dtype=bool)
label_mask[verts_to_remove] = True
seg_mask = np.zeros(num_verts, dtype=bool)
seg_mask[points_so_far] = True
# Remove points that belong to other bodyparts
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
combine_mask = np.zeros(num_verts, dtype=bool)
combine_mask[points_so_far] = True
combine_mask[extra_verts_to_remove] = False
all_indices = np.argwhere(combine_mask == True).flatten()
i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
mask = np.zeros(len(recon.faces), dtype=bool)
if len(faces_to_keep) > 0:
mask[faces_to_keep] = True
mesh.update_faces(mask)
mesh.remove_unreferenced_vertices()
# mesh.rezero()
return mesh
return None
|