File size: 2,794 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
def get_tomato_slice(idx):
if idx == 0:
result = [0, 1, 2, 3, 463, 464, 465]
else:
result = [
4 + (idx - 1) * 3,
4 + (idx - 1) * 3 + 1,
4 + (idx - 1) * 3 + 2,
157 + (idx - 1) * 6,
157 + (idx - 1) * 6 + 1,
157 + (idx - 1) * 6 + 2,
157 + (idx - 1) * 6 + 3,
157 + (idx - 1) * 6 + 4,
157 + (idx - 1) * 6 + 5,
463 + idx * 3,
463 + idx * 3 + 1,
463 + idx * 3 + 2,
]
return result
def get_part_slice(idx_list, func):
result = []
for idx in idx_list:
result.extend(func(idx))
return result
def expand_mask_to_all(mask, body_scale, hand_scale, face_scale):
func = get_tomato_slice
root_slice = get_part_slice([0], func)
head_slice = get_part_slice([12, 15], func)
stem_slice = get_part_slice([3, 6, 9], func)
larm_slice = get_part_slice([14, 17, 19, 21], func)
rarm_slice = get_part_slice([13, 16, 18, 20], func)
lleg_slice = get_part_slice([2, 5, 8, 11], func)
rleg_slice = get_part_slice([1, 4, 7, 10], func)
lhnd_slice = get_part_slice(range(22, 37), func)
rhnd_slice = get_part_slice(range(37, 52), func)
face_slice = range(619, 669)
B, T = mask.shape[0], mask.shape[1]
mask = mask.view(B, T, -1)
all_mask = torch.zeros(B, T, 669).type_as(mask)
all_mask[:, :, root_slice] = mask[:, :, 0].unsqueeze(-1).repeat(1, 1, len(root_slice))
all_mask[:, :, head_slice] = mask[:, :, 1].unsqueeze(-1).repeat(1, 1, len(head_slice))
all_mask[:, :, stem_slice] = mask[:, :, 2].unsqueeze(-1).repeat(1, 1, len(stem_slice))
all_mask[:, :, larm_slice] = mask[:, :, 3].unsqueeze(-1).repeat(1, 1, len(larm_slice))
all_mask[:, :, rarm_slice] = mask[:, :, 4].unsqueeze(-1).repeat(1, 1, len(rarm_slice))
all_mask[:, :, lleg_slice] = mask[:, :, 5].unsqueeze(-1).repeat(1, 1, len(lleg_slice))
all_mask[:, :, rleg_slice] = mask[:, :, 6].unsqueeze(-1).repeat(1, 1, len(rleg_slice))
all_mask[:, :, lhnd_slice] = mask[:, :, 7].unsqueeze(-1).repeat(1, 1, len(lhnd_slice))
all_mask[:, :, rhnd_slice] = mask[:, :, 8].unsqueeze(-1).repeat(1, 1, len(rhnd_slice))
all_mask[:, :, face_slice] = mask[:, :, 9].unsqueeze(-1).repeat(1, 1, len(face_slice))
all_mask[:, :, root_slice] *= body_scale
all_mask[:, :, head_slice] *= body_scale
all_mask[:, :, stem_slice] *= body_scale
all_mask[:, :, larm_slice] *= body_scale
all_mask[:, :, rarm_slice] *= body_scale
all_mask[:, :, lleg_slice] *= body_scale
all_mask[:, :, rleg_slice] *= body_scale
all_mask[:, :, lhnd_slice] *= hand_scale
all_mask[:, :, rhnd_slice] *= hand_scale
all_mask[:, :, face_slice] *= face_scale
return all_mask
|