File size: 6,914 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
# Copyright (c) Facebook, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
@dataclass
class DensePoseChartResult:
"""
DensePose results for chart-based methods represented by labels and inner
coordinates (U, V) of individual charts. Each chart is a 2D manifold
that has an associated label and is parameterized by two coordinates U and V.
Both U and V take values in [0, 1].
Thus the results are represented by two tensors:
- labels (tensor [H, W] of long): contains estimated label for each pixel of
the detection bounding box of size (H, W)
- uv (tensor [2, H, W] of float): contains estimated U and V coordinates
for each pixel of the detection bounding box of size (H, W)
"""
labels: torch.Tensor
uv: torch.Tensor
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
labels = self.labels.to(device)
uv = self.uv.to(device)
return DensePoseChartResult(labels=labels, uv=uv)
@dataclass
class DensePoseChartResultWithConfidences:
"""
We add confidence values to DensePoseChartResult
Thus the results are represented by two tensors:
- labels (tensor [H, W] of long): contains estimated label for each pixel of
the detection bounding box of size (H, W)
- uv (tensor [2, H, W] of float): contains estimated U and V coordinates
for each pixel of the detection bounding box of size (H, W)
Plus one [H, W] tensor of float for each confidence type
"""
labels: torch.Tensor
uv: torch.Tensor
sigma_1: Optional[torch.Tensor] = None
sigma_2: Optional[torch.Tensor] = None
kappa_u: Optional[torch.Tensor] = None
kappa_v: Optional[torch.Tensor] = None
fine_segm_confidence: Optional[torch.Tensor] = None
coarse_segm_confidence: Optional[torch.Tensor] = None
def to(self, device: torch.device):
"""
Transfers all tensors to the given device, except if their value is None
"""
def to_device_if_tensor(var: Any):
if isinstance(var, torch.Tensor):
return var.to(device)
return var
return DensePoseChartResultWithConfidences(
labels=self.labels.to(device),
uv=self.uv.to(device),
sigma_1=to_device_if_tensor(self.sigma_1),
sigma_2=to_device_if_tensor(self.sigma_2),
kappa_u=to_device_if_tensor(self.kappa_u),
kappa_v=to_device_if_tensor(self.kappa_v),
fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence),
coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence),
)
@dataclass
class DensePoseChartResultQuantized:
"""
DensePose results for chart-based methods represented by labels and quantized
inner coordinates (U, V) of individual charts. Each chart is a 2D manifold
that has an associated label and is parameterized by two coordinates U and V.
Both U and V take values in [0, 1].
Quantized coordinates Uq and Vq have uint8 values which are obtained as:
Uq = U * 255 (hence 0 <= Uq <= 255)
Vq = V * 255 (hence 0 <= Vq <= 255)
Thus the results are represented by one tensor:
- labels_uv_uint8 (tensor [3, H, W] of uint8): contains estimated label
and quantized coordinates Uq and Vq for each pixel of the detection
bounding box of size (H, W)
"""
labels_uv_uint8: torch.Tensor
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
labels_uv_uint8 = self.labels_uv_uint8.to(device)
return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8)
@dataclass
class DensePoseChartResultCompressed:
"""
DensePose results for chart-based methods represented by a PNG-encoded string.
The tensor of quantized DensePose results of size [3, H, W] is considered
as an image with 3 color channels. PNG compression is applied and the result
is stored as a Base64-encoded string. The following attributes are defined:
- shape_chw (tuple of 3 int): contains shape of the result tensor
(number of channels, height, width)
- labels_uv_str (str): contains Base64-encoded results tensor of size
[3, H, W] compressed with PNG compression methods
"""
shape_chw: Tuple[int, int, int]
labels_uv_str: str
def quantize_densepose_chart_result(result: DensePoseChartResult) -> DensePoseChartResultQuantized:
"""
Applies quantization to DensePose chart-based result.
Args:
result (DensePoseChartResult): DensePose chart-based result
Return:
Quantized DensePose chart-based result (DensePoseChartResultQuantized)
"""
h, w = result.labels.shape
labels_uv_uint8 = torch.zeros([3, h, w], dtype=torch.uint8, device=result.labels.device)
labels_uv_uint8[0] = result.labels
labels_uv_uint8[1:] = (result.uv * 255).clamp(0, 255).byte()
return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8)
def compress_quantized_densepose_chart_result(
result: DensePoseChartResultQuantized,
) -> DensePoseChartResultCompressed:
"""
Compresses quantized DensePose chart-based result
Args:
result (DensePoseChartResultQuantized): quantized DensePose chart-based result
Return:
Compressed DensePose chart-based result (DensePoseChartResultCompressed)
"""
import base64
import numpy as np
from io import BytesIO
from PIL import Image
labels_uv_uint8_np_chw = result.labels_uv_uint8.cpu().numpy()
labels_uv_uint8_np_hwc = np.moveaxis(labels_uv_uint8_np_chw, 0, -1)
im = Image.fromarray(labels_uv_uint8_np_hwc)
fstream = BytesIO()
im.save(fstream, format="png", optimize=True)
labels_uv_str = base64.encodebytes(fstream.getvalue()).decode()
shape_chw = labels_uv_uint8_np_chw.shape
return DensePoseChartResultCompressed(labels_uv_str=labels_uv_str, shape_chw=shape_chw)
def decompress_compressed_densepose_chart_result(
result: DensePoseChartResultCompressed,
) -> DensePoseChartResultQuantized:
"""
Decompresses DensePose chart-based result encoded into a base64 string
Args:
result (DensePoseChartResultCompressed): compressed DensePose chart result
Return:
Quantized DensePose chart-based result (DensePoseChartResultQuantized)
"""
import base64
import numpy as np
from io import BytesIO
from PIL import Image
fstream = BytesIO(base64.decodebytes(result.labels_uv_str.encode()))
im = Image.open(fstream)
labels_uv_uint8_np_chw = np.moveaxis(np.array(im, dtype=np.uint8), -1, 0)
return DensePoseChartResultQuantized(
labels_uv_uint8=torch.from_numpy(labels_uv_uint8_np_chw.reshape(result.shape_chw))
)
|