File size: 16,709 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
import io
import os
import shutil
from pathlib import Path
from typing import Iterable, List, Optional, Union

import cv2
import mmcv
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D

from detrsmpl.core.conventions.cameras.convert_convention import \
    enc_camera_convention  # prevent yapf isort conflict
from detrsmpl.utils.demo_utils import get_different_colors
from detrsmpl.utils.ffmpeg_utils import images_to_video
from detrsmpl.utils.path_utils import check_path_suffix


class Axes3dBaseRenderer(object):
    """Base renderer."""
    def init_camera(self,
                    cam_elev_angle=10,
                    cam_elev_speed=0.0,
                    cam_hori_angle=45,
                    cam_hori_speed=0.5):
        """Initiate the route of camera with arguments.

        Args:
            cam_elev_angle (int, optional):
                The pitch angle where camera starts.
                Defaults to 10.
            cam_elev_speed (float, optional):
                The pitch angle camera steps in one frame.
                It will go back and forth between -30 and 30 degree.
                Defaults to 0.0.
            cam_hori_angle (int, optional):
                The yaw angle where camera starts. Defaults to 45.
            cam_hori_speed (float, optional):
                The yaw angle camera steps in one frame.
                It will go back and forth between 0 and 90 degree.
                Defaults to 0.5.
        """
        self.cam_elevation_args = [cam_elev_angle, cam_elev_speed]
        self.cam_horizon_args = [cam_hori_angle, cam_hori_speed]
        self.if_camera_init = True

    def _get_camera_vector_list(self, frame_number):
        """Generate self.cam_vector_list according to hori and elev arguments.

        Args:
            frame_number (int):
                Number of frames.

        Returns:
            List[List[float, float]]:
                A list of float vectors.
        """
        self.cam_vector_list = [
            [self.cam_elevation_args[0], self.cam_horizon_args[0]],
        ]
        ele_sign = 1
        hor_sign = 1
        for _ in range(frame_number - 1):
            new_ele_angle = ele_sign * self.cam_elevation_args[
                1] + self.cam_vector_list[-1][0]
            #  if elevation angle out of range, go backwards
            if new_ele_angle <= self.cam_elevation_args[
                    1] or new_ele_angle >= 30:
                ele_sign = (-1) * ele_sign
                new_ele_angle = (ele_sign * self.cam_elevation_args[1] +
                                 self.cam_vector_list[-1][0])
            new_hor_angle = (hor_sign * self.cam_horizon_args[1] +
                             self.cam_vector_list[-1][1])
            #  if horizon angle out of range, go backwards
            if new_hor_angle >= 90 - 2 * self.cam_horizon_args[
                    1] or new_hor_angle <= 2 * self.cam_horizon_args[1]:
                hor_sign = (-1) * hor_sign
                new_hor_angle = (hor_sign * self.cam_horizon_args[1] +
                                 self.cam_vector_list[-1][1])
            self.cam_vector_list.append([new_ele_angle, new_hor_angle])
        return self.cam_vector_list

    @staticmethod
    def _get_visual_range(points: np.ndarray) -> np.ndarray:
        """Calculate the visual range according to the input points. It make
        sure that no point is absent.

        Args:
            points (np.ndarray):
                An array of 3D points.
                Axis at the last dim.

        Returns:
            np.ndarray:
                An array in shape [3, 2].
                It marks the lower bound and the upper bound
                along each axis.
        """
        axis_num = points.shape[-1]
        axis_stat = np.zeros(shape=[axis_num, 4])
        for axis_index in range(axis_num):
            axis_data = points[..., axis_index]
            axis_min = np.min(axis_data)
            axis_max = np.max(axis_data)
            axis_mid = (axis_min + axis_max) / 2.0
            axis_span = axis_max - axis_min
            axis_stat[axis_index] = np.asarray(
                (axis_min, axis_max, axis_mid, axis_span))
        max_span = np.max(axis_stat[:, 3])
        visual_range = np.zeros(shape=[axis_num, 2])
        for axis_index in range(axis_num):
            visual_range[axis_index, 0] =\
                axis_stat[axis_index, 2] - max_span/2.0
            visual_range[axis_index, 1] =\
                axis_stat[axis_index, 2] + max_span/2.0
        return visual_range

    def _draw_scene(self,
                    visual_range,
                    axis_len=1.0,
                    cam_elev_angle=10,
                    cam_hori_angle=45):
        """Draw an empty scene according to visual range and camera vector.

        Args:
            visual_range (np.ndarray):
                Return value of _get_visual_range().
            axis_len (float, optional):
                The length of every axis.
                Defaults to 1.0.
            cam_elev_angle (int, optional):
                Pitch angle of the camera.
                Defaults to 10.
            cam_hori_angle (int, optional):
                Yaw angle of the camera.
                Defaults to 45.

        Returns:
            list: Figure and Axes3D
        """
        fig = plt.figure()
        ax = Axes3D(fig, auto_add_to_figure=False)
        fig.add_axes(ax)
        ax.set_xlim(*visual_range[0])
        ax.set_ylim(*visual_range[1])
        ax.set_zlim(*visual_range[2])
        ax.view_init(cam_elev_angle, cam_hori_angle)
        mid_point = [
            np.average(visual_range[0]),
            np.average(visual_range[1]),
            np.average(visual_range[2]),
        ]
        # draw axis
        zero_point = np.array([0, 0, 0])
        x_axis = np.array([(visual_range[0][1] - mid_point[0]) * axis_len, 0,
                           0])
        y_axis = np.array(
            [0, (visual_range[1][1] - mid_point[1]) * axis_len, 0])
        z_axis = np.array(
            [0, 0, (visual_range[2][1] - mid_point[2]) * axis_len])
        ax = _plot_line_on_fig(ax, zero_point, x_axis, 'r')
        ax = _plot_line_on_fig(ax, zero_point, y_axis, 'g')
        ax = _plot_line_on_fig(ax, zero_point, z_axis, 'b')
        return fig, ax


class Axes3dJointsRenderer(Axes3dBaseRenderer):
    """Render of joints."""
    def __init__(self):
        self.if_camera_init = False
        self.cam_vector_list = None
        self.if_connection_setup = False
        self.if_frame_updated = False
        self.temp_path = ''

    def set_connections(self, limbs_connection, limbs_palette):
        """set body limbs."""
        self.limbs_connection = limbs_connection
        self.limbs_palette = limbs_palette
        self.if_connection_setup = True

    def render_kp3d_to_video(
        self,
        keypoints_np: np.ndarray,
        output_path: Optional[str] = None,
        convention='opencv',
        fps: Union[float, int] = 30,
        resolution: Iterable[int] = (720, 720),
        visual_range: Iterable[int] = (-100, 100),
        frame_names: Optional[List[str]] = None,
        disable_limbs: bool = False,
        return_array: bool = False,
    ) -> None:
        """Render 3d keypoints to a video.

        Args:
            keypoints_np (np.ndarray): shape of input array should be
                    (f * n * J * 3).
            output_path (str): output video path or frame folder.
            sign (Iterable[int], optional): direction of the axis.
                    Defaults to (1, 1, 1).
            axis (str, optional): axis convention.
                    Defaults to 'xzy'.
            fps (Union[float, int], optional): fps.
                    Defaults to 30.
            resolution (Iterable[int], optional): (width, height) of
                    output video.
                    Defaults to (720, 720).
            visual_range (Iterable[int], optional): range of axis value.
                    Defaults to (-100, 100).
            frame_names (Optional[List[str]], optional):  List of string
                    for frame title, no title if None. Defaults to None.
            disable_limbs (bool, optional): whether need to disable drawing
                limbs.
                Defaults to False.
        Returns:
            None.
        """
        assert self.if_camera_init is True
        assert self.if_connection_setup is True
        sign, axis = enc_camera_convention(convention)
        if output_path is not None:
            if check_path_suffix(output_path, ['.mp4', '.gif']):
                self.temp_path = os.path.join(
                    Path(output_path).parent,
                    Path(output_path).name + '_output_temp')
                mmcv.mkdir_or_exist(self.temp_path)
                print('make dir', self.temp_path)
                self.remove_temp = True
            else:
                self.temp_path = output_path
                self.remove_temp = False
        else:
            self.temp_path = None
        keypoints_np = _set_new_pose(keypoints_np, sign, axis)
        if not self.if_frame_updated:
            if self.cam_vector_list is None:
                self._get_camera_vector_list(
                    frame_number=keypoints_np.shape[0])
            assert len(self.cam_vector_list) == keypoints_np.shape[0]
            if visual_range is None:
                visual_range = self._get_visual_range(keypoints_np)
            else:
                visual_range = np.asarray(visual_range)
                if len(visual_range.shape) == 1:
                    one_dim_visual_range = np.expand_dims(visual_range, 0)
                    visual_range = one_dim_visual_range.repeat(3, axis=0)
            image_array = self._export_frames(keypoints_np, resolution,
                                              visual_range, frame_names,
                                              disable_limbs, return_array)
            self.if_frame_updated = True

        if output_path is not None:
            if check_path_suffix(output_path, '.mp4'):
                images_to_video(self.temp_path,
                                output_path,
                                img_format='frame_%06d.png',
                                fps=fps)
        return image_array

    def _export_frames(self, keypoints_np, resolution, visual_range,
                       frame_names, disable_limbs, return_array):
        """Write output/temp images."""
        image_array = []
        for frame_index in range(keypoints_np.shape[0]):
            keypoints_frame = keypoints_np[frame_index]
            cam_ele, cam_hor = self.cam_vector_list[frame_index]
            fig, ax = \
                self._draw_scene(visual_range=visual_range, axis_len=0.5,
                                 cam_elev_angle=cam_ele,
                                 cam_hori_angle=cam_hor)
            #  draw limbs
            num_person = keypoints_frame.shape[0]
            for person_index, keypoints_person in enumerate(keypoints_frame):
                if num_person >= 2:
                    self.limbs_palette = get_different_colors(
                        num_person)[person_index].reshape(-1, 3)
                if not disable_limbs:
                    for part_name, limbs in self.limbs_connection.items():
                        if part_name == 'body':
                            linewidth = 2
                        else:
                            linewidth = 1
                        if isinstance(self.limbs_palette, np.ndarray):
                            color = self.limbs_palette.astype(
                                np.int32).reshape(-1, 3)
                        elif isinstance(self.limbs_palette, dict):
                            color = np.array(
                                self.limbs_palette[part_name]).astype(np.int32)
                        for limb_index, limb in enumerate(limbs):
                            limb_index = min(limb_index, len(color) - 1)

                            ax = _plot_line_on_fig(
                                ax,
                                keypoints_person[limb[0]],
                                keypoints_person[limb[1]],
                                color=np.array(color[limb_index]) / 255.0,
                                linewidth=linewidth)
                scatter_points_index = list(
                    set(
                        np.array(self.limbs_connection['body']).reshape(
                            -1).tolist()))
                ax.scatter(keypoints_person[scatter_points_index, 0],
                           keypoints_person[scatter_points_index, 1],
                           keypoints_person[scatter_points_index, 2],
                           c=np.array([0, 0, 0]).reshape(1, -1),
                           s=10,
                           marker='o')
            if num_person >= 2:
                ax.xaxis.set_ticklabels([])
                ax.yaxis.set_ticklabels([])
                ax.zaxis.set_ticklabels([])
                labels = []
                custom_lines = []
                for person_index in range(num_person):
                    color = get_different_colors(
                        num_person)[person_index].reshape(1, 3) / 255.0
                    custom_lines.append(
                        Line2D([0], [0],
                               linestyle='-',
                               color=color[0],
                               lw=2,
                               marker='',
                               markeredgecolor='k',
                               markeredgewidth=.1,
                               markersize=20))
                    labels.append(f'person_{person_index + 1}')
                ax.legend(
                    handles=custom_lines,
                    labels=labels,
                    loc='upper left',
                )
            plt.close('all')
            rgb_mat = _get_cv2mat_from_buf(fig)
            resized_mat = cv2.resize(rgb_mat, resolution)
            if frame_names is not None:
                cv2.putText(
                    resized_mat, str(frame_names[frame_index]),
                    (resolution[0] // 10, resolution[1] // 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5 * resolution[0] / 500,
                    np.array([255, 255, 255]).astype(np.int32).tolist(), 2)
            if self.temp_path is not None:
                frame_path = os.path.join(self.temp_path,
                                          'frame_%06d.png' % frame_index)
                cv2.imwrite(frame_path, resized_mat)
            if return_array:
                image_array.append(resized_mat[None])
        if return_array:
            image_array = np.concatenate(image_array)
            return image_array
        else:
            return None

    def __del__(self):
        """remove temp images."""
        self.remove_temp_frames()

    def remove_temp_frames(self):
        """remove temp images."""
        if self.temp_path is not None:
            if Path(self.temp_path).is_dir() and self.remove_temp:
                shutil.rmtree(self.temp_path)


def _set_new_pose(pose_np, sign, axis):
    """set new pose with axis convention."""
    target_sign = [-1, 1, -1]
    target_axis = ['x', 'z', 'y']

    pose_rearrange_axis_result = pose_np.copy()
    for axis_index, axis_name in enumerate(target_axis):
        src_axis_index = axis.index(axis_name)
        pose_rearrange_axis_result[..., axis_index] = \
            pose_np[..., src_axis_index]

    for dim_index in range(pose_rearrange_axis_result.shape[-1]):
        pose_rearrange_axis_result[
            ..., dim_index] = sign[dim_index] / target_sign[
                dim_index] * pose_rearrange_axis_result[..., dim_index]
    return pose_rearrange_axis_result


def _plot_line_on_fig(ax,
                      point1_location,
                      point2_location,
                      color,
                      linewidth=1):
    """Draw line on fig with matplotlib."""
    ax.plot([point1_location[0], point2_location[0]],
            [point1_location[1], point2_location[1]],
            [point1_location[2], point2_location[2]],
            color=color,
            linewidth=linewidth)
    return ax


def _get_cv2mat_from_buf(fig, dpi=180):
    """Get numpy image from IO."""
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img