minjung-s commited on
Commit
c5c6bad
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +2 -0
  3. README.md +13 -0
  4. app.py +436 -0
  5. densepose/__init__.py +22 -0
  6. densepose/config.py +277 -0
  7. densepose/converters/__init__.py +17 -0
  8. densepose/converters/base.py +95 -0
  9. densepose/converters/builtin.py +33 -0
  10. densepose/converters/chart_output_hflip.py +73 -0
  11. densepose/converters/chart_output_to_chart_result.py +190 -0
  12. densepose/converters/hflip.py +36 -0
  13. densepose/converters/segm_to_mask.py +152 -0
  14. densepose/converters/to_chart_result.py +72 -0
  15. densepose/converters/to_mask.py +51 -0
  16. densepose/data/__init__.py +27 -0
  17. densepose/data/build.py +738 -0
  18. densepose/data/combined_loader.py +46 -0
  19. densepose/data/dataset_mapper.py +170 -0
  20. densepose/data/datasets/__init__.py +7 -0
  21. densepose/data/datasets/builtin.py +18 -0
  22. densepose/data/datasets/chimpnsee.py +31 -0
  23. densepose/data/datasets/coco.py +434 -0
  24. densepose/data/datasets/dataset_type.py +13 -0
  25. densepose/data/datasets/lvis.py +259 -0
  26. densepose/data/image_list_dataset.py +74 -0
  27. densepose/data/inference_based_loader.py +174 -0
  28. densepose/data/meshes/__init__.py +7 -0
  29. densepose/data/meshes/builtin.py +103 -0
  30. densepose/data/meshes/catalog.py +73 -0
  31. densepose/data/samplers/__init__.py +10 -0
  32. densepose/data/samplers/densepose_base.py +205 -0
  33. densepose/data/samplers/densepose_confidence_based.py +110 -0
  34. densepose/data/samplers/densepose_cse_base.py +141 -0
  35. densepose/data/samplers/densepose_cse_confidence_based.py +121 -0
  36. densepose/data/samplers/densepose_cse_uniform.py +14 -0
  37. densepose/data/samplers/densepose_uniform.py +43 -0
  38. densepose/data/samplers/mask_from_densepose.py +30 -0
  39. densepose/data/samplers/prediction_to_gt.py +100 -0
  40. densepose/data/transform/__init__.py +5 -0
  41. densepose/data/transform/image.py +41 -0
  42. densepose/data/utils.py +40 -0
  43. densepose/data/video/__init__.py +19 -0
  44. densepose/data/video/frame_selector.py +89 -0
  45. densepose/data/video/video_keyframe_dataset.py +304 -0
  46. densepose/engine/__init__.py +5 -0
  47. densepose/engine/trainer.py +260 -0
  48. densepose/evaluation/__init__.py +5 -0
  49. densepose/evaluation/d2_evaluator_adapter.py +52 -0
  50. densepose/evaluation/densepose_coco_evaluation.py +1305 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ playground.py
2
+ __pycache__
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CatVTON
3
+ emoji: 🐈
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.40.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-sa-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from huggingface_hub import snapshot_download
13
+ from PIL import Image
14
+ torch.jit.script = lambda f: f
15
+ # from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.cloth_masker2 import AutoMasker, vis_mask
17
+ from model.pipeline import CatVTONPipeline
18
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
+ parser.add_argument(
24
+ "--base_model_path",
25
+ type=str,
26
+ default="booksforcharlie/stable-diffusion-inpainting",
27
+ # default="runwayml/stable-diffusion-inpainting",
28
+ help=(
29
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
30
+ ),
31
+ )
32
+ parser.add_argument(
33
+ "--resume_path",
34
+ type=str,
35
+ default="zhengchong/CatVTON",
36
+ help=(
37
+ "The Path to the checkpoint of trained tryon model."
38
+ ),
39
+ )
40
+ parser.add_argument(
41
+ "--output_dir",
42
+ type=str,
43
+ default="resource/demo/output",
44
+ help="The output directory where the model predictions will be written.",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--width",
49
+ type=int,
50
+ default=768,
51
+ help=(
52
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
53
+ " resolution"
54
+ ),
55
+ )
56
+ parser.add_argument(
57
+ "--height",
58
+ type=int,
59
+ default=1024,
60
+ help=(
61
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
62
+ " resolution"
63
+ ),
64
+ )
65
+ parser.add_argument(
66
+ "--repaint",
67
+ action="store_true",
68
+ help="Whether to repaint the result image with the original background."
69
+ )
70
+ parser.add_argument(
71
+ "--allow_tf32",
72
+ action="store_true",
73
+ default=True,
74
+ help=(
75
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
76
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
77
+ ),
78
+ )
79
+ parser.add_argument(
80
+ "--mixed_precision",
81
+ type=str,
82
+ default="no",
83
+ choices=["no", "fp16", "bf16"],
84
+ help=(
85
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
86
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
87
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
88
+ ),
89
+ )
90
+
91
+ args = parser.parse_args()
92
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
93
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
94
+ args.local_rank = env_local_rank
95
+
96
+ return args
97
+
98
+ def image_grid(imgs, rows, cols):
99
+ assert len(imgs) == rows * cols
100
+
101
+ w, h = imgs[0].size
102
+ grid = Image.new("RGB", size=(cols * w, rows * h))
103
+
104
+ for i, img in enumerate(imgs):
105
+ grid.paste(img, box=(i % cols * w, i // cols * h))
106
+ return grid
107
+
108
+
109
+ args = parse_args()
110
+ repo_path = snapshot_download(repo_id=args.resume_path)
111
+ # Pipeline
112
+ pipeline = CatVTONPipeline(
113
+ base_ckpt=args.base_model_path,
114
+ attn_ckpt=repo_path,
115
+ attn_ckpt_version="mix",
116
+ weight_dtype=init_weight_dtype(args.mixed_precision),
117
+ use_tf32=args.allow_tf32,
118
+ device='cuda'
119
+ )
120
+ # AutoMasker
121
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
122
+ automasker = AutoMasker(
123
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
124
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
125
+ device='cuda',
126
+ )
127
+
128
+ @spaces.GPU(duration=120)
129
+ # 매개변수로 fitting_type 추가함.
130
+ def submit_function(
131
+ person_image,
132
+ cloth_image,
133
+ cloth_type,
134
+ fitting_type,
135
+ num_inference_steps,
136
+ guidance_scale,
137
+ seed,
138
+ show_type
139
+ ):
140
+ person_image, mask = person_image["background"], person_image["layers"][0] # person_image["layers"][0]이 유저가 그린 마스크 레이어임.
141
+ mask = Image.open(mask).convert("L")
142
+ if len(np.unique(np.array(mask))) == 1:
143
+ mask = None # 사용자가 마스크를 그리지 않은 경우.
144
+ else:
145
+ mask = np.array(mask)
146
+ mask[mask > 0] = 255 # 배경이 검은색.
147
+ mask = Image.fromarray(mask)
148
+
149
+ tmp_folder = args.output_dir
150
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
151
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
152
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
153
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
154
+
155
+ generator = None
156
+ if seed != -1:
157
+ generator = torch.Generator(device='cuda').manual_seed(seed)
158
+
159
+ person_image = Image.open(person_image).convert("RGB")
160
+ cloth_image = Image.open(cloth_image).convert("RGB")
161
+ person_image = resize_and_crop(person_image, (args.width, args.height))
162
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
163
+
164
+ # Process mask
165
+ if mask is not None:
166
+ mask = resize_and_crop(mask, (args.width, args.height))
167
+ else:
168
+ mask = automasker(
169
+ person_image,
170
+ cloth_type
171
+ )['mask']
172
+
173
+ # 추가로 Fitting Type에 따라 마스크 처리 (else문 내부)
174
+ if fitting_type == "standard":
175
+ mask = mask # 코드 수정해야함. 메소드 구현한 거 가져오기
176
+ elif fitting_type == "loose" :
177
+ mask = mask # 코드 수정해야 함.
178
+
179
+ mask = mask_processor.blur(mask, blur_factor=9)
180
+
181
+ # Inference
182
+ # try:
183
+ result_image = pipeline(
184
+ image=person_image,
185
+ condition_image=cloth_image,
186
+ mask=mask,
187
+ num_inference_steps=num_inference_steps,
188
+ guidance_scale=guidance_scale,
189
+ generator=generator
190
+ )[0]
191
+ # except Exception as e:
192
+ # raise gr.Error(
193
+ # "An error occurred. Please try again later: {}".format(e)
194
+ # )
195
+
196
+ # Post-process
197
+ masked_person = vis_mask(person_image, mask)
198
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
199
+ save_result_image.save(result_save_path)
200
+ if show_type == "result only":
201
+ return result_image
202
+ else:
203
+ width, height = person_image.size
204
+ if show_type == "input & result":
205
+ condition_width = width // 2
206
+ conditions = image_grid([person_image, cloth_image], 2, 1)
207
+ else:
208
+ condition_width = width // 3
209
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
210
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
211
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
212
+ new_result_image.paste(conditions, (0, 0))
213
+ new_result_image.paste(result_image, (condition_width + 5, 0))
214
+ return new_result_image
215
+
216
+
217
+ def person_example_fn(image_path):
218
+ return image_path
219
+
220
+ HEADER = """
221
+ <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
222
+ <div style="display: flex; justify-content: center; align-items: center;">
223
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
224
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
225
+ </a>
226
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
227
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
228
+ </a>
229
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
230
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
231
+ </a>
232
+ <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
233
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
234
+ </a>
235
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
236
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
237
+ </a>
238
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
239
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
240
+ </a>
241
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
242
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
243
+ </a>
244
+ </div>
245
+ <br>
246
+ · This demo and our weights are only for Non-commercial Use. <br>
247
+ · You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
248
+ · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
249
+ · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
250
+ """
251
+
252
+ def app_gradio():
253
+ with gr.Blocks(title="CatVTON") as demo:
254
+ gr.Markdown(HEADER)
255
+ with gr.Row():
256
+ with gr.Column(scale=1, min_width=350):
257
+ with gr.Row():
258
+ image_path = gr.Image(
259
+ type="filepath",
260
+ interactive=True,
261
+ visible=False,
262
+ )
263
+ person_image = gr.ImageEditor(
264
+ interactive=True, label="Person Image", type="filepath"
265
+ )
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=1, min_width=230):
269
+ cloth_image = gr.Image(
270
+ interactive=True, label="Condition Image", type="filepath"
271
+ )
272
+ with gr.Column(scale=1, min_width=120):
273
+ gr.Markdown(
274
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically <br> </span>'
275
+ )
276
+ cloth_type = gr.Radio(
277
+ label="Try-On Cloth Type",
278
+ choices=['long sleeve', 'short sleeve', 'long pants', 'short pants', 'long dress', 'short dress'],
279
+ value="upper",
280
+ )
281
+ with gr.Column(scale=1, min_width=120):
282
+ gr.Markdown(
283
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Fitting Type` to generate automatically </span>'
284
+ )
285
+ fitting_type = gr.Radio(
286
+ label="Try-On Fitting Type",
287
+ choices=["fit", "standard", "loose"],
288
+ value="fit", # default
289
+ )
290
+
291
+
292
+ submit = gr.Button("Submit")
293
+ gr.Markdown(
294
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
295
+ )
296
+
297
+ gr.Markdown(
298
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
299
+ )
300
+ with gr.Accordion("Advanced Options", open=False):
301
+ num_inference_steps = gr.Slider(
302
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
303
+ )
304
+ # Guidence Scale
305
+ guidance_scale = gr.Slider(
306
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
307
+ )
308
+ # Random Seed
309
+ seed = gr.Slider(
310
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
311
+ )
312
+ show_type = gr.Radio(
313
+ label="Show Type",
314
+ choices=["result only", "input & result", "input & mask & result"],
315
+ value="input & mask & result",
316
+ )
317
+
318
+ with gr.Column(scale=2, min_width=500):
319
+ result_image = gr.Image(interactive=False, label="Result")
320
+ with gr.Row():
321
+ # Photo Examples
322
+ root_path = "resource/demo/example"
323
+ with gr.Column():
324
+ men_exm = gr.Examples(
325
+ examples=[
326
+ os.path.join(root_path, "person", "men", _)
327
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
328
+ ],
329
+ examples_per_page=4,
330
+ inputs=image_path,
331
+ label="Person Examples ①",
332
+ )
333
+ women_exm = gr.Examples(
334
+ examples=[
335
+ os.path.join(root_path, "person", "women", _)
336
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
337
+ ],
338
+ examples_per_page=4,
339
+ inputs=image_path,
340
+ label="Person Examples ②",
341
+ )
342
+ gr.Markdown(
343
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
344
+ )
345
+ with gr.Column():
346
+ condition_long_sleeve_exm = gr.Examples(
347
+ examples=[
348
+ os.path.join(root_path, "condition", "long_sleeve", _)
349
+ for _ in os.listdir(os.path.join(root_path, "condition", "long_sleeve"))
350
+ ],
351
+ examples_per_page=4,
352
+ inputs=cloth_image,
353
+ label="Condition long sleeve Examples(sweaters, jackets)",
354
+ )
355
+ condition_short_sleeve_exm = gr.Examples(
356
+ examples=[
357
+ os.path.join(root_path, "condition", "short_sleeve", _)
358
+ for _ in os.listdir(os.path.join(root_path, "condition", "short_sleeve"))
359
+ ],
360
+ examples_per_page=4,
361
+ inputs=cloth_image,
362
+ label="Condition short sleeve Examples(vest, sleeveleess, etc)",
363
+ )
364
+ condition_long_pants_exm = gr.Examples(
365
+ examples=[
366
+ os.path.join(root_path, "condition", "long_pants", _)
367
+ for _ in os.listdir(os.path.join(root_path, "condition", "long_pants"))
368
+ ],
369
+ examples_per_page=4,
370
+ inputs=cloth_image,
371
+ label="Condition long_pants Examples",
372
+ )
373
+ condition_short_pants_exm = gr.Examples(
374
+ examples=[
375
+ os.path.join(root_path, "condition", "short_pants", _)
376
+ for _ in os.listdir(os.path.join(root_path, "condition", "short_pants"))
377
+ ],
378
+ examples_per_page=4,
379
+ inputs=cloth_image,
380
+ label="Condition short_pants Examples",
381
+ )
382
+ condition_long_dress_exm = gr.Examples(
383
+ examples=[
384
+ os.path.join(root_path, "condition", "long_dress", _)
385
+ for _ in os.listdir(os.path.join(root_path, "condition", "long_dress"))
386
+ ],
387
+ examples_per_page=4,
388
+ inputs=cloth_image,
389
+ label="Condition long_dress Examples (!! uses default mask generator !!)",
390
+ )
391
+ condition_short_dress_exm = gr.Examples(
392
+ examples=[
393
+ os.path.join(root_path, "condition", "short_dress", _)
394
+ for _ in os.listdir(os.path.join(root_path, "condition", "short_dress"))
395
+ ],
396
+ examples_per_page=4,
397
+ inputs=cloth_image,
398
+ label="Condition short_dress Examples",
399
+ )
400
+ condition_person_exm = gr.Examples(
401
+ examples=[
402
+ os.path.join(root_path, "condition", "person", _)
403
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
404
+ ],
405
+ examples_per_page=4,
406
+ inputs=cloth_image,
407
+ label="Condition Reference Person Examples",
408
+ )
409
+ gr.Markdown(
410
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
411
+ )
412
+
413
+ image_path.change(
414
+ person_example_fn, inputs=image_path, outputs=person_image
415
+ )
416
+
417
+ #여기도 매개변수 fitting_type 추가해야 함.
418
+ submit.click(
419
+ submit_function,
420
+ [
421
+ person_image,
422
+ cloth_image,
423
+ cloth_type,
424
+ fitting_type,
425
+ num_inference_steps,
426
+ guidance_scale,
427
+ seed,
428
+ show_type,
429
+ ],
430
+ result_image,
431
+ )
432
+ demo.queue().launch(share=True, show_error=True)
433
+
434
+
435
+ if __name__ == "__main__":
436
+ app_gradio()
densepose/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .data.datasets import builtin # just to register data
5
+ from .converters import builtin as builtin_converters # register converters
6
+ from .config import (
7
+ add_densepose_config,
8
+ add_densepose_head_config,
9
+ add_hrnet_config,
10
+ add_dataset_category_config,
11
+ add_bootstrap_config,
12
+ load_bootstrap_config,
13
+ )
14
+ from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
15
+ from .evaluation import DensePoseCOCOEvaluator
16
+ from .modeling.roi_heads import DensePoseROIHeads
17
+ from .modeling.test_time_augmentation import (
18
+ DensePoseGeneralizedRCNNWithTTA,
19
+ DensePoseDatasetMapperTTA,
20
+ )
21
+ from .utils.transform import load_from_cfg
22
+ from .modeling.hrfpn import build_hrfpn_backbone
densepose/config.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding = utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # pyre-ignore-all-errors
4
+
5
+ from detectron2.config import CfgNode as CN
6
+
7
+
8
+ def add_dataset_category_config(cfg: CN) -> None:
9
+ """
10
+ Add config for additional category-related dataset options
11
+ - category whitelisting
12
+ - category mapping
13
+ """
14
+ _C = cfg
15
+ _C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
16
+ _C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
17
+ # class to mesh mapping
18
+ _C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
19
+
20
+
21
+ def add_evaluation_config(cfg: CN) -> None:
22
+ _C = cfg
23
+ _C.DENSEPOSE_EVALUATION = CN()
24
+ # evaluator type, possible values:
25
+ # - "iou": evaluator for models that produce iou data
26
+ # - "cse": evaluator for models that produce cse data
27
+ _C.DENSEPOSE_EVALUATION.TYPE = "iou"
28
+ # storage for DensePose results, possible values:
29
+ # - "none": no explicit storage, all the results are stored in the
30
+ # dictionary with predictions, memory intensive;
31
+ # historically the default storage type
32
+ # - "ram": RAM storage, uses per-process RAM storage, which is
33
+ # reduced to a single process storage on later stages,
34
+ # less memory intensive
35
+ # - "file": file storage, uses per-process file-based storage,
36
+ # the least memory intensive, but may create bottlenecks
37
+ # on file system accesses
38
+ _C.DENSEPOSE_EVALUATION.STORAGE = "none"
39
+ # minimum threshold for IOU values: the lower its values is,
40
+ # the more matches are produced (and the higher the AP score)
41
+ _C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
42
+ # Non-distributed inference is slower (at inference time) but can avoid RAM OOM
43
+ _C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
44
+ # evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
45
+ _C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
46
+ # meshes to compute mesh alignment for
47
+ _C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
48
+
49
+
50
+ def add_bootstrap_config(cfg: CN) -> None:
51
+ """ """
52
+ _C = cfg
53
+ _C.BOOTSTRAP_DATASETS = []
54
+ _C.BOOTSTRAP_MODEL = CN()
55
+ _C.BOOTSTRAP_MODEL.WEIGHTS = ""
56
+ _C.BOOTSTRAP_MODEL.DEVICE = "cuda"
57
+
58
+
59
+ def get_bootstrap_dataset_config() -> CN:
60
+ _C = CN()
61
+ _C.DATASET = ""
62
+ # ratio used to mix data loaders
63
+ _C.RATIO = 0.1
64
+ # image loader
65
+ _C.IMAGE_LOADER = CN(new_allowed=True)
66
+ _C.IMAGE_LOADER.TYPE = ""
67
+ _C.IMAGE_LOADER.BATCH_SIZE = 4
68
+ _C.IMAGE_LOADER.NUM_WORKERS = 4
69
+ _C.IMAGE_LOADER.CATEGORIES = []
70
+ _C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
71
+ _C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
72
+ # inference
73
+ _C.INFERENCE = CN()
74
+ # batch size for model inputs
75
+ _C.INFERENCE.INPUT_BATCH_SIZE = 4
76
+ # batch size to group model outputs
77
+ _C.INFERENCE.OUTPUT_BATCH_SIZE = 2
78
+ # sampled data
79
+ _C.DATA_SAMPLER = CN(new_allowed=True)
80
+ _C.DATA_SAMPLER.TYPE = ""
81
+ _C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
82
+ # filter
83
+ _C.FILTER = CN(new_allowed=True)
84
+ _C.FILTER.TYPE = ""
85
+ return _C
86
+
87
+
88
+ def load_bootstrap_config(cfg: CN) -> None:
89
+ """
90
+ Bootstrap datasets are given as a list of `dict` that are not automatically
91
+ converted into CfgNode. This method processes all bootstrap dataset entries
92
+ and ensures that they are in CfgNode format and comply with the specification
93
+ """
94
+ if not cfg.BOOTSTRAP_DATASETS:
95
+ return
96
+
97
+ bootstrap_datasets_cfgnodes = []
98
+ for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
99
+ _C = get_bootstrap_dataset_config().clone()
100
+ _C.merge_from_other_cfg(CN(dataset_cfg))
101
+ bootstrap_datasets_cfgnodes.append(_C)
102
+ cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
103
+
104
+
105
+ def add_densepose_head_cse_config(cfg: CN) -> None:
106
+ """
107
+ Add configuration options for Continuous Surface Embeddings (CSE)
108
+ """
109
+ _C = cfg
110
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
111
+ # Dimensionality D of the embedding space
112
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
113
+ # Embedder specifications for various mesh IDs
114
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
115
+ # normalization coefficient for embedding distances
116
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
117
+ # normalization coefficient for geodesic distances
118
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
119
+ # embedding loss weight
120
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
121
+ # embedding loss name, currently the following options are supported:
122
+ # - EmbeddingLoss: cross-entropy on vertex labels
123
+ # - SoftEmbeddingLoss: cross-entropy on vertex label combined with
124
+ # Gaussian penalty on distance between vertices
125
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
126
+ # optimizer hyperparameters
127
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
128
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
129
+ # Shape to shape cycle consistency loss parameters:
130
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
131
+ # shape to shape cycle consistency loss weight
132
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
133
+ # norm type used for loss computation
134
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
135
+ # normalization term for embedding similarity matrices
136
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
137
+ # maximum number of vertices to include into shape to shape cycle loss
138
+ # if negative or zero, all vertices are considered
139
+ # if positive, random subset of vertices of given size is considered
140
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
141
+ # Pixel to shape cycle consistency loss parameters:
142
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
143
+ # pixel to shape cycle consistency loss weight
144
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
145
+ # norm type used for loss computation
146
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
147
+ # map images to all meshes and back (if false, use only gt meshes from the batch)
148
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
149
+ # Randomly select at most this number of pixels from every instance
150
+ # if negative or zero, all vertices are considered
151
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
152
+ # normalization factor for pixel to pixel distances (higher value = smoother distribution)
153
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
154
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
155
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
156
+
157
+
158
+ def add_densepose_head_config(cfg: CN) -> None:
159
+ """
160
+ Add config for densepose head.
161
+ """
162
+ _C = cfg
163
+
164
+ _C.MODEL.DENSEPOSE_ON = True
165
+
166
+ _C.MODEL.ROI_DENSEPOSE_HEAD = CN()
167
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
168
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
169
+ # Number of parts used for point labels
170
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
171
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
172
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
173
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
174
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
175
+ _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
176
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
177
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
178
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
179
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
180
+ # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
181
+ _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
182
+ # Loss weights for annotation masks.(14 Parts)
183
+ _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
184
+ # Loss weights for surface parts. (24 Parts)
185
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
186
+ # Loss weights for UV regression.
187
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
188
+ # Coarse segmentation is trained using instance segmentation task data
189
+ _C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
190
+ # For Decoder
191
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
192
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
193
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
194
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
195
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
196
+ # For DeepLab head
197
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
198
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
199
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
200
+ # Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
201
+ # Some registered predictors:
202
+ # "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
203
+ # "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
204
+ # and associated confidences for predefined charts (default)
205
+ # "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
206
+ # and associated confidences for CSE
207
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
208
+ # Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
209
+ # Some registered losses:
210
+ # "DensePoseChartLoss": loss for chart-based models that estimate
211
+ # segmentation and UV coordinates
212
+ # "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
213
+ # segmentation, UV coordinates and the corresponding confidences (default)
214
+ _C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
215
+ # Confidences
216
+ # Enable learning UV confidences (variances) along with the actual values
217
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
218
+ # UV confidence lower bound
219
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
220
+ # Enable learning segmentation confidences (variances) along with the actual values
221
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
222
+ # Segmentation confidence lower bound
223
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
224
+ # Statistical model type for confidence learning, possible values:
225
+ # - "iid_iso": statistically independent identically distributed residuals
226
+ # with isotropic covariance
227
+ # - "indep_aniso": statistically independent residuals with anisotropic
228
+ # covariances
229
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
230
+ # List of angles for rotation in data augmentation during training
231
+ _C.INPUT.ROTATION_ANGLES = [0]
232
+ _C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
233
+
234
+ add_densepose_head_cse_config(cfg)
235
+
236
+
237
+ def add_hrnet_config(cfg: CN) -> None:
238
+ """
239
+ Add config for HRNet backbone.
240
+ """
241
+ _C = cfg
242
+
243
+ # For HigherHRNet w32
244
+ _C.MODEL.HRNET = CN()
245
+ _C.MODEL.HRNET.STEM_INPLANES = 64
246
+ _C.MODEL.HRNET.STAGE2 = CN()
247
+ _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
248
+ _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
249
+ _C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
250
+ _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
251
+ _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
252
+ _C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
253
+ _C.MODEL.HRNET.STAGE3 = CN()
254
+ _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
255
+ _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
256
+ _C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
257
+ _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
258
+ _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
259
+ _C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
260
+ _C.MODEL.HRNET.STAGE4 = CN()
261
+ _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
262
+ _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
263
+ _C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
264
+ _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
265
+ _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
266
+ _C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
267
+
268
+ _C.MODEL.HRNET.HRFPN = CN()
269
+ _C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
270
+
271
+
272
+ def add_densepose_config(cfg: CN) -> None:
273
+ add_densepose_head_config(cfg)
274
+ add_hrnet_config(cfg)
275
+ add_bootstrap_config(cfg)
276
+ add_dataset_category_config(cfg)
277
+ add_evaluation_config(cfg)
densepose/converters/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .hflip import HFlipConverter
6
+ from .to_mask import ToMaskConverter
7
+ from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
8
+ from .segm_to_mask import (
9
+ predictor_output_with_fine_and_coarse_segm_to_mask,
10
+ predictor_output_with_coarse_segm_to_mask,
11
+ resample_fine_and_coarse_segm_to_bbox,
12
+ )
13
+ from .chart_output_to_chart_result import (
14
+ densepose_chart_predictor_output_to_result,
15
+ densepose_chart_predictor_output_to_result_with_confidences,
16
+ )
17
+ from .chart_output_hflip import densepose_chart_predictor_output_hflip
densepose/converters/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple, Type
6
+ import torch
7
+
8
+
9
+ class BaseConverter:
10
+ """
11
+ Converter base class to be reused by various converters.
12
+ Converter allows one to convert data from various source types to a particular
13
+ destination type. Each source type needs to register its converter. The
14
+ registration for each source type is valid for all descendants of that type.
15
+ """
16
+
17
+ @classmethod
18
+ def register(cls, from_type: Type, converter: Any = None):
19
+ """
20
+ Registers a converter for the specified type.
21
+ Can be used as a decorator (if converter is None), or called as a method.
22
+
23
+ Args:
24
+ from_type (type): type to register the converter for;
25
+ all instances of this type will use the same converter
26
+ converter (callable): converter to be registered for the given
27
+ type; if None, this method is assumed to be a decorator for the converter
28
+ """
29
+
30
+ if converter is not None:
31
+ cls._do_register(from_type, converter)
32
+
33
+ def wrapper(converter: Any) -> Any:
34
+ cls._do_register(from_type, converter)
35
+ return converter
36
+
37
+ return wrapper
38
+
39
+ @classmethod
40
+ def _do_register(cls, from_type: Type, converter: Any):
41
+ cls.registry[from_type] = converter # pyre-ignore[16]
42
+
43
+ @classmethod
44
+ def _lookup_converter(cls, from_type: Type) -> Any:
45
+ """
46
+ Perform recursive lookup for the given type
47
+ to find registered converter. If a converter was found for some base
48
+ class, it gets registered for this class to save on further lookups.
49
+
50
+ Args:
51
+ from_type: type for which to find a converter
52
+ Return:
53
+ callable or None - registered converter or None
54
+ if no suitable entry was found in the registry
55
+ """
56
+ if from_type in cls.registry: # pyre-ignore[16]
57
+ return cls.registry[from_type]
58
+ for base in from_type.__bases__:
59
+ converter = cls._lookup_converter(base)
60
+ if converter is not None:
61
+ cls._do_register(from_type, converter)
62
+ return converter
63
+ return None
64
+
65
+ @classmethod
66
+ def convert(cls, instance: Any, *args, **kwargs):
67
+ """
68
+ Convert an instance to the destination type using some registered
69
+ converter. Does recursive lookup for base classes, so there's no need
70
+ for explicit registration for derived classes.
71
+
72
+ Args:
73
+ instance: source instance to convert to the destination type
74
+ Return:
75
+ An instance of the destination type obtained from the source instance
76
+ Raises KeyError, if no suitable converter found
77
+ """
78
+ instance_type = type(instance)
79
+ converter = cls._lookup_converter(instance_type)
80
+ if converter is None:
81
+ if cls.dst_type is None: # pyre-ignore[16]
82
+ output_type_str = "itself"
83
+ else:
84
+ output_type_str = cls.dst_type
85
+ raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
86
+ return converter(instance, *args, **kwargs)
87
+
88
+
89
+ IntTupleBox = Tuple[int, int, int, int]
90
+
91
+
92
+ def make_int_box(box: torch.Tensor) -> IntTupleBox:
93
+ int_box = [0, 0, 0, 0]
94
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
95
+ return int_box[0], int_box[1], int_box[2], int_box[3]
densepose/converters/builtin.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
6
+ from . import (
7
+ HFlipConverter,
8
+ ToChartResultConverter,
9
+ ToChartResultConverterWithConfidences,
10
+ ToMaskConverter,
11
+ densepose_chart_predictor_output_hflip,
12
+ densepose_chart_predictor_output_to_result,
13
+ densepose_chart_predictor_output_to_result_with_confidences,
14
+ predictor_output_with_coarse_segm_to_mask,
15
+ predictor_output_with_fine_and_coarse_segm_to_mask,
16
+ )
17
+
18
+ ToMaskConverter.register(
19
+ DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
20
+ )
21
+ ToMaskConverter.register(
22
+ DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
23
+ )
24
+
25
+ ToChartResultConverter.register(
26
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
27
+ )
28
+
29
+ ToChartResultConverterWithConfidences.register(
30
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
31
+ )
32
+
33
+ HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
densepose/converters/chart_output_hflip.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from dataclasses import fields
5
+ import torch
6
+
7
+ from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
8
+
9
+
10
+ def densepose_chart_predictor_output_hflip(
11
+ densepose_predictor_output: DensePoseChartPredictorOutput,
12
+ transform_data: DensePoseTransformData,
13
+ ) -> DensePoseChartPredictorOutput:
14
+ """
15
+ Change to take into account a Horizontal flip.
16
+ """
17
+ if len(densepose_predictor_output) > 0:
18
+
19
+ PredictorOutput = type(densepose_predictor_output)
20
+ output_dict = {}
21
+
22
+ for field in fields(densepose_predictor_output):
23
+ field_value = getattr(densepose_predictor_output, field.name)
24
+ # flip tensors
25
+ if isinstance(field_value, torch.Tensor):
26
+ setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
27
+
28
+ densepose_predictor_output = _flip_iuv_semantics_tensor(
29
+ densepose_predictor_output, transform_data
30
+ )
31
+ densepose_predictor_output = _flip_segm_semantics_tensor(
32
+ densepose_predictor_output, transform_data
33
+ )
34
+
35
+ for field in fields(densepose_predictor_output):
36
+ output_dict[field.name] = getattr(densepose_predictor_output, field.name)
37
+
38
+ return PredictorOutput(**output_dict)
39
+ else:
40
+ return densepose_predictor_output
41
+
42
+
43
+ def _flip_iuv_semantics_tensor(
44
+ densepose_predictor_output: DensePoseChartPredictorOutput,
45
+ dp_transform_data: DensePoseTransformData,
46
+ ) -> DensePoseChartPredictorOutput:
47
+ point_label_symmetries = dp_transform_data.point_label_symmetries
48
+ uv_symmetries = dp_transform_data.uv_symmetries
49
+
50
+ N, C, H, W = densepose_predictor_output.u.shape
51
+ u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
52
+ v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
53
+ Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
54
+ None, :, None, None
55
+ ].expand(N, C - 1, H, W)
56
+ densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
57
+ densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
58
+
59
+ for el in ["fine_segm", "u", "v"]:
60
+ densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
61
+ :, point_label_symmetries, :, :
62
+ ]
63
+ return densepose_predictor_output
64
+
65
+
66
+ def _flip_segm_semantics_tensor(
67
+ densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
68
+ ):
69
+ if densepose_predictor_output.coarse_segm.shape[1] > 2:
70
+ densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
71
+ :, dp_transform_data.mask_label_symmetries, :, :
72
+ ]
73
+ return densepose_predictor_output
densepose/converters/chart_output_to_chart_result.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Dict
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures.boxes import Boxes, BoxMode
10
+
11
+ from ..structures import (
12
+ DensePoseChartPredictorOutput,
13
+ DensePoseChartResult,
14
+ DensePoseChartResultWithConfidences,
15
+ )
16
+ from . import resample_fine_and_coarse_segm_to_bbox
17
+ from .base import IntTupleBox, make_int_box
18
+
19
+
20
+ def resample_uv_tensors_to_bbox(
21
+ u: torch.Tensor,
22
+ v: torch.Tensor,
23
+ labels: torch.Tensor,
24
+ box_xywh_abs: IntTupleBox,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Resamples U and V coordinate estimates for the given bounding box
28
+
29
+ Args:
30
+ u (tensor [1, C, H, W] of float): U coordinates
31
+ v (tensor [1, C, H, W] of float): V coordinates
32
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
33
+ outputs for the given bounding box
34
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
35
+ Return:
36
+ Resampled U and V coordinates - a tensor [2, H, W] of float
37
+ """
38
+ x, y, w, h = box_xywh_abs
39
+ w = max(int(w), 1)
40
+ h = max(int(h), 1)
41
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
42
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
43
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
44
+ for part_id in range(1, u_bbox.size(1)):
45
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
46
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
47
+ return uv
48
+
49
+
50
+ def resample_uv_to_bbox(
51
+ predictor_output: DensePoseChartPredictorOutput,
52
+ labels: torch.Tensor,
53
+ box_xywh_abs: IntTupleBox,
54
+ ) -> torch.Tensor:
55
+ """
56
+ Resamples U and V coordinate estimates for the given bounding box
57
+
58
+ Args:
59
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
60
+ output to be resampled
61
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
62
+ outputs for the given bounding box
63
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
64
+ Return:
65
+ Resampled U and V coordinates - a tensor [2, H, W] of float
66
+ """
67
+ return resample_uv_tensors_to_bbox(
68
+ predictor_output.u,
69
+ predictor_output.v,
70
+ labels,
71
+ box_xywh_abs,
72
+ )
73
+
74
+
75
+ def densepose_chart_predictor_output_to_result(
76
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
77
+ ) -> DensePoseChartResult:
78
+ """
79
+ Convert densepose chart predictor outputs to results
80
+
81
+ Args:
82
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
83
+ output to be converted to results, must contain only 1 output
84
+ boxes (Boxes): bounding box that corresponds to the predictor output,
85
+ must contain only 1 bounding box
86
+ Return:
87
+ DensePose chart-based result (DensePoseChartResult)
88
+ """
89
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
90
+ f"Predictor output to result conversion can operate only single outputs"
91
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
92
+ )
93
+
94
+ boxes_xyxy_abs = boxes.tensor.clone()
95
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
96
+ box_xywh = make_int_box(boxes_xywh_abs[0])
97
+
98
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
99
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
100
+ return DensePoseChartResult(labels=labels, uv=uv)
101
+
102
+
103
+ def resample_confidences_to_bbox(
104
+ predictor_output: DensePoseChartPredictorOutput,
105
+ labels: torch.Tensor,
106
+ box_xywh_abs: IntTupleBox,
107
+ ) -> Dict[str, torch.Tensor]:
108
+ """
109
+ Resamples confidences for the given bounding box
110
+
111
+ Args:
112
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
113
+ output to be resampled
114
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
115
+ outputs for the given bounding box
116
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
117
+ Return:
118
+ Resampled confidences - a dict of [H, W] tensors of float
119
+ """
120
+
121
+ x, y, w, h = box_xywh_abs
122
+ w = max(int(w), 1)
123
+ h = max(int(h), 1)
124
+
125
+ confidence_names = [
126
+ "sigma_1",
127
+ "sigma_2",
128
+ "kappa_u",
129
+ "kappa_v",
130
+ "fine_segm_confidence",
131
+ "coarse_segm_confidence",
132
+ ]
133
+ confidence_results = {key: None for key in confidence_names}
134
+ confidence_names = [
135
+ key for key in confidence_names if getattr(predictor_output, key) is not None
136
+ ]
137
+ confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
138
+
139
+ # assign data from channels that correspond to the labels
140
+ for key in confidence_names:
141
+ resampled_confidence = F.interpolate(
142
+ getattr(predictor_output, key),
143
+ (h, w),
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+ result = confidence_base.clone()
148
+ for part_id in range(1, predictor_output.u.size(1)):
149
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
150
+ # confidence is not part-based, don't try to fill it part by part
151
+ continue
152
+ result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
153
+
154
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
155
+ # confidence is not part-based, fill the data with the first channel
156
+ # (targeted for segmentation confidences that have only 1 channel)
157
+ result = resampled_confidence[0, 0]
158
+
159
+ confidence_results[key] = result
160
+
161
+ return confidence_results # pyre-ignore[7]
162
+
163
+
164
+ def densepose_chart_predictor_output_to_result_with_confidences(
165
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
166
+ ) -> DensePoseChartResultWithConfidences:
167
+ """
168
+ Convert densepose chart predictor outputs to results
169
+
170
+ Args:
171
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
172
+ output with confidences to be converted to results, must contain only 1 output
173
+ boxes (Boxes): bounding box that corresponds to the predictor output,
174
+ must contain only 1 bounding box
175
+ Return:
176
+ DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
177
+ """
178
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
179
+ f"Predictor output to result conversion can operate only single outputs"
180
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
181
+ )
182
+
183
+ boxes_xyxy_abs = boxes.tensor.clone()
184
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
185
+ box_xywh = make_int_box(boxes_xywh_abs[0])
186
+
187
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
188
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
189
+ confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
190
+ return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
densepose/converters/hflip.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from .base import BaseConverter
8
+
9
+
10
+ class HFlipConverter(BaseConverter):
11
+ """
12
+ Converts various DensePose predictor outputs to DensePose results.
13
+ Each DensePose predictor output type has to register its convertion strategy.
14
+ """
15
+
16
+ registry = {}
17
+ dst_type = None
18
+
19
+ @classmethod
20
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
21
+ # inconsistently.
22
+ def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
23
+ """
24
+ Performs an horizontal flip on DensePose predictor outputs.
25
+ Does recursive lookup for base classes, so there's no need
26
+ for explicit registration for derived classes.
27
+
28
+ Args:
29
+ predictor_outputs: DensePose predictor output to be converted to BitMasks
30
+ transform_data: Anything useful for the flip
31
+ Return:
32
+ An instance of the same type as predictor_outputs
33
+ """
34
+ return super(HFlipConverter, cls).convert(
35
+ predictor_outputs, transform_data, *args, **kwargs
36
+ )
densepose/converters/segm_to_mask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BitMasks, Boxes, BoxMode
10
+
11
+ from .base import IntTupleBox, make_int_box
12
+ from .to_mask import ImageSizeType
13
+
14
+
15
+ def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
16
+ """
17
+ Resample coarse segmentation tensor to the given
18
+ bounding box and derive labels for each pixel of the bounding box
19
+
20
+ Args:
21
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
22
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
23
+ corner coordinates, width (W) and height (H)
24
+ Return:
25
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
26
+ """
27
+ x, y, w, h = box_xywh_abs
28
+ w = max(int(w), 1)
29
+ h = max(int(h), 1)
30
+ labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
31
+ return labels
32
+
33
+
34
+ def resample_fine_and_coarse_segm_tensors_to_bbox(
35
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
36
+ ):
37
+ """
38
+ Resample fine and coarse segmentation tensors to the given
39
+ bounding box and derive labels for each pixel of the bounding box
40
+
41
+ Args:
42
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
43
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
44
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
45
+ corner coordinates, width (W) and height (H)
46
+ Return:
47
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
48
+ """
49
+ x, y, w, h = box_xywh_abs
50
+ w = max(int(w), 1)
51
+ h = max(int(h), 1)
52
+ # coarse segmentation
53
+ coarse_segm_bbox = F.interpolate(
54
+ coarse_segm,
55
+ (h, w),
56
+ mode="bilinear",
57
+ align_corners=False,
58
+ ).argmax(dim=1)
59
+ # combined coarse and fine segmentation
60
+ labels = (
61
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
62
+ * (coarse_segm_bbox > 0).long()
63
+ )
64
+ return labels
65
+
66
+
67
+ def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
68
+ """
69
+ Resample fine and coarse segmentation outputs from a predictor to the given
70
+ bounding box and derive labels for each pixel of the bounding box
71
+
72
+ Args:
73
+ predictor_output: DensePose predictor output that contains segmentation
74
+ results to be resampled
75
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
76
+ corner coordinates, width (W) and height (H)
77
+ Return:
78
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
79
+ """
80
+ return resample_fine_and_coarse_segm_tensors_to_bbox(
81
+ predictor_output.fine_segm,
82
+ predictor_output.coarse_segm,
83
+ box_xywh_abs,
84
+ )
85
+
86
+
87
+ def predictor_output_with_coarse_segm_to_mask(
88
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
89
+ ) -> BitMasks:
90
+ """
91
+ Convert predictor output with coarse and fine segmentation to a mask.
92
+ Assumes that predictor output has the following attributes:
93
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
94
+ unnormalized scores for N instances; D is the number of coarse
95
+ segmentation labels, H and W is the resolution of the estimate
96
+
97
+ Args:
98
+ predictor_output: DensePose predictor output to be converted to mask
99
+ boxes (Boxes): bounding boxes that correspond to the DensePose
100
+ predictor outputs
101
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
102
+ Return:
103
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
104
+ a mask of the size of the image for each instance
105
+ """
106
+ H, W = image_size_hw
107
+ boxes_xyxy_abs = boxes.tensor.clone()
108
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
109
+ N = len(boxes_xywh_abs)
110
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
111
+ for i in range(len(boxes_xywh_abs)):
112
+ box_xywh = make_int_box(boxes_xywh_abs[i])
113
+ box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
114
+ x, y, w, h = box_xywh
115
+ masks[i, y : y + h, x : x + w] = box_mask
116
+
117
+ return BitMasks(masks)
118
+
119
+
120
+ def predictor_output_with_fine_and_coarse_segm_to_mask(
121
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
122
+ ) -> BitMasks:
123
+ """
124
+ Convert predictor output with coarse and fine segmentation to a mask.
125
+ Assumes that predictor output has the following attributes:
126
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
127
+ unnormalized scores for N instances; D is the number of coarse
128
+ segmentation labels, H and W is the resolution of the estimate
129
+ - fine_segm (tensor of size [N, C, H, W]): fine segmentation
130
+ unnormalized scores for N instances; C is the number of fine
131
+ segmentation labels, H and W is the resolution of the estimate
132
+
133
+ Args:
134
+ predictor_output: DensePose predictor output to be converted to mask
135
+ boxes (Boxes): bounding boxes that correspond to the DensePose
136
+ predictor outputs
137
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
138
+ Return:
139
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
140
+ a mask of the size of the image for each instance
141
+ """
142
+ H, W = image_size_hw
143
+ boxes_xyxy_abs = boxes.tensor.clone()
144
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
145
+ N = len(boxes_xywh_abs)
146
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
147
+ for i in range(len(boxes_xywh_abs)):
148
+ box_xywh = make_int_box(boxes_xywh_abs[i])
149
+ labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
150
+ x, y, w, h = box_xywh
151
+ masks[i, y : y + h, x : x + w] = labels_i > 0
152
+ return BitMasks(masks)
densepose/converters/to_chart_result.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from detectron2.structures import Boxes
8
+
9
+ from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
10
+ from .base import BaseConverter
11
+
12
+
13
+ class ToChartResultConverter(BaseConverter):
14
+ """
15
+ Converts various DensePose predictor outputs to DensePose results.
16
+ Each DensePose predictor output type has to register its convertion strategy.
17
+ """
18
+
19
+ registry = {}
20
+ dst_type = DensePoseChartResult
21
+
22
+ @classmethod
23
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
24
+ # inconsistently.
25
+ def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
26
+ """
27
+ Convert DensePose predictor outputs to DensePoseResult using some registered
28
+ converter. Does recursive lookup for base classes, so there's no need
29
+ for explicit registration for derived classes.
30
+
31
+ Args:
32
+ densepose_predictor_outputs: DensePose predictor output to be
33
+ converted to BitMasks
34
+ boxes (Boxes): bounding boxes that correspond to the DensePose
35
+ predictor outputs
36
+ Return:
37
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
38
+ """
39
+ return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
40
+
41
+
42
+ class ToChartResultConverterWithConfidences(BaseConverter):
43
+ """
44
+ Converts various DensePose predictor outputs to DensePose results.
45
+ Each DensePose predictor output type has to register its convertion strategy.
46
+ """
47
+
48
+ registry = {}
49
+ dst_type = DensePoseChartResultWithConfidences
50
+
51
+ @classmethod
52
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
53
+ # inconsistently.
54
+ def convert(
55
+ cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
56
+ ) -> DensePoseChartResultWithConfidences:
57
+ """
58
+ Convert DensePose predictor outputs to DensePoseResult with confidences
59
+ using some registered converter. Does recursive lookup for base classes,
60
+ so there's no need for explicit registration for derived classes.
61
+
62
+ Args:
63
+ densepose_predictor_outputs: DensePose predictor output with confidences
64
+ to be converted to BitMasks
65
+ boxes (Boxes): bounding boxes that correspond to the DensePose
66
+ predictor outputs
67
+ Return:
68
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
69
+ """
70
+ return super(ToChartResultConverterWithConfidences, cls).convert(
71
+ predictor_outputs, boxes, *args, **kwargs
72
+ )
densepose/converters/to_mask.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple
6
+
7
+ from detectron2.structures import BitMasks, Boxes
8
+
9
+ from .base import BaseConverter
10
+
11
+ ImageSizeType = Tuple[int, int]
12
+
13
+
14
+ class ToMaskConverter(BaseConverter):
15
+ """
16
+ Converts various DensePose predictor outputs to masks
17
+ in bit mask format (see `BitMasks`). Each DensePose predictor output type
18
+ has to register its convertion strategy.
19
+ """
20
+
21
+ registry = {}
22
+ dst_type = BitMasks
23
+
24
+ @classmethod
25
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
26
+ # inconsistently.
27
+ def convert(
28
+ cls,
29
+ densepose_predictor_outputs: Any,
30
+ boxes: Boxes,
31
+ image_size_hw: ImageSizeType,
32
+ *args,
33
+ **kwargs
34
+ ) -> BitMasks:
35
+ """
36
+ Convert DensePose predictor outputs to BitMasks using some registered
37
+ converter. Does recursive lookup for base classes, so there's no need
38
+ for explicit registration for derived classes.
39
+
40
+ Args:
41
+ densepose_predictor_outputs: DensePose predictor output to be
42
+ converted to BitMasks
43
+ boxes (Boxes): bounding boxes that correspond to the DensePose
44
+ predictor outputs
45
+ image_size_hw (tuple [int, int]): image height and width
46
+ Return:
47
+ An instance of `BitMasks`. If no suitable converter was found, raises KeyError
48
+ """
49
+ return super(ToMaskConverter, cls).convert(
50
+ densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
51
+ )
densepose/data/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .meshes import builtin
6
+ from .build import (
7
+ build_detection_test_loader,
8
+ build_detection_train_loader,
9
+ build_combined_loader,
10
+ build_frame_selector,
11
+ build_inference_based_loaders,
12
+ has_inference_based_loaders,
13
+ BootstrapDatasetFactoryCatalog,
14
+ )
15
+ from .combined_loader import CombinedDataLoader
16
+ from .dataset_mapper import DatasetMapper
17
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
18
+ from .image_list_dataset import ImageListDataset
19
+ from .utils import is_relative_local_path, maybe_prepend_base_path
20
+
21
+ # ensure the builtin datasets are registered
22
+ from . import datasets
23
+
24
+ # ensure the bootstrap datasets builders are registered
25
+ from . import build
26
+
27
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
densepose/data/build.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import itertools
6
+ import logging
7
+ import numpy as np
8
+ from collections import UserDict, defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.config import CfgNode
15
+ from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
16
+ from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
17
+ from detectron2.data.build import (
18
+ load_proposals_into_dataset,
19
+ print_instances_class_histogram,
20
+ trivial_batch_collator,
21
+ worker_init_reset_seed,
22
+ )
23
+ from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
24
+ from detectron2.data.samplers import TrainingSampler
25
+ from detectron2.utils.comm import get_world_size
26
+
27
+ from densepose.config import get_bootstrap_dataset_config
28
+ from densepose.modeling import build_densepose_embedder
29
+
30
+ from .combined_loader import CombinedDataLoader, Loader
31
+ from .dataset_mapper import DatasetMapper
32
+ from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
33
+ from .datasets.dataset_type import DatasetType
34
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
35
+ from .samplers import (
36
+ DensePoseConfidenceBasedSampler,
37
+ DensePoseCSEConfidenceBasedSampler,
38
+ DensePoseCSEUniformSampler,
39
+ DensePoseUniformSampler,
40
+ MaskFromDensePoseSampler,
41
+ PredictionToGroundTruthSampler,
42
+ )
43
+ from .transform import ImageResizeTransform
44
+ from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
45
+ from .video import (
46
+ FirstKFramesSelector,
47
+ FrameSelectionStrategy,
48
+ LastKFramesSelector,
49
+ RandomKFramesSelector,
50
+ VideoKeyframeDataset,
51
+ video_list_from_file,
52
+ )
53
+
54
+ __all__ = ["build_detection_train_loader", "build_detection_test_loader"]
55
+
56
+
57
+ Instance = Dict[str, Any]
58
+ InstancePredicate = Callable[[Instance], bool]
59
+
60
+
61
+ def _compute_num_images_per_worker(cfg: CfgNode) -> int:
62
+ num_workers = get_world_size()
63
+ images_per_batch = cfg.SOLVER.IMS_PER_BATCH
64
+ assert (
65
+ images_per_batch % num_workers == 0
66
+ ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
67
+ images_per_batch, num_workers
68
+ )
69
+ assert (
70
+ images_per_batch >= num_workers
71
+ ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
72
+ images_per_batch, num_workers
73
+ )
74
+ images_per_worker = images_per_batch // num_workers
75
+ return images_per_worker
76
+
77
+
78
+ def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
79
+ meta = MetadataCatalog.get(dataset_name)
80
+ for dataset_dict in dataset_dicts:
81
+ for ann in dataset_dict["annotations"]:
82
+ ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
83
+
84
+
85
+ @dataclass
86
+ class _DatasetCategory:
87
+ """
88
+ Class representing category data in a dataset:
89
+ - id: category ID, as specified in the dataset annotations file
90
+ - name: category name, as specified in the dataset annotations file
91
+ - mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
92
+ - mapped_name: category name after applying category maps
93
+ - dataset_name: dataset in which the category is defined
94
+
95
+ For example, when training models in a class-agnostic manner, one could take LVIS 1.0
96
+ dataset and map the animal categories to the same category as human data from COCO:
97
+ id = 225
98
+ name = "cat"
99
+ mapped_id = 1
100
+ mapped_name = "person"
101
+ dataset_name = "lvis_v1_animals_dp_train"
102
+ """
103
+
104
+ id: int
105
+ name: str
106
+ mapped_id: int
107
+ mapped_name: str
108
+ dataset_name: str
109
+
110
+
111
+ _MergedCategoriesT = Dict[int, List[_DatasetCategory]]
112
+
113
+
114
+ def _add_category_id_to_contiguous_id_maps_to_metadata(
115
+ merged_categories: _MergedCategoriesT,
116
+ ) -> None:
117
+ merged_categories_per_dataset = {}
118
+ for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
119
+ for cat in merged_categories[cat_id]:
120
+ if cat.dataset_name not in merged_categories_per_dataset:
121
+ merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
122
+ merged_categories_per_dataset[cat.dataset_name][cat_id].append(
123
+ (
124
+ contiguous_cat_id,
125
+ cat,
126
+ )
127
+ )
128
+
129
+ logger = logging.getLogger(__name__)
130
+ for dataset_name, merged_categories in merged_categories_per_dataset.items():
131
+ meta = MetadataCatalog.get(dataset_name)
132
+ if not hasattr(meta, "thing_classes"):
133
+ meta.thing_classes = []
134
+ meta.thing_dataset_id_to_contiguous_id = {}
135
+ meta.thing_dataset_id_to_merged_id = {}
136
+ else:
137
+ meta.thing_classes.clear()
138
+ meta.thing_dataset_id_to_contiguous_id.clear()
139
+ meta.thing_dataset_id_to_merged_id.clear()
140
+ logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
141
+ for _cat_id, categories in sorted(merged_categories.items()):
142
+ added_to_thing_classes = False
143
+ for contiguous_cat_id, cat in categories:
144
+ if not added_to_thing_classes:
145
+ meta.thing_classes.append(cat.mapped_name)
146
+ added_to_thing_classes = True
147
+ meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
148
+ meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
149
+ logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
150
+
151
+
152
+ def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
153
+ def has_annotations(instance: Instance) -> bool:
154
+ return "annotations" in instance
155
+
156
+ def has_only_crowd_anotations(instance: Instance) -> bool:
157
+ for ann in instance["annotations"]:
158
+ if ann.get("is_crowd", 0) == 0:
159
+ return False
160
+ return True
161
+
162
+ def general_keep_instance_predicate(instance: Instance) -> bool:
163
+ return has_annotations(instance) and not has_only_crowd_anotations(instance)
164
+
165
+ if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
166
+ return None
167
+ return general_keep_instance_predicate
168
+
169
+
170
+ def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
171
+
172
+ min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
173
+
174
+ def has_sufficient_num_keypoints(instance: Instance) -> bool:
175
+ num_kpts = sum(
176
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
177
+ for ann in instance["annotations"]
178
+ if "keypoints" in ann
179
+ )
180
+ return num_kpts >= min_num_keypoints
181
+
182
+ if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
183
+ return has_sufficient_num_keypoints
184
+ return None
185
+
186
+
187
+ def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
188
+ if not cfg.MODEL.MASK_ON:
189
+ return None
190
+
191
+ def has_mask_annotations(instance: Instance) -> bool:
192
+ return any("segmentation" in ann for ann in instance["annotations"])
193
+
194
+ return has_mask_annotations
195
+
196
+
197
+ def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
198
+ if not cfg.MODEL.DENSEPOSE_ON:
199
+ return None
200
+
201
+ use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
202
+
203
+ def has_densepose_annotations(instance: Instance) -> bool:
204
+ for ann in instance["annotations"]:
205
+ if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
206
+ key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
207
+ ):
208
+ return True
209
+ if use_masks and "segmentation" in ann:
210
+ return True
211
+ return False
212
+
213
+ return has_densepose_annotations
214
+
215
+
216
+ def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
217
+ specific_predicate_creators = [
218
+ _maybe_create_keypoints_keep_instance_predicate,
219
+ _maybe_create_mask_keep_instance_predicate,
220
+ _maybe_create_densepose_keep_instance_predicate,
221
+ ]
222
+ predicates = [creator(cfg) for creator in specific_predicate_creators]
223
+ predicates = [p for p in predicates if p is not None]
224
+ if not predicates:
225
+ return None
226
+
227
+ def combined_predicate(instance: Instance) -> bool:
228
+ return any(p(instance) for p in predicates)
229
+
230
+ return combined_predicate
231
+
232
+
233
+ def _get_train_keep_instance_predicate(cfg: CfgNode):
234
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
235
+ combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
236
+
237
+ def combined_general_specific_keep_predicate(instance: Instance) -> bool:
238
+ return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
239
+
240
+ if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
241
+ return None
242
+ if general_keep_predicate is None:
243
+ return combined_specific_keep_predicate
244
+ if combined_specific_keep_predicate is None:
245
+ return general_keep_predicate
246
+ return combined_general_specific_keep_predicate
247
+
248
+
249
+ def _get_test_keep_instance_predicate(cfg: CfgNode):
250
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
251
+ return general_keep_predicate
252
+
253
+
254
+ def _maybe_filter_and_map_categories(
255
+ dataset_name: str, dataset_dicts: List[Instance]
256
+ ) -> List[Instance]:
257
+ meta = MetadataCatalog.get(dataset_name)
258
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
259
+ filtered_dataset_dicts = []
260
+ for dataset_dict in dataset_dicts:
261
+ anns = []
262
+ for ann in dataset_dict["annotations"]:
263
+ cat_id = ann["category_id"]
264
+ if cat_id not in category_id_map:
265
+ continue
266
+ ann["category_id"] = category_id_map[cat_id]
267
+ anns.append(ann)
268
+ dataset_dict["annotations"] = anns
269
+ filtered_dataset_dicts.append(dataset_dict)
270
+ return filtered_dataset_dicts
271
+
272
+
273
+ def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
274
+ for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
275
+ meta = MetadataCatalog.get(dataset_name)
276
+ meta.whitelisted_categories = whitelisted_cat_ids
277
+ logger = logging.getLogger(__name__)
278
+ logger.info(
279
+ "Whitelisted categories for dataset {}: {}".format(
280
+ dataset_name, meta.whitelisted_categories
281
+ )
282
+ )
283
+
284
+
285
+ def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
286
+ for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
287
+ category_map = {
288
+ int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
289
+ }
290
+ meta = MetadataCatalog.get(dataset_name)
291
+ meta.category_map = category_map
292
+ logger = logging.getLogger(__name__)
293
+ logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
294
+
295
+
296
+ def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
297
+ meta = MetadataCatalog.get(dataset_name)
298
+ meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
299
+ meta.categories = dataset_cfg.CATEGORIES
300
+ meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
301
+ logger = logging.getLogger(__name__)
302
+ logger.info(
303
+ "Category to class mapping for dataset {}: {}".format(
304
+ dataset_name, meta.category_to_class_mapping
305
+ )
306
+ )
307
+
308
+
309
+ def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
310
+ for dataset_name in dataset_names:
311
+ meta = MetadataCatalog.get(dataset_name)
312
+ if not hasattr(meta, "class_to_mesh_name"):
313
+ meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
314
+
315
+
316
+ def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
317
+ merged_categories = defaultdict(list)
318
+ category_names = {}
319
+ for dataset_name in dataset_names:
320
+ meta = MetadataCatalog.get(dataset_name)
321
+ whitelisted_categories = meta.get("whitelisted_categories")
322
+ category_map = meta.get("category_map", {})
323
+ cat_ids = (
324
+ whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
325
+ )
326
+ for cat_id in cat_ids:
327
+ cat_name = meta.categories[cat_id]
328
+ cat_id_mapped = category_map.get(cat_id, cat_id)
329
+ if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
330
+ category_names[cat_id] = cat_name
331
+ else:
332
+ category_names[cat_id] = str(cat_id_mapped)
333
+ # assign temporary mapped category name, this name can be changed
334
+ # during the second pass, since mapped ID can correspond to a category
335
+ # from a different dataset
336
+ cat_name_mapped = meta.categories[cat_id_mapped]
337
+ merged_categories[cat_id_mapped].append(
338
+ _DatasetCategory(
339
+ id=cat_id,
340
+ name=cat_name,
341
+ mapped_id=cat_id_mapped,
342
+ mapped_name=cat_name_mapped,
343
+ dataset_name=dataset_name,
344
+ )
345
+ )
346
+ # second pass to assign proper mapped category names
347
+ for cat_id, categories in merged_categories.items():
348
+ for cat in categories:
349
+ if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
350
+ cat.mapped_name = category_names[cat_id]
351
+
352
+ return merged_categories
353
+
354
+
355
+ def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
356
+ logger = logging.getLogger(__name__)
357
+ for cat_id in merged_categories:
358
+ merged_categories_i = merged_categories[cat_id]
359
+ first_cat_name = merged_categories_i[0].name
360
+ if len(merged_categories_i) > 1 and not all(
361
+ cat.name == first_cat_name for cat in merged_categories_i[1:]
362
+ ):
363
+ cat_summary_str = ", ".join(
364
+ [f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
365
+ )
366
+ logger.warning(
367
+ f"Merged category {cat_id} corresponds to the following categories: "
368
+ f"{cat_summary_str}"
369
+ )
370
+
371
+
372
+ def combine_detection_dataset_dicts(
373
+ dataset_names: Collection[str],
374
+ keep_instance_predicate: Optional[InstancePredicate] = None,
375
+ proposal_files: Optional[Collection[str]] = None,
376
+ ) -> List[Instance]:
377
+ """
378
+ Load and prepare dataset dicts for training / testing
379
+
380
+ Args:
381
+ dataset_names (Collection[str]): a list of dataset names
382
+ keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
383
+ applied to instance dicts which defines whether to keep the instance
384
+ proposal_files (Collection[str]): if given, a list of object proposal files
385
+ that match each dataset in `dataset_names`.
386
+ """
387
+ assert len(dataset_names)
388
+ if proposal_files is None:
389
+ proposal_files = [None] * len(dataset_names)
390
+ assert len(dataset_names) == len(proposal_files)
391
+ # load datasets and metadata
392
+ dataset_name_to_dicts = {}
393
+ for dataset_name in dataset_names:
394
+ dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
395
+ assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
396
+ # merge categories, requires category metadata to be loaded
397
+ # cat_id -> [(orig_cat_id, cat_name, dataset_name)]
398
+ merged_categories = _merge_categories(dataset_names)
399
+ _warn_if_merged_different_categories(merged_categories)
400
+ merged_category_names = [
401
+ merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
402
+ ]
403
+ # map to contiguous category IDs
404
+ _add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
405
+ # load annotations and dataset metadata
406
+ for dataset_name, proposal_file in zip(dataset_names, proposal_files):
407
+ dataset_dicts = dataset_name_to_dicts[dataset_name]
408
+ assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
409
+ if proposal_file is not None:
410
+ dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
411
+ dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
412
+ print_instances_class_histogram(dataset_dicts, merged_category_names)
413
+ dataset_name_to_dicts[dataset_name] = dataset_dicts
414
+
415
+ if keep_instance_predicate is not None:
416
+ all_datasets_dicts_plain = [
417
+ d
418
+ for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
419
+ if keep_instance_predicate(d)
420
+ ]
421
+ else:
422
+ all_datasets_dicts_plain = list(
423
+ itertools.chain.from_iterable(dataset_name_to_dicts.values())
424
+ )
425
+ return all_datasets_dicts_plain
426
+
427
+
428
+ def build_detection_train_loader(cfg: CfgNode, mapper=None):
429
+ """
430
+ A data loader is created in a way similar to that of Detectron2.
431
+ The main differences are:
432
+ - it allows to combine datasets with different but compatible object category sets
433
+
434
+ The data loader is created by the following steps:
435
+ 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
436
+ 2. Start workers to work on the dicts. Each worker will:
437
+ * Map each metadata dict into another format to be consumed by the model.
438
+ * Batch them by simply putting dicts into a list.
439
+ The batched ``list[mapped_dict]`` is what this dataloader will return.
440
+
441
+ Args:
442
+ cfg (CfgNode): the config
443
+ mapper (callable): a callable which takes a sample (dict) from dataset and
444
+ returns the format to be consumed by the model.
445
+ By default it will be `DatasetMapper(cfg, True)`.
446
+
447
+ Returns:
448
+ an infinite iterator of training data
449
+ """
450
+
451
+ _add_category_whitelists_to_metadata(cfg)
452
+ _add_category_maps_to_metadata(cfg)
453
+ _maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
454
+ dataset_dicts = combine_detection_dataset_dicts(
455
+ cfg.DATASETS.TRAIN,
456
+ keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
457
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
458
+ )
459
+ if mapper is None:
460
+ mapper = DatasetMapper(cfg, True)
461
+ return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
462
+
463
+
464
+ def build_detection_test_loader(cfg, dataset_name, mapper=None):
465
+ """
466
+ Similar to `build_detection_train_loader`.
467
+ But this function uses the given `dataset_name` argument (instead of the names in cfg),
468
+ and uses batch size 1.
469
+
470
+ Args:
471
+ cfg: a detectron2 CfgNode
472
+ dataset_name (str): a name of the dataset that's available in the DatasetCatalog
473
+ mapper (callable): a callable which takes a sample (dict) from dataset
474
+ and returns the format to be consumed by the model.
475
+ By default it will be `DatasetMapper(cfg, False)`.
476
+
477
+ Returns:
478
+ DataLoader: a torch DataLoader, that loads the given detection
479
+ dataset, with test-time transformation and batching.
480
+ """
481
+ _add_category_whitelists_to_metadata(cfg)
482
+ _add_category_maps_to_metadata(cfg)
483
+ _maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
484
+ dataset_dicts = combine_detection_dataset_dicts(
485
+ [dataset_name],
486
+ keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
487
+ proposal_files=(
488
+ [cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
489
+ if cfg.MODEL.LOAD_PROPOSALS
490
+ else None
491
+ ),
492
+ )
493
+ sampler = None
494
+ if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
495
+ sampler = torch.utils.data.SequentialSampler(dataset_dicts)
496
+ if mapper is None:
497
+ mapper = DatasetMapper(cfg, False)
498
+ return d2_build_detection_test_loader(
499
+ dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
500
+ )
501
+
502
+
503
+ def build_frame_selector(cfg: CfgNode):
504
+ strategy = FrameSelectionStrategy(cfg.STRATEGY)
505
+ if strategy == FrameSelectionStrategy.RANDOM_K:
506
+ frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
507
+ elif strategy == FrameSelectionStrategy.FIRST_K:
508
+ frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
509
+ elif strategy == FrameSelectionStrategy.LAST_K:
510
+ frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
511
+ elif strategy == FrameSelectionStrategy.ALL:
512
+ frame_selector = None
513
+ # pyre-fixme[61]: `frame_selector` may not be initialized here.
514
+ return frame_selector
515
+
516
+
517
+ def build_transform(cfg: CfgNode, data_type: str):
518
+ if cfg.TYPE == "resize":
519
+ if data_type == "image":
520
+ return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
521
+ raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
522
+
523
+
524
+ def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
525
+ images_per_worker = _compute_num_images_per_worker(cfg)
526
+ return CombinedDataLoader(loaders, images_per_worker, ratios)
527
+
528
+
529
+ def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
530
+ """
531
+ Build dataset that provides data to bootstrap on
532
+
533
+ Args:
534
+ dataset_name (str): Name of the dataset, needs to have associated metadata
535
+ to load the data
536
+ cfg (CfgNode): bootstrapping config
537
+ Returns:
538
+ Sequence[Tensor] - dataset that provides image batches, Tensors of size
539
+ [N, C, H, W] of type float32
540
+ """
541
+ logger = logging.getLogger(__name__)
542
+ _add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
543
+ meta = MetadataCatalog.get(dataset_name)
544
+ factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
545
+ dataset = None
546
+ if factory is not None:
547
+ dataset = factory(meta, cfg)
548
+ if dataset is None:
549
+ logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
550
+ return dataset
551
+
552
+
553
+ def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
554
+ if sampler_cfg.TYPE == "densepose_uniform":
555
+ data_sampler = PredictionToGroundTruthSampler()
556
+ # transform densepose pred -> gt
557
+ data_sampler.register_sampler(
558
+ "pred_densepose",
559
+ "gt_densepose",
560
+ DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
561
+ )
562
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
563
+ return data_sampler
564
+ elif sampler_cfg.TYPE == "densepose_UV_confidence":
565
+ data_sampler = PredictionToGroundTruthSampler()
566
+ # transform densepose pred -> gt
567
+ data_sampler.register_sampler(
568
+ "pred_densepose",
569
+ "gt_densepose",
570
+ DensePoseConfidenceBasedSampler(
571
+ confidence_channel="sigma_2",
572
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
573
+ search_proportion=0.5,
574
+ ),
575
+ )
576
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
577
+ return data_sampler
578
+ elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
579
+ data_sampler = PredictionToGroundTruthSampler()
580
+ # transform densepose pred -> gt
581
+ data_sampler.register_sampler(
582
+ "pred_densepose",
583
+ "gt_densepose",
584
+ DensePoseConfidenceBasedSampler(
585
+ confidence_channel="fine_segm_confidence",
586
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
587
+ search_proportion=0.5,
588
+ ),
589
+ )
590
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
591
+ return data_sampler
592
+ elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
593
+ data_sampler = PredictionToGroundTruthSampler()
594
+ # transform densepose pred -> gt
595
+ data_sampler.register_sampler(
596
+ "pred_densepose",
597
+ "gt_densepose",
598
+ DensePoseConfidenceBasedSampler(
599
+ confidence_channel="coarse_segm_confidence",
600
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
601
+ search_proportion=0.5,
602
+ ),
603
+ )
604
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
605
+ return data_sampler
606
+ elif sampler_cfg.TYPE == "densepose_cse_uniform":
607
+ assert embedder is not None
608
+ data_sampler = PredictionToGroundTruthSampler()
609
+ # transform densepose pred -> gt
610
+ data_sampler.register_sampler(
611
+ "pred_densepose",
612
+ "gt_densepose",
613
+ DensePoseCSEUniformSampler(
614
+ cfg=cfg,
615
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
616
+ embedder=embedder,
617
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
618
+ ),
619
+ )
620
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
621
+ return data_sampler
622
+ elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
623
+ assert embedder is not None
624
+ data_sampler = PredictionToGroundTruthSampler()
625
+ # transform densepose pred -> gt
626
+ data_sampler.register_sampler(
627
+ "pred_densepose",
628
+ "gt_densepose",
629
+ DensePoseCSEConfidenceBasedSampler(
630
+ cfg=cfg,
631
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
632
+ embedder=embedder,
633
+ confidence_channel="coarse_segm_confidence",
634
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
635
+ search_proportion=0.5,
636
+ ),
637
+ )
638
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
639
+ return data_sampler
640
+
641
+ raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
642
+
643
+
644
+ def build_data_filter(cfg: CfgNode):
645
+ if cfg.TYPE == "detection_score":
646
+ min_score = cfg.MIN_VALUE
647
+ return ScoreBasedFilter(min_score=min_score)
648
+ raise ValueError(f"Unknown data filter type {cfg.TYPE}")
649
+
650
+
651
+ def build_inference_based_loader(
652
+ cfg: CfgNode,
653
+ dataset_cfg: CfgNode,
654
+ model: torch.nn.Module,
655
+ embedder: Optional[torch.nn.Module] = None,
656
+ ) -> InferenceBasedLoader:
657
+ """
658
+ Constructs data loader based on inference results of a model.
659
+ """
660
+ dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
661
+ meta = MetadataCatalog.get(dataset_cfg.DATASET)
662
+ training_sampler = TrainingSampler(len(dataset))
663
+ data_loader = torch.utils.data.DataLoader(
664
+ dataset, # pyre-ignore[6]
665
+ batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
666
+ sampler=training_sampler,
667
+ num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
668
+ collate_fn=trivial_batch_collator,
669
+ worker_init_fn=worker_init_reset_seed,
670
+ )
671
+ return InferenceBasedLoader(
672
+ model,
673
+ data_loader=data_loader,
674
+ data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
675
+ data_filter=build_data_filter(dataset_cfg.FILTER),
676
+ shuffle=True,
677
+ batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
678
+ inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
679
+ category_to_class_mapping=meta.category_to_class_mapping,
680
+ )
681
+
682
+
683
+ def has_inference_based_loaders(cfg: CfgNode) -> bool:
684
+ """
685
+ Returns True, if at least one inferense-based loader must
686
+ be instantiated for training
687
+ """
688
+ return len(cfg.BOOTSTRAP_DATASETS) > 0
689
+
690
+
691
+ def build_inference_based_loaders(
692
+ cfg: CfgNode, model: torch.nn.Module
693
+ ) -> Tuple[List[InferenceBasedLoader], List[float]]:
694
+ loaders = []
695
+ ratios = []
696
+ embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
697
+ for dataset_spec in cfg.BOOTSTRAP_DATASETS:
698
+ dataset_cfg = get_bootstrap_dataset_config().clone()
699
+ dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
700
+ loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
701
+ loaders.append(loader)
702
+ ratios.append(dataset_cfg.RATIO)
703
+ return loaders, ratios
704
+
705
+
706
+ def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
707
+ video_list_fpath = meta.video_list_fpath
708
+ video_base_path = meta.video_base_path
709
+ category = meta.category
710
+ if cfg.TYPE == "video_keyframe":
711
+ frame_selector = build_frame_selector(cfg.SELECT)
712
+ transform = build_transform(cfg.TRANSFORM, data_type="image")
713
+ video_list = video_list_from_file(video_list_fpath, video_base_path)
714
+ keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
715
+ return VideoKeyframeDataset(
716
+ video_list, category, frame_selector, transform, keyframe_helper_fpath
717
+ )
718
+
719
+
720
+ class _BootstrapDatasetFactoryCatalog(UserDict):
721
+ """
722
+ A global dictionary that stores information about bootstrapped datasets creation functions
723
+ from metadata and config, for diverse DatasetType
724
+ """
725
+
726
+ def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
727
+ """
728
+ Args:
729
+ dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
730
+ factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
731
+ arguments and returns a dataset object.
732
+ """
733
+ assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
734
+ self[dataset_type] = factory
735
+
736
+
737
+ BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
738
+ BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
densepose/data/combined_loader.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections import deque
7
+ from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
8
+
9
+ Loader = Iterable[Any]
10
+
11
+
12
+ def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
13
+ if not pool:
14
+ pool.extend(next(iterator))
15
+ return pool.popleft()
16
+
17
+
18
+ class CombinedDataLoader:
19
+ """
20
+ Combines data loaders using the provided sampling ratios
21
+ """
22
+
23
+ BATCH_COUNT = 100
24
+
25
+ def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
26
+ self.loaders = loaders
27
+ self.batch_size = batch_size
28
+ self.ratios = ratios
29
+
30
+ def __iter__(self) -> Iterator[List[Any]]:
31
+ iters = [iter(loader) for loader in self.loaders]
32
+ indices = []
33
+ pool = [deque()] * len(iters)
34
+ # infinite iterator, as in D2
35
+ while True:
36
+ if not indices:
37
+ # just a buffer of indices, its size doesn't matter
38
+ # as long as it's a multiple of batch_size
39
+ k = self.batch_size * self.BATCH_COUNT
40
+ indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
41
+ try:
42
+ batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
43
+ except StopIteration:
44
+ break
45
+ indices = indices[self.batch_size :]
46
+ yield batch
densepose/data/dataset_mapper.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import copy
7
+ import logging
8
+ from typing import Any, Dict, List, Tuple
9
+ import torch
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.data import detection_utils as utils
13
+ from detectron2.data import transforms as T
14
+ from detectron2.layers import ROIAlign
15
+ from detectron2.structures import BoxMode
16
+ from detectron2.utils.file_io import PathManager
17
+
18
+ from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
19
+
20
+
21
+ def build_augmentation(cfg, is_train):
22
+ logger = logging.getLogger(__name__)
23
+ result = utils.build_augmentation(cfg, is_train)
24
+ if is_train:
25
+ random_rotation = T.RandomRotation(
26
+ cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
27
+ )
28
+ result.append(random_rotation)
29
+ logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
30
+ return result
31
+
32
+
33
+ class DatasetMapper:
34
+ """
35
+ A customized version of `detectron2.data.DatasetMapper`
36
+ """
37
+
38
+ def __init__(self, cfg, is_train=True):
39
+ self.augmentation = build_augmentation(cfg, is_train)
40
+
41
+ # fmt: off
42
+ self.img_format = cfg.INPUT.FORMAT
43
+ self.mask_on = (
44
+ cfg.MODEL.MASK_ON or (
45
+ cfg.MODEL.DENSEPOSE_ON
46
+ and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
47
+ )
48
+ self.keypoint_on = cfg.MODEL.KEYPOINT_ON
49
+ self.densepose_on = cfg.MODEL.DENSEPOSE_ON
50
+ assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
51
+ # fmt: on
52
+ if self.keypoint_on and is_train:
53
+ # Flip only makes sense in training
54
+ self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
55
+ else:
56
+ self.keypoint_hflip_indices = None
57
+
58
+ if self.densepose_on:
59
+ densepose_transform_srcs = [
60
+ MetadataCatalog.get(ds).densepose_transform_src
61
+ for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
62
+ ]
63
+ assert len(densepose_transform_srcs) > 0
64
+ # TODO: check that DensePose transformation data is the same for
65
+ # all the datasets. Otherwise one would have to pass DB ID with
66
+ # each entry to select proper transformation data. For now, since
67
+ # all DensePose annotated data uses the same data semantics, we
68
+ # omit this check.
69
+ densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
70
+ self.densepose_transform_data = DensePoseTransformData.load(
71
+ densepose_transform_data_fpath
72
+ )
73
+
74
+ self.is_train = is_train
75
+
76
+ def __call__(self, dataset_dict):
77
+ """
78
+ Args:
79
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
80
+
81
+ Returns:
82
+ dict: a format that builtin models in detectron2 accept
83
+ """
84
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
85
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
86
+ utils.check_image_size(dataset_dict, image)
87
+
88
+ image, transforms = T.apply_transform_gens(self.augmentation, image)
89
+ image_shape = image.shape[:2] # h, w
90
+ dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
91
+
92
+ if not self.is_train:
93
+ dataset_dict.pop("annotations", None)
94
+ return dataset_dict
95
+
96
+ for anno in dataset_dict["annotations"]:
97
+ if not self.mask_on:
98
+ anno.pop("segmentation", None)
99
+ if not self.keypoint_on:
100
+ anno.pop("keypoints", None)
101
+
102
+ # USER: Implement additional transformations if you have other types of data
103
+ # USER: Don't call transpose_densepose if you don't need
104
+ annos = [
105
+ self._transform_densepose(
106
+ utils.transform_instance_annotations(
107
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
108
+ ),
109
+ transforms,
110
+ )
111
+ for obj in dataset_dict.pop("annotations")
112
+ if obj.get("iscrowd", 0) == 0
113
+ ]
114
+
115
+ if self.mask_on:
116
+ self._add_densepose_masks_as_segmentation(annos, image_shape)
117
+
118
+ instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
119
+ densepose_annotations = [obj.get("densepose") for obj in annos]
120
+ if densepose_annotations and not all(v is None for v in densepose_annotations):
121
+ instances.gt_densepose = DensePoseList(
122
+ densepose_annotations, instances.gt_boxes, image_shape
123
+ )
124
+
125
+ dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
126
+ return dataset_dict
127
+
128
+ def _transform_densepose(self, annotation, transforms):
129
+ if not self.densepose_on:
130
+ return annotation
131
+
132
+ # Handle densepose annotations
133
+ is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
134
+ if is_valid:
135
+ densepose_data = DensePoseDataRelative(annotation, cleanup=True)
136
+ densepose_data.apply_transform(transforms, self.densepose_transform_data)
137
+ annotation["densepose"] = densepose_data
138
+ else:
139
+ # logger = logging.getLogger(__name__)
140
+ # logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
141
+ DensePoseDataRelative.cleanup_annotation(annotation)
142
+ # NOTE: annotations for certain instances may be unavailable.
143
+ # 'None' is accepted by the DensePostList data structure.
144
+ annotation["densepose"] = None
145
+ return annotation
146
+
147
+ def _add_densepose_masks_as_segmentation(
148
+ self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
149
+ ):
150
+ for obj in annotations:
151
+ if ("densepose" not in obj) or ("segmentation" in obj):
152
+ continue
153
+ # DP segmentation: torch.Tensor [S, S] of float32, S=256
154
+ segm_dp = torch.zeros_like(obj["densepose"].segm)
155
+ segm_dp[obj["densepose"].segm > 0] = 1
156
+ segm_h, segm_w = segm_dp.shape
157
+ bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
158
+ # image bbox
159
+ x0, y0, x1, y1 = (
160
+ v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
161
+ )
162
+ segm_aligned = (
163
+ ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
164
+ .forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
165
+ .squeeze()
166
+ )
167
+ image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
168
+ image_mask[y0:y1, x0:x1] = segm_aligned
169
+ # segmentation for BitMask: np.array [H, W] of bool
170
+ obj["segmentation"] = image_mask >= 0.5
densepose/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin # ensure the builtin datasets are registered
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
densepose/data/datasets/builtin.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .chimpnsee import register_dataset as register_chimpnsee_dataset
5
+ from .coco import BASE_DATASETS as BASE_COCO_DATASETS
6
+ from .coco import DATASETS as COCO_DATASETS
7
+ from .coco import register_datasets as register_coco_datasets
8
+ from .lvis import DATASETS as LVIS_DATASETS
9
+ from .lvis import register_datasets as register_lvis_datasets
10
+
11
+ DEFAULT_DATASETS_ROOT = "datasets"
12
+
13
+
14
+ register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
15
+ register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
16
+ register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
17
+
18
+ register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
densepose/data/datasets/chimpnsee.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Optional
6
+
7
+ from detectron2.data import DatasetCatalog, MetadataCatalog
8
+
9
+ from ..utils import maybe_prepend_base_path
10
+ from .dataset_type import DatasetType
11
+
12
+ CHIMPNSEE_DATASET_NAME = "chimpnsee"
13
+
14
+
15
+ def register_dataset(datasets_root: Optional[str] = None) -> None:
16
+ def empty_load_callback():
17
+ pass
18
+
19
+ video_list_fpath = maybe_prepend_base_path(
20
+ datasets_root,
21
+ "chimpnsee/cdna.eva.mpg.de/video_list.txt",
22
+ )
23
+ video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de")
24
+
25
+ DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback)
26
+ MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set(
27
+ dataset_type=DatasetType.VIDEO_LIST,
28
+ video_list_fpath=video_list_fpath,
29
+ video_base_path=video_base_path,
30
+ category="chimpanzee",
31
+ )
densepose/data/datasets/coco.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import contextlib
5
+ import io
6
+ import logging
7
+ import os
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+ from fvcore.common.timer import Timer
12
+
13
+ from detectron2.data import DatasetCatalog, MetadataCatalog
14
+ from detectron2.structures import BoxMode
15
+ from detectron2.utils.file_io import PathManager
16
+
17
+ from ..utils import maybe_prepend_base_path
18
+
19
+ DENSEPOSE_MASK_KEY = "dp_masks"
20
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"]
21
+ DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"]
22
+ DENSEPOSE_ALL_POSSIBLE_KEYS = set(
23
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY]
24
+ )
25
+ DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/"
26
+
27
+
28
+ @dataclass
29
+ class CocoDatasetInfo:
30
+ name: str
31
+ images_root: str
32
+ annotations_fpath: str
33
+
34
+
35
+ DATASETS = [
36
+ CocoDatasetInfo(
37
+ name="densepose_coco_2014_train",
38
+ images_root="coco/train2014",
39
+ annotations_fpath="coco/annotations/densepose_train2014.json",
40
+ ),
41
+ CocoDatasetInfo(
42
+ name="densepose_coco_2014_minival",
43
+ images_root="coco/val2014",
44
+ annotations_fpath="coco/annotations/densepose_minival2014.json",
45
+ ),
46
+ CocoDatasetInfo(
47
+ name="densepose_coco_2014_minival_100",
48
+ images_root="coco/val2014",
49
+ annotations_fpath="coco/annotations/densepose_minival2014_100.json",
50
+ ),
51
+ CocoDatasetInfo(
52
+ name="densepose_coco_2014_valminusminival",
53
+ images_root="coco/val2014",
54
+ annotations_fpath="coco/annotations/densepose_valminusminival2014.json",
55
+ ),
56
+ CocoDatasetInfo(
57
+ name="densepose_coco_2014_train_cse",
58
+ images_root="coco/train2014",
59
+ annotations_fpath="coco_cse/densepose_train2014_cse.json",
60
+ ),
61
+ CocoDatasetInfo(
62
+ name="densepose_coco_2014_minival_cse",
63
+ images_root="coco/val2014",
64
+ annotations_fpath="coco_cse/densepose_minival2014_cse.json",
65
+ ),
66
+ CocoDatasetInfo(
67
+ name="densepose_coco_2014_minival_100_cse",
68
+ images_root="coco/val2014",
69
+ annotations_fpath="coco_cse/densepose_minival2014_100_cse.json",
70
+ ),
71
+ CocoDatasetInfo(
72
+ name="densepose_coco_2014_valminusminival_cse",
73
+ images_root="coco/val2014",
74
+ annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json",
75
+ ),
76
+ CocoDatasetInfo(
77
+ name="densepose_chimps",
78
+ images_root="densepose_chimps/images",
79
+ annotations_fpath="densepose_chimps/densepose_chimps_densepose.json",
80
+ ),
81
+ CocoDatasetInfo(
82
+ name="densepose_chimps_cse_train",
83
+ images_root="densepose_chimps/images",
84
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json",
85
+ ),
86
+ CocoDatasetInfo(
87
+ name="densepose_chimps_cse_val",
88
+ images_root="densepose_chimps/images",
89
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json",
90
+ ),
91
+ CocoDatasetInfo(
92
+ name="posetrack2017_train",
93
+ images_root="posetrack2017/posetrack_data_2017",
94
+ annotations_fpath="posetrack2017/densepose_posetrack_train2017.json",
95
+ ),
96
+ CocoDatasetInfo(
97
+ name="posetrack2017_val",
98
+ images_root="posetrack2017/posetrack_data_2017",
99
+ annotations_fpath="posetrack2017/densepose_posetrack_val2017.json",
100
+ ),
101
+ CocoDatasetInfo(
102
+ name="lvis_v05_train",
103
+ images_root="coco/train2017",
104
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json",
105
+ ),
106
+ CocoDatasetInfo(
107
+ name="lvis_v05_val",
108
+ images_root="coco/val2017",
109
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json",
110
+ ),
111
+ ]
112
+
113
+
114
+ BASE_DATASETS = [
115
+ CocoDatasetInfo(
116
+ name="base_coco_2017_train",
117
+ images_root="coco/train2017",
118
+ annotations_fpath="coco/annotations/instances_train2017.json",
119
+ ),
120
+ CocoDatasetInfo(
121
+ name="base_coco_2017_val",
122
+ images_root="coco/val2017",
123
+ annotations_fpath="coco/annotations/instances_val2017.json",
124
+ ),
125
+ CocoDatasetInfo(
126
+ name="base_coco_2017_val_100",
127
+ images_root="coco/val2017",
128
+ annotations_fpath="coco/annotations/instances_val2017_100.json",
129
+ ),
130
+ ]
131
+
132
+
133
+ def get_metadata(base_path: Optional[str]) -> Dict[str, Any]:
134
+ """
135
+ Returns metadata associated with COCO DensePose datasets
136
+
137
+ Args:
138
+ base_path: Optional[str]
139
+ Base path used to load metadata from
140
+
141
+ Returns:
142
+ Dict[str, Any]
143
+ Metadata in the form of a dictionary
144
+ """
145
+ meta = {
146
+ "densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"),
147
+ "densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"),
148
+ "densepose_smpl_subdiv_transform": maybe_prepend_base_path(
149
+ base_path,
150
+ "SMPL_SUBDIV_TRANSFORM.mat",
151
+ ),
152
+ }
153
+ return meta
154
+
155
+
156
+ def _load_coco_annotations(json_file: str):
157
+ """
158
+ Load COCO annotations from a JSON file
159
+
160
+ Args:
161
+ json_file: str
162
+ Path to the file to load annotations from
163
+ Returns:
164
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
165
+ data
166
+ """
167
+ from pycocotools.coco import COCO
168
+
169
+ logger = logging.getLogger(__name__)
170
+ timer = Timer()
171
+ with contextlib.redirect_stdout(io.StringIO()):
172
+ coco_api = COCO(json_file)
173
+ if timer.seconds() > 1:
174
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
175
+ return coco_api
176
+
177
+
178
+ def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]):
179
+ meta = MetadataCatalog.get(dataset_name)
180
+ meta.categories = {c["id"]: c["name"] for c in categories}
181
+ logger = logging.getLogger(__name__)
182
+ logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories))
183
+
184
+
185
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]):
186
+ if "minival" in json_file:
187
+ # Skip validation on COCO2014 valminusminival and minival annotations
188
+ # The ratio of buggy annotations there is tiny and does not affect accuracy
189
+ # Therefore we explicitly white-list them
190
+ return
191
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
192
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
193
+ json_file
194
+ )
195
+
196
+
197
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
198
+ if "bbox" not in ann_dict:
199
+ return
200
+ obj["bbox"] = ann_dict["bbox"]
201
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
202
+
203
+
204
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
205
+ if "segmentation" not in ann_dict:
206
+ return
207
+ segm = ann_dict["segmentation"]
208
+ if not isinstance(segm, dict):
209
+ # filter out invalid polygons (< 3 points)
210
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
211
+ if len(segm) == 0:
212
+ return
213
+ obj["segmentation"] = segm
214
+
215
+
216
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
217
+ if "keypoints" not in ann_dict:
218
+ return
219
+ keypts = ann_dict["keypoints"] # list[int]
220
+ for idx, v in enumerate(keypts):
221
+ if idx % 3 != 2:
222
+ # COCO's segmentation coordinates are floating points in [0, H or W],
223
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
224
+ # Therefore we assume the coordinates are "pixel indices" and
225
+ # add 0.5 to convert to floating point coordinates.
226
+ keypts[idx] = v + 0.5
227
+ obj["keypoints"] = keypts
228
+
229
+
230
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
231
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
232
+ if key in ann_dict:
233
+ obj[key] = ann_dict[key]
234
+
235
+
236
+ def _combine_images_with_annotations(
237
+ dataset_name: str,
238
+ image_root: str,
239
+ img_datas: Iterable[Dict[str, Any]],
240
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
241
+ ):
242
+
243
+ ann_keys = ["iscrowd", "category_id"]
244
+ dataset_dicts = []
245
+ contains_video_frame_info = False
246
+
247
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
248
+ record = {}
249
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
250
+ record["height"] = img_dict["height"]
251
+ record["width"] = img_dict["width"]
252
+ record["image_id"] = img_dict["id"]
253
+ record["dataset"] = dataset_name
254
+ if "frame_id" in img_dict:
255
+ record["frame_id"] = img_dict["frame_id"]
256
+ record["video_id"] = img_dict.get("vid_id", None)
257
+ contains_video_frame_info = True
258
+ objs = []
259
+ for ann_dict in ann_dicts:
260
+ assert ann_dict["image_id"] == record["image_id"]
261
+ assert ann_dict.get("ignore", 0) == 0
262
+ obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict}
263
+ _maybe_add_bbox(obj, ann_dict)
264
+ _maybe_add_segm(obj, ann_dict)
265
+ _maybe_add_keypoints(obj, ann_dict)
266
+ _maybe_add_densepose(obj, ann_dict)
267
+ objs.append(obj)
268
+ record["annotations"] = objs
269
+ dataset_dicts.append(record)
270
+ if contains_video_frame_info:
271
+ create_video_frame_mapping(dataset_name, dataset_dicts)
272
+ return dataset_dicts
273
+
274
+
275
+ def get_contiguous_id_to_category_id_map(metadata):
276
+ cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id
277
+ cont_id_2_cat_id = {}
278
+ for cat_id, cont_id in cat_id_2_cont_id.items():
279
+ if cont_id in cont_id_2_cat_id:
280
+ continue
281
+ cont_id_2_cat_id[cont_id] = cat_id
282
+ return cont_id_2_cat_id
283
+
284
+
285
+ def maybe_filter_categories_cocoapi(dataset_name, coco_api):
286
+ meta = MetadataCatalog.get(dataset_name)
287
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta)
288
+ cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id
289
+ # filter categories
290
+ cats = []
291
+ for cat in coco_api.dataset["categories"]:
292
+ cat_id = cat["id"]
293
+ if cat_id not in cat_id_2_cont_id:
294
+ continue
295
+ cont_id = cat_id_2_cont_id[cat_id]
296
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
297
+ cats.append(cat)
298
+ coco_api.dataset["categories"] = cats
299
+ # filter annotations, if multiple categories are mapped to a single
300
+ # contiguous ID, use only one category ID and map all annotations to that category ID
301
+ anns = []
302
+ for ann in coco_api.dataset["annotations"]:
303
+ cat_id = ann["category_id"]
304
+ if cat_id not in cat_id_2_cont_id:
305
+ continue
306
+ cont_id = cat_id_2_cont_id[cat_id]
307
+ ann["category_id"] = cont_id_2_cat_id[cont_id]
308
+ anns.append(ann)
309
+ coco_api.dataset["annotations"] = anns
310
+ # recreate index
311
+ coco_api.createIndex()
312
+
313
+
314
+ def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api):
315
+ meta = MetadataCatalog.get(dataset_name)
316
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
317
+ # map categories
318
+ cats = []
319
+ for cat in coco_api.dataset["categories"]:
320
+ cat_id = cat["id"]
321
+ if cat_id not in category_id_map:
322
+ continue
323
+ cat["id"] = category_id_map[cat_id]
324
+ cats.append(cat)
325
+ coco_api.dataset["categories"] = cats
326
+ # map annotation categories
327
+ anns = []
328
+ for ann in coco_api.dataset["annotations"]:
329
+ cat_id = ann["category_id"]
330
+ if cat_id not in category_id_map:
331
+ continue
332
+ ann["category_id"] = category_id_map[cat_id]
333
+ anns.append(ann)
334
+ coco_api.dataset["annotations"] = anns
335
+ # recreate index
336
+ coco_api.createIndex()
337
+
338
+
339
+ def create_video_frame_mapping(dataset_name, dataset_dicts):
340
+ mapping = defaultdict(dict)
341
+ for d in dataset_dicts:
342
+ video_id = d.get("video_id")
343
+ if video_id is None:
344
+ continue
345
+ mapping[video_id].update({d["frame_id"]: d["file_name"]})
346
+ MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping)
347
+
348
+
349
+ def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str):
350
+ """
351
+ Loads a JSON file with annotations in COCO instances format.
352
+ Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata
353
+ in a more flexible way. Postpones category mapping to a later stage to be
354
+ able to combine several datasets with different (but coherent) sets of
355
+ categories.
356
+
357
+ Args:
358
+
359
+ annotations_json_file: str
360
+ Path to the JSON file with annotations in COCO instances format.
361
+ image_root: str
362
+ directory that contains all the images
363
+ dataset_name: str
364
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
365
+ extra_annotation_keys: Optional[List[str]]
366
+ If provided, these keys are used to extract additional data from
367
+ the annotations.
368
+ """
369
+ coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file))
370
+ _add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds()))
371
+ # sort indices for reproducible results
372
+ img_ids = sorted(coco_api.imgs.keys())
373
+ # imgs is a list of dicts, each looks something like:
374
+ # {'license': 4,
375
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
376
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
377
+ # 'height': 427,
378
+ # 'width': 640,
379
+ # 'date_captured': '2013-11-17 05:57:24',
380
+ # 'id': 1268}
381
+ imgs = coco_api.loadImgs(img_ids)
382
+ logger = logging.getLogger(__name__)
383
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file))
384
+ # anns is a list[list[dict]], where each dict is an annotation
385
+ # record for an object. The inner list enumerates the objects in an image
386
+ # and the outer list enumerates over images.
387
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
388
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
389
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
390
+ return dataset_records
391
+
392
+
393
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None):
394
+ """
395
+ Registers provided COCO DensePose dataset
396
+
397
+ Args:
398
+ dataset_data: CocoDatasetInfo
399
+ Dataset data
400
+ datasets_root: Optional[str]
401
+ Datasets root folder (default: None)
402
+ """
403
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
404
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
405
+
406
+ def load_annotations():
407
+ return load_coco_json(
408
+ annotations_json_file=annotations_fpath,
409
+ image_root=images_root,
410
+ dataset_name=dataset_data.name,
411
+ )
412
+
413
+ DatasetCatalog.register(dataset_data.name, load_annotations)
414
+ MetadataCatalog.get(dataset_data.name).set(
415
+ json_file=annotations_fpath,
416
+ image_root=images_root,
417
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX)
418
+ )
419
+
420
+
421
+ def register_datasets(
422
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
423
+ ):
424
+ """
425
+ Registers provided COCO DensePose datasets
426
+
427
+ Args:
428
+ datasets_data: Iterable[CocoDatasetInfo]
429
+ An iterable of dataset datas
430
+ datasets_root: Optional[str]
431
+ Datasets root folder (default: None)
432
+ """
433
+ for dataset_data in datasets_data:
434
+ register_dataset(dataset_data, datasets_root)
densepose/data/datasets/dataset_type.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class DatasetType(Enum):
9
+ """
10
+ Dataset type, mostly used for datasets that contain data to bootstrap models on
11
+ """
12
+
13
+ VIDEO_LIST = "video_list"
densepose/data/datasets/lvis.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import logging
5
+ import os
6
+ from typing import Any, Dict, Iterable, List, Optional
7
+ from fvcore.common.timer import Timer
8
+
9
+ from detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from detectron2.data.datasets.lvis import get_lvis_instances_meta
11
+ from detectron2.structures import BoxMode
12
+ from detectron2.utils.file_io import PathManager
13
+
14
+ from ..utils import maybe_prepend_base_path
15
+ from .coco import (
16
+ DENSEPOSE_ALL_POSSIBLE_KEYS,
17
+ DENSEPOSE_METADATA_URL_PREFIX,
18
+ CocoDatasetInfo,
19
+ get_metadata,
20
+ )
21
+
22
+ DATASETS = [
23
+ CocoDatasetInfo(
24
+ name="densepose_lvis_v1_ds1_train_v1",
25
+ images_root="coco_",
26
+ annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json",
27
+ ),
28
+ CocoDatasetInfo(
29
+ name="densepose_lvis_v1_ds1_val_v1",
30
+ images_root="coco_",
31
+ annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json",
32
+ ),
33
+ CocoDatasetInfo(
34
+ name="densepose_lvis_v1_ds2_train_v1",
35
+ images_root="coco_",
36
+ annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json",
37
+ ),
38
+ CocoDatasetInfo(
39
+ name="densepose_lvis_v1_ds2_val_v1",
40
+ images_root="coco_",
41
+ annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json",
42
+ ),
43
+ CocoDatasetInfo(
44
+ name="densepose_lvis_v1_ds1_val_animals_100",
45
+ images_root="coco_",
46
+ annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json",
47
+ ),
48
+ ]
49
+
50
+
51
+ def _load_lvis_annotations(json_file: str):
52
+ """
53
+ Load COCO annotations from a JSON file
54
+
55
+ Args:
56
+ json_file: str
57
+ Path to the file to load annotations from
58
+ Returns:
59
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
60
+ data
61
+ """
62
+ from lvis import LVIS
63
+
64
+ json_file = PathManager.get_local_path(json_file)
65
+ logger = logging.getLogger(__name__)
66
+ timer = Timer()
67
+ lvis_api = LVIS(json_file)
68
+ if timer.seconds() > 1:
69
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
70
+ return lvis_api
71
+
72
+
73
+ def _add_categories_metadata(dataset_name: str) -> None:
74
+ metadict = get_lvis_instances_meta(dataset_name)
75
+ categories = metadict["thing_classes"]
76
+ metadata = MetadataCatalog.get(dataset_name)
77
+ metadata.categories = {i + 1: categories[i] for i in range(len(categories))}
78
+ logger = logging.getLogger(__name__)
79
+ logger.info(f"Dataset {dataset_name} has {len(categories)} categories")
80
+
81
+
82
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None:
83
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
84
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
85
+ json_file
86
+ )
87
+
88
+
89
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
90
+ if "bbox" not in ann_dict:
91
+ return
92
+ obj["bbox"] = ann_dict["bbox"]
93
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
94
+
95
+
96
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
97
+ if "segmentation" not in ann_dict:
98
+ return
99
+ segm = ann_dict["segmentation"]
100
+ if not isinstance(segm, dict):
101
+ # filter out invalid polygons (< 3 points)
102
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
103
+ if len(segm) == 0:
104
+ return
105
+ obj["segmentation"] = segm
106
+
107
+
108
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
109
+ if "keypoints" not in ann_dict:
110
+ return
111
+ keypts = ann_dict["keypoints"] # list[int]
112
+ for idx, v in enumerate(keypts):
113
+ if idx % 3 != 2:
114
+ # COCO's segmentation coordinates are floating points in [0, H or W],
115
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
116
+ # Therefore we assume the coordinates are "pixel indices" and
117
+ # add 0.5 to convert to floating point coordinates.
118
+ keypts[idx] = v + 0.5
119
+ obj["keypoints"] = keypts
120
+
121
+
122
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
123
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
124
+ if key in ann_dict:
125
+ obj[key] = ann_dict[key]
126
+
127
+
128
+ def _combine_images_with_annotations(
129
+ dataset_name: str,
130
+ image_root: str,
131
+ img_datas: Iterable[Dict[str, Any]],
132
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
133
+ ):
134
+
135
+ dataset_dicts = []
136
+
137
+ def get_file_name(img_root, img_dict):
138
+ # Determine the path including the split folder ("train2017", "val2017", "test2017") from
139
+ # the coco_url field. Example:
140
+ # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
141
+ split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
142
+ return os.path.join(img_root + split_folder, file_name)
143
+
144
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
145
+ record = {}
146
+ record["file_name"] = get_file_name(image_root, img_dict)
147
+ record["height"] = img_dict["height"]
148
+ record["width"] = img_dict["width"]
149
+ record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
150
+ record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
151
+ record["image_id"] = img_dict["id"]
152
+ record["dataset"] = dataset_name
153
+
154
+ objs = []
155
+ for ann_dict in ann_dicts:
156
+ assert ann_dict["image_id"] == record["image_id"]
157
+ obj = {}
158
+ _maybe_add_bbox(obj, ann_dict)
159
+ obj["iscrowd"] = ann_dict.get("iscrowd", 0)
160
+ obj["category_id"] = ann_dict["category_id"]
161
+ _maybe_add_segm(obj, ann_dict)
162
+ _maybe_add_keypoints(obj, ann_dict)
163
+ _maybe_add_densepose(obj, ann_dict)
164
+ objs.append(obj)
165
+ record["annotations"] = objs
166
+ dataset_dicts.append(record)
167
+ return dataset_dicts
168
+
169
+
170
+ def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str):
171
+ """
172
+ Loads a JSON file with annotations in LVIS instances format.
173
+ Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata
174
+ in a more flexible way. Postpones category mapping to a later stage to be
175
+ able to combine several datasets with different (but coherent) sets of
176
+ categories.
177
+
178
+ Args:
179
+
180
+ annotations_json_file: str
181
+ Path to the JSON file with annotations in COCO instances format.
182
+ image_root: str
183
+ directory that contains all the images
184
+ dataset_name: str
185
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
186
+ extra_annotation_keys: Optional[List[str]]
187
+ If provided, these keys are used to extract additional data from
188
+ the annotations.
189
+ """
190
+ lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file))
191
+
192
+ _add_categories_metadata(dataset_name)
193
+
194
+ # sort indices for reproducible results
195
+ img_ids = sorted(lvis_api.imgs.keys())
196
+ # imgs is a list of dicts, each looks something like:
197
+ # {'license': 4,
198
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
199
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
200
+ # 'height': 427,
201
+ # 'width': 640,
202
+ # 'date_captured': '2013-11-17 05:57:24',
203
+ # 'id': 1268}
204
+ imgs = lvis_api.load_imgs(img_ids)
205
+ logger = logging.getLogger(__name__)
206
+ logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file))
207
+ # anns is a list[list[dict]], where each dict is an annotation
208
+ # record for an object. The inner list enumerates the objects in an image
209
+ # and the outer list enumerates over images.
210
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
211
+
212
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
213
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
214
+ return dataset_records
215
+
216
+
217
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None:
218
+ """
219
+ Registers provided LVIS DensePose dataset
220
+
221
+ Args:
222
+ dataset_data: CocoDatasetInfo
223
+ Dataset data
224
+ datasets_root: Optional[str]
225
+ Datasets root folder (default: None)
226
+ """
227
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
228
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
229
+
230
+ def load_annotations():
231
+ return load_lvis_json(
232
+ annotations_json_file=annotations_fpath,
233
+ image_root=images_root,
234
+ dataset_name=dataset_data.name,
235
+ )
236
+
237
+ DatasetCatalog.register(dataset_data.name, load_annotations)
238
+ MetadataCatalog.get(dataset_data.name).set(
239
+ json_file=annotations_fpath,
240
+ image_root=images_root,
241
+ evaluator_type="lvis",
242
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX),
243
+ )
244
+
245
+
246
+ def register_datasets(
247
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
248
+ ) -> None:
249
+ """
250
+ Registers provided LVIS DensePose datasets
251
+
252
+ Args:
253
+ datasets_data: Iterable[CocoDatasetInfo]
254
+ An iterable of dataset datas
255
+ datasets_root: Optional[str]
256
+ Datasets root folder (default: None)
257
+ """
258
+ for dataset_data in datasets_data:
259
+ register_dataset(dataset_data, datasets_root)
densepose/data/image_list_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import logging
7
+ import numpy as np
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+ import torch
10
+ from torch.utils.data.dataset import Dataset
11
+
12
+ from detectron2.data.detection_utils import read_image
13
+
14
+ ImageTransform = Callable[[torch.Tensor], torch.Tensor]
15
+
16
+
17
+ class ImageListDataset(Dataset):
18
+ """
19
+ Dataset that provides images from a list.
20
+ """
21
+
22
+ _EMPTY_IMAGE = torch.empty((0, 3, 1, 1))
23
+
24
+ def __init__(
25
+ self,
26
+ image_list: List[str],
27
+ category_list: Union[str, List[str], None] = None,
28
+ transform: Optional[ImageTransform] = None,
29
+ ):
30
+ """
31
+ Args:
32
+ image_list (List[str]): list of paths to image files
33
+ category_list (Union[str, List[str], None]): list of animal categories for
34
+ each image. If it is a string, or None, this applies to all images
35
+ """
36
+ if type(category_list) is list:
37
+ self.category_list = category_list
38
+ else:
39
+ self.category_list = [category_list] * len(image_list)
40
+ assert len(image_list) == len(
41
+ self.category_list
42
+ ), "length of image and category lists must be equal"
43
+ self.image_list = image_list
44
+ self.transform = transform
45
+
46
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
47
+ """
48
+ Gets selected images from the list
49
+
50
+ Args:
51
+ idx (int): video index in the video list file
52
+ Returns:
53
+ A dictionary containing two keys:
54
+ images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE)
55
+ categories (List[str]): categories of the frames
56
+ """
57
+ categories = [self.category_list[idx]]
58
+ fpath = self.image_list[idx]
59
+ transform = self.transform
60
+
61
+ try:
62
+ image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR")))
63
+ image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW
64
+ if transform is not None:
65
+ image = transform(image)
66
+ return {"images": image, "categories": categories}
67
+ except (OSError, RuntimeError) as e:
68
+ logger = logging.getLogger(__name__)
69
+ logger.warning(f"Error opening image file container {fpath}: {e}")
70
+
71
+ return {"images": self._EMPTY_IMAGE, "categories": []}
72
+
73
+ def __len__(self):
74
+ return len(self.image_list)
densepose/data/inference_based_loader.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
7
+ import torch
8
+ from torch import nn
9
+
10
+ SampledData = Any
11
+ ModelOutput = Any
12
+
13
+
14
+ def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
15
+ """
16
+ Group elements of an iterable by chunks of size `n`, e.g.
17
+ grouper(range(9), 4) ->
18
+ (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
19
+ """
20
+ it = iter(iterable)
21
+ while True:
22
+ values = []
23
+ for _ in range(n):
24
+ try:
25
+ value = next(it)
26
+ except StopIteration:
27
+ if values:
28
+ values.extend([fillvalue] * (n - len(values)))
29
+ yield tuple(values)
30
+ return
31
+ values.append(value)
32
+ yield tuple(values)
33
+
34
+
35
+ class ScoreBasedFilter:
36
+ """
37
+ Filters entries in model output based on their scores
38
+ Discards all entries with score less than the specified minimum
39
+ """
40
+
41
+ def __init__(self, min_score: float = 0.8):
42
+ self.min_score = min_score
43
+
44
+ def __call__(self, model_output: ModelOutput) -> ModelOutput:
45
+ for model_output_i in model_output:
46
+ instances = model_output_i["instances"]
47
+ if not instances.has("scores"):
48
+ continue
49
+ instances_filtered = instances[instances.scores >= self.min_score]
50
+ model_output_i["instances"] = instances_filtered
51
+ return model_output
52
+
53
+
54
+ class InferenceBasedLoader:
55
+ """
56
+ Data loader based on results inferred by a model. Consists of:
57
+ - a data loader that provides batches of images
58
+ - a model that is used to infer the results
59
+ - a data sampler that converts inferred results to annotations
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ model: nn.Module,
65
+ data_loader: Iterable[List[Dict[str, Any]]],
66
+ data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
67
+ data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
68
+ shuffle: bool = True,
69
+ batch_size: int = 4,
70
+ inference_batch_size: int = 4,
71
+ drop_last: bool = False,
72
+ category_to_class_mapping: Optional[dict] = None,
73
+ ):
74
+ """
75
+ Constructor
76
+
77
+ Args:
78
+ model (torch.nn.Module): model used to produce data
79
+ data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
80
+ dictionaries with "images" and "categories" fields to perform inference on
81
+ data_sampler (Callable: ModelOutput -> SampledData): functor
82
+ that produces annotation data from inference results;
83
+ (optional, default: None)
84
+ data_filter (Callable: ModelOutput -> ModelOutput): filter
85
+ that selects model outputs for further processing
86
+ (optional, default: None)
87
+ shuffle (bool): if True, the input images get shuffled
88
+ batch_size (int): batch size for the produced annotation data
89
+ inference_batch_size (int): batch size for input images
90
+ drop_last (bool): if True, drop the last batch if it is undersized
91
+ category_to_class_mapping (dict): category to class mapping
92
+ """
93
+ self.model = model
94
+ self.model.eval()
95
+ self.data_loader = data_loader
96
+ self.data_sampler = data_sampler
97
+ self.data_filter = data_filter
98
+ self.shuffle = shuffle
99
+ self.batch_size = batch_size
100
+ self.inference_batch_size = inference_batch_size
101
+ self.drop_last = drop_last
102
+ if category_to_class_mapping is not None:
103
+ self.category_to_class_mapping = category_to_class_mapping
104
+ else:
105
+ self.category_to_class_mapping = {}
106
+
107
+ def __iter__(self) -> Iterator[List[SampledData]]:
108
+ for batch in self.data_loader:
109
+ # batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
110
+ # images_batch : Tensor[N, C, H, W]
111
+ # image : Tensor[C, H, W]
112
+ images_and_categories = [
113
+ {"image": image, "category": category}
114
+ for element in batch
115
+ for image, category in zip(element["images"], element["categories"])
116
+ ]
117
+ if not images_and_categories:
118
+ continue
119
+ if self.shuffle:
120
+ random.shuffle(images_and_categories)
121
+ yield from self._produce_data(images_and_categories) # pyre-ignore[6]
122
+
123
+ def _produce_data(
124
+ self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
125
+ ) -> Iterator[List[SampledData]]:
126
+ """
127
+ Produce batches of data from images
128
+
129
+ Args:
130
+ images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
131
+ list of images and corresponding categories to process
132
+
133
+ Returns:
134
+ Iterator over batches of data sampled from model outputs
135
+ """
136
+ data_batches: List[SampledData] = []
137
+ category_to_class_mapping = self.category_to_class_mapping
138
+ batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
139
+ for batch in batched_images_and_categories:
140
+ batch = [
141
+ {
142
+ "image": image_and_category["image"].to(self.model.device),
143
+ "category": image_and_category["category"],
144
+ }
145
+ for image_and_category in batch
146
+ if image_and_category is not None
147
+ ]
148
+ if not batch:
149
+ continue
150
+ with torch.no_grad():
151
+ model_output = self.model(batch)
152
+ for model_output_i, batch_i in zip(model_output, batch):
153
+ assert len(batch_i["image"].shape) == 3
154
+ model_output_i["image"] = batch_i["image"]
155
+ instance_class = category_to_class_mapping.get(batch_i["category"], 0)
156
+ model_output_i["instances"].dataset_classes = torch.tensor(
157
+ [instance_class] * len(model_output_i["instances"])
158
+ )
159
+ model_output_filtered = (
160
+ model_output if self.data_filter is None else self.data_filter(model_output)
161
+ )
162
+ data = (
163
+ model_output_filtered
164
+ if self.data_sampler is None
165
+ else self.data_sampler(model_output_filtered)
166
+ )
167
+ for data_i in data:
168
+ if len(data_i["instances"]):
169
+ data_batches.append(data_i)
170
+ if len(data_batches) >= self.batch_size:
171
+ yield data_batches[: self.batch_size]
172
+ data_batches = data_batches[self.batch_size :]
173
+ if not self.drop_last and data_batches:
174
+ yield data_batches
densepose/data/meshes/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
densepose/data/meshes/builtin.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from .catalog import MeshInfo, register_meshes
6
+
7
+ DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"
8
+
9
+ MESHES = [
10
+ MeshInfo(
11
+ name="smpl_27554",
12
+ data="smpl_27554.pkl",
13
+ geodists="geodists/geodists_smpl_27554.pkl",
14
+ symmetry="symmetry/symmetry_smpl_27554.pkl",
15
+ texcoords="texcoords/texcoords_smpl_27554.pkl",
16
+ ),
17
+ MeshInfo(
18
+ name="chimp_5029",
19
+ data="chimp_5029.pkl",
20
+ geodists="geodists/geodists_chimp_5029.pkl",
21
+ symmetry="symmetry/symmetry_chimp_5029.pkl",
22
+ texcoords="texcoords/texcoords_chimp_5029.pkl",
23
+ ),
24
+ MeshInfo(
25
+ name="cat_5001",
26
+ data="cat_5001.pkl",
27
+ geodists="geodists/geodists_cat_5001.pkl",
28
+ symmetry="symmetry/symmetry_cat_5001.pkl",
29
+ texcoords="texcoords/texcoords_cat_5001.pkl",
30
+ ),
31
+ MeshInfo(
32
+ name="cat_7466",
33
+ data="cat_7466.pkl",
34
+ geodists="geodists/geodists_cat_7466.pkl",
35
+ symmetry="symmetry/symmetry_cat_7466.pkl",
36
+ texcoords="texcoords/texcoords_cat_7466.pkl",
37
+ ),
38
+ MeshInfo(
39
+ name="sheep_5004",
40
+ data="sheep_5004.pkl",
41
+ geodists="geodists/geodists_sheep_5004.pkl",
42
+ symmetry="symmetry/symmetry_sheep_5004.pkl",
43
+ texcoords="texcoords/texcoords_sheep_5004.pkl",
44
+ ),
45
+ MeshInfo(
46
+ name="zebra_5002",
47
+ data="zebra_5002.pkl",
48
+ geodists="geodists/geodists_zebra_5002.pkl",
49
+ symmetry="symmetry/symmetry_zebra_5002.pkl",
50
+ texcoords="texcoords/texcoords_zebra_5002.pkl",
51
+ ),
52
+ MeshInfo(
53
+ name="horse_5004",
54
+ data="horse_5004.pkl",
55
+ geodists="geodists/geodists_horse_5004.pkl",
56
+ symmetry="symmetry/symmetry_horse_5004.pkl",
57
+ texcoords="texcoords/texcoords_zebra_5002.pkl",
58
+ ),
59
+ MeshInfo(
60
+ name="giraffe_5002",
61
+ data="giraffe_5002.pkl",
62
+ geodists="geodists/geodists_giraffe_5002.pkl",
63
+ symmetry="symmetry/symmetry_giraffe_5002.pkl",
64
+ texcoords="texcoords/texcoords_giraffe_5002.pkl",
65
+ ),
66
+ MeshInfo(
67
+ name="elephant_5002",
68
+ data="elephant_5002.pkl",
69
+ geodists="geodists/geodists_elephant_5002.pkl",
70
+ symmetry="symmetry/symmetry_elephant_5002.pkl",
71
+ texcoords="texcoords/texcoords_elephant_5002.pkl",
72
+ ),
73
+ MeshInfo(
74
+ name="dog_5002",
75
+ data="dog_5002.pkl",
76
+ geodists="geodists/geodists_dog_5002.pkl",
77
+ symmetry="symmetry/symmetry_dog_5002.pkl",
78
+ texcoords="texcoords/texcoords_dog_5002.pkl",
79
+ ),
80
+ MeshInfo(
81
+ name="dog_7466",
82
+ data="dog_7466.pkl",
83
+ geodists="geodists/geodists_dog_7466.pkl",
84
+ symmetry="symmetry/symmetry_dog_7466.pkl",
85
+ texcoords="texcoords/texcoords_dog_7466.pkl",
86
+ ),
87
+ MeshInfo(
88
+ name="cow_5002",
89
+ data="cow_5002.pkl",
90
+ geodists="geodists/geodists_cow_5002.pkl",
91
+ symmetry="symmetry/symmetry_cow_5002.pkl",
92
+ texcoords="texcoords/texcoords_cow_5002.pkl",
93
+ ),
94
+ MeshInfo(
95
+ name="bear_4936",
96
+ data="bear_4936.pkl",
97
+ geodists="geodists/geodists_bear_4936.pkl",
98
+ symmetry="symmetry/symmetry_bear_4936.pkl",
99
+ texcoords="texcoords/texcoords_bear_4936.pkl",
100
+ ),
101
+ ]
102
+
103
+ register_meshes(MESHES, DENSEPOSE_MESHES_DIR)
densepose/data/meshes/catalog.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ from collections import UserDict
7
+ from dataclasses import dataclass
8
+ from typing import Iterable, Optional
9
+
10
+ from ..utils import maybe_prepend_base_path
11
+
12
+
13
+ @dataclass
14
+ class MeshInfo:
15
+ name: str
16
+ data: str
17
+ geodists: Optional[str] = None
18
+ symmetry: Optional[str] = None
19
+ texcoords: Optional[str] = None
20
+
21
+
22
+ class _MeshCatalog(UserDict):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.mesh_ids = {}
26
+ self.mesh_names = {}
27
+ self.max_mesh_id = -1
28
+
29
+ def __setitem__(self, key, value):
30
+ if key in self:
31
+ logger = logging.getLogger(__name__)
32
+ logger.warning(
33
+ f"Overwriting mesh catalog entry '{key}': old value {self[key]}"
34
+ f", new value {value}"
35
+ )
36
+ mesh_id = self.mesh_ids[key]
37
+ else:
38
+ self.max_mesh_id += 1
39
+ mesh_id = self.max_mesh_id
40
+ super().__setitem__(key, value)
41
+ self.mesh_ids[key] = mesh_id
42
+ self.mesh_names[mesh_id] = key
43
+
44
+ def get_mesh_id(self, shape_name: str) -> int:
45
+ return self.mesh_ids[shape_name]
46
+
47
+ def get_mesh_name(self, mesh_id: int) -> str:
48
+ return self.mesh_names[mesh_id]
49
+
50
+
51
+ MeshCatalog = _MeshCatalog()
52
+
53
+
54
+ def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None:
55
+ geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords
56
+ if geodists:
57
+ geodists = maybe_prepend_base_path(base_path, geodists)
58
+ if symmetry:
59
+ symmetry = maybe_prepend_base_path(base_path, symmetry)
60
+ if texcoords:
61
+ texcoords = maybe_prepend_base_path(base_path, texcoords)
62
+ MeshCatalog[mesh_info.name] = MeshInfo(
63
+ name=mesh_info.name,
64
+ data=maybe_prepend_base_path(base_path, mesh_info.data),
65
+ geodists=geodists,
66
+ symmetry=symmetry,
67
+ texcoords=texcoords,
68
+ )
69
+
70
+
71
+ def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None:
72
+ for mesh_info in mesh_infos:
73
+ register_mesh(mesh_info, base_path)
densepose/data/samplers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_uniform import DensePoseUniformSampler
6
+ from .densepose_confidence_based import DensePoseConfidenceBasedSampler
7
+ from .densepose_cse_uniform import DensePoseCSEUniformSampler
8
+ from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler
9
+ from .mask_from_densepose import MaskFromDensePoseSampler
10
+ from .prediction_to_gt import PredictionToGroundTruthSampler
densepose/data/samplers/densepose_base.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BoxMode, Instances
10
+
11
+ from densepose.converters import ToChartResultConverter
12
+ from densepose.converters.base import IntTupleBox, make_int_box
13
+ from densepose.structures import DensePoseDataRelative, DensePoseList
14
+
15
+
16
+ class DensePoseBaseSampler:
17
+ """
18
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
19
+ Samples for each class are drawn according to some distribution over all pixels estimated
20
+ to belong to that class.
21
+ """
22
+
23
+ def __init__(self, count_per_class: int = 8):
24
+ """
25
+ Constructor
26
+
27
+ Args:
28
+ count_per_class (int): the sampler produces at most `count_per_class`
29
+ samples for each category
30
+ """
31
+ self.count_per_class = count_per_class
32
+
33
+ def __call__(self, instances: Instances) -> DensePoseList:
34
+ """
35
+ Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
36
+ into DensePose annotations data (an instance of `DensePoseList`)
37
+ """
38
+ boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
39
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
40
+ dp_datas = []
41
+ for i in range(len(boxes_xywh_abs)):
42
+ annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
43
+ annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
44
+ instances[i].pred_densepose
45
+ )
46
+ dp_datas.append(DensePoseDataRelative(annotation_i))
47
+ # create densepose annotations on CPU
48
+ dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
49
+ return dp_list
50
+
51
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
52
+ """
53
+ Sample DensPoseDataRelative from estimation results
54
+ """
55
+ labels, dp_result = self._produce_labels_and_results(instance)
56
+ annotation = {
57
+ DensePoseDataRelative.X_KEY: [],
58
+ DensePoseDataRelative.Y_KEY: [],
59
+ DensePoseDataRelative.U_KEY: [],
60
+ DensePoseDataRelative.V_KEY: [],
61
+ DensePoseDataRelative.I_KEY: [],
62
+ }
63
+ n, h, w = dp_result.shape
64
+ for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
65
+ # indices - tuple of 3 1D tensors of size k
66
+ # 0: index along the first dimension N
67
+ # 1: index along H dimension
68
+ # 2: index along W dimension
69
+ indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
70
+ # values - an array of size [n, k]
71
+ # n: number of channels (U, V, confidences)
72
+ # k: number of points labeled with part_id
73
+ values = dp_result[indices].view(n, -1)
74
+ k = values.shape[1]
75
+ count = min(self.count_per_class, k)
76
+ if count <= 0:
77
+ continue
78
+ index_sample = self._produce_index_sample(values, count)
79
+ sampled_values = values[:, index_sample]
80
+ sampled_y = indices[1][index_sample] + 0.5
81
+ sampled_x = indices[2][index_sample] + 0.5
82
+ # prepare / normalize data
83
+ x = (sampled_x / w * 256.0).cpu().tolist()
84
+ y = (sampled_y / h * 256.0).cpu().tolist()
85
+ u = sampled_values[0].clamp(0, 1).cpu().tolist()
86
+ v = sampled_values[1].clamp(0, 1).cpu().tolist()
87
+ fine_segm_labels = [part_id] * count
88
+ # extend annotations
89
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
90
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
91
+ annotation[DensePoseDataRelative.U_KEY].extend(u)
92
+ annotation[DensePoseDataRelative.V_KEY].extend(v)
93
+ annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
94
+ return annotation
95
+
96
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
97
+ """
98
+ Abstract method to produce a sample of indices to select data
99
+ To be implemented in descendants
100
+
101
+ Args:
102
+ values (torch.Tensor): an array of size [n, k] that contains
103
+ estimated values (U, V, confidences);
104
+ n: number of channels (U, V, confidences)
105
+ k: number of points labeled with part_id
106
+ count (int): number of samples to produce, should be positive and <= k
107
+
108
+ Return:
109
+ list(int): indices of values (along axis 1) selected as a sample
110
+ """
111
+ raise NotImplementedError
112
+
113
+ def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """
115
+ Method to get labels and DensePose results from an instance
116
+
117
+ Args:
118
+ instance (Instances): an instance of `DensePoseChartPredictorOutput`
119
+
120
+ Return:
121
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
122
+ dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
123
+ """
124
+ converter = ToChartResultConverter
125
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
126
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
127
+ return labels, dp_result
128
+
129
+ def _resample_mask(self, output: Any) -> torch.Tensor:
130
+ """
131
+ Convert DensePose predictor output to segmentation annotation - tensors of size
132
+ (256, 256) and type `int64`.
133
+
134
+ Args:
135
+ output: DensePose predictor output with the following attributes:
136
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
137
+ segmentation scores
138
+ - fine_segm: tensor of size [N, C, H, W] with unnormalized fine
139
+ segmentation scores
140
+ Return:
141
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
142
+ where S = DensePoseDataRelative.MASK_SIZE
143
+ """
144
+ sz = DensePoseDataRelative.MASK_SIZE
145
+ S = (
146
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
147
+ .argmax(dim=1)
148
+ .long()
149
+ )
150
+ I = (
151
+ (
152
+ F.interpolate(
153
+ output.fine_segm,
154
+ (sz, sz),
155
+ mode="bilinear",
156
+ align_corners=False,
157
+ ).argmax(dim=1)
158
+ * (S > 0).long()
159
+ )
160
+ .squeeze()
161
+ .cpu()
162
+ )
163
+ # Map fine segmentation results to coarse segmentation ground truth
164
+ # TODO: extract this into separate classes
165
+ # coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
166
+ # 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
167
+ # 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
168
+ # 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
169
+ # 14 = Head
170
+ # fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
171
+ # 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
172
+ # 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
173
+ # 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
174
+ # 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
175
+ # 20, 22 = Lower Arm Right, 23, 24 = Head
176
+ FINE_TO_COARSE_SEGMENTATION = {
177
+ 1: 1,
178
+ 2: 1,
179
+ 3: 2,
180
+ 4: 3,
181
+ 5: 4,
182
+ 6: 5,
183
+ 7: 6,
184
+ 8: 7,
185
+ 9: 6,
186
+ 10: 7,
187
+ 11: 8,
188
+ 12: 9,
189
+ 13: 8,
190
+ 14: 9,
191
+ 15: 10,
192
+ 16: 11,
193
+ 17: 10,
194
+ 18: 11,
195
+ 19: 12,
196
+ 20: 13,
197
+ 21: 12,
198
+ 22: 13,
199
+ 23: 14,
200
+ 24: 14,
201
+ }
202
+ mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
203
+ for i in range(DensePoseDataRelative.N_PART_LABELS):
204
+ mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
205
+ return mask
densepose/data/samplers/densepose_confidence_based.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+
9
+ from densepose.converters import ToChartResultConverterWithConfidences
10
+
11
+ from .densepose_base import DensePoseBaseSampler
12
+
13
+
14
+ class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
15
+ """
16
+ Samples DensePose data from DensePose predictions.
17
+ Samples for each class are drawn using confidence value estimates.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ confidence_channel: str,
23
+ count_per_class: int = 8,
24
+ search_count_multiplier: Optional[float] = None,
25
+ search_proportion: Optional[float] = None,
26
+ ):
27
+ """
28
+ Constructor
29
+
30
+ Args:
31
+ confidence_channel (str): confidence channel to use for sampling;
32
+ possible values:
33
+ "sigma_2": confidences for UV values
34
+ "fine_segm_confidence": confidences for fine segmentation
35
+ "coarse_segm_confidence": confidences for coarse segmentation
36
+ (default: "sigma_2")
37
+ count_per_class (int): the sampler produces at most `count_per_class`
38
+ samples for each category (default: 8)
39
+ search_count_multiplier (float or None): if not None, the total number
40
+ of the most confident estimates of a given class to consider is
41
+ defined as `min(search_count_multiplier * count_per_class, N)`,
42
+ where `N` is the total number of estimates of the class; cannot be
43
+ specified together with `search_proportion` (default: None)
44
+ search_proportion (float or None): if not None, the total number of the
45
+ of the most confident estimates of a given class to consider is
46
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
47
+ where `N` is the total number of estimates of the class; cannot be
48
+ specified together with `search_count_multiplier` (default: None)
49
+ """
50
+ super().__init__(count_per_class)
51
+ self.confidence_channel = confidence_channel
52
+ self.search_count_multiplier = search_count_multiplier
53
+ self.search_proportion = search_proportion
54
+ assert (search_count_multiplier is None) or (search_proportion is None), (
55
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
56
+ f"and search_proportion (={search_proportion})"
57
+ )
58
+
59
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
60
+ """
61
+ Produce a sample of indices to select data based on confidences
62
+
63
+ Args:
64
+ values (torch.Tensor): an array of size [n, k] that contains
65
+ estimated values (U, V, confidences);
66
+ n: number of channels (U, V, confidences)
67
+ k: number of points labeled with part_id
68
+ count (int): number of samples to produce, should be positive and <= k
69
+
70
+ Return:
71
+ list(int): indices of values (along axis 1) selected as a sample
72
+ """
73
+ k = values.shape[1]
74
+ if k == count:
75
+ index_sample = list(range(k))
76
+ else:
77
+ # take the best count * search_count_multiplier pixels,
78
+ # sample from them uniformly
79
+ # (here best = smallest variance)
80
+ _, sorted_confidence_indices = torch.sort(values[2])
81
+ if self.search_count_multiplier is not None:
82
+ search_count = min(int(count * self.search_count_multiplier), k)
83
+ elif self.search_proportion is not None:
84
+ search_count = min(max(int(k * self.search_proportion), count), k)
85
+ else:
86
+ search_count = min(count, k)
87
+ sample_from_top = random.sample(range(search_count), count)
88
+ index_sample = sorted_confidence_indices[:search_count][sample_from_top]
89
+ return index_sample
90
+
91
+ def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Method to get labels and DensePose results from an instance, with confidences
94
+
95
+ Args:
96
+ instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
97
+
98
+ Return:
99
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
100
+ dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
101
+ stacked with the confidence channel
102
+ """
103
+ converter = ToChartResultConverterWithConfidences
104
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
105
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
106
+ dp_result = torch.cat(
107
+ (dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
108
+ )
109
+
110
+ return labels, dp_result
densepose/data/samplers/densepose_cse_base.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import CfgNode
10
+ from detectron2.structures import Instances
11
+
12
+ from densepose.converters.base import IntTupleBox
13
+ from densepose.data.utils import get_class_to_mesh_name_mapping
14
+ from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
15
+ from densepose.structures import DensePoseDataRelative
16
+
17
+ from .densepose_base import DensePoseBaseSampler
18
+
19
+
20
+ class DensePoseCSEBaseSampler(DensePoseBaseSampler):
21
+ """
22
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
23
+ Samples for each class are drawn according to some distribution over all pixels estimated
24
+ to belong to that class.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ cfg: CfgNode,
30
+ use_gt_categories: bool,
31
+ embedder: torch.nn.Module,
32
+ count_per_class: int = 8,
33
+ ):
34
+ """
35
+ Constructor
36
+
37
+ Args:
38
+ cfg (CfgNode): the config of the model
39
+ embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
40
+ count_per_class (int): the sampler produces at most `count_per_class`
41
+ samples for each category
42
+ """
43
+ super().__init__(count_per_class)
44
+ self.embedder = embedder
45
+ self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
46
+ self.use_gt_categories = use_gt_categories
47
+
48
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
49
+ """
50
+ Sample DensPoseDataRelative from estimation results
51
+ """
52
+ if self.use_gt_categories:
53
+ instance_class = instance.dataset_classes.tolist()[0]
54
+ else:
55
+ instance_class = instance.pred_classes.tolist()[0]
56
+ mesh_name = self.class_to_mesh_name[instance_class]
57
+
58
+ annotation = {
59
+ DensePoseDataRelative.X_KEY: [],
60
+ DensePoseDataRelative.Y_KEY: [],
61
+ DensePoseDataRelative.VERTEX_IDS_KEY: [],
62
+ DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
63
+ }
64
+
65
+ mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
66
+ indices = torch.nonzero(mask, as_tuple=True)
67
+ selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
68
+ values = other_values[:, indices[0], indices[1]]
69
+ k = values.shape[1]
70
+
71
+ count = min(self.count_per_class, k)
72
+ if count <= 0:
73
+ return annotation
74
+
75
+ index_sample = self._produce_index_sample(values, count)
76
+ closest_vertices = squared_euclidean_distance_matrix(
77
+ selected_embeddings[index_sample], self.embedder(mesh_name)
78
+ )
79
+ closest_vertices = torch.argmin(closest_vertices, dim=1)
80
+
81
+ sampled_y = indices[0][index_sample] + 0.5
82
+ sampled_x = indices[1][index_sample] + 0.5
83
+ # prepare / normalize data
84
+ _, _, w, h = bbox_xywh
85
+ x = (sampled_x / w * 256.0).cpu().tolist()
86
+ y = (sampled_y / h * 256.0).cpu().tolist()
87
+ # extend annotations
88
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
89
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
90
+ annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
91
+ return annotation
92
+
93
+ def _produce_mask_and_results(
94
+ self, instance: Instances, bbox_xywh: IntTupleBox
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
96
+ """
97
+ Method to get labels and DensePose results from an instance
98
+
99
+ Args:
100
+ instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
101
+ bbox_xywh (IntTupleBox): the corresponding bounding box
102
+
103
+ Return:
104
+ mask (torch.Tensor): shape [H, W], DensePose segmentation mask
105
+ embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
106
+ DensePose CSE Embeddings
107
+ other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
108
+ for potential other values
109
+ """
110
+ densepose_output = instance.pred_densepose
111
+ S = densepose_output.coarse_segm
112
+ E = densepose_output.embedding
113
+ _, _, w, h = bbox_xywh
114
+ embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
115
+ coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
116
+ mask = coarse_segm_resized.argmax(0) > 0
117
+ other_values = torch.empty((0, h, w), device=E.device)
118
+ return mask, embeddings, other_values
119
+
120
+ def _resample_mask(self, output: Any) -> torch.Tensor:
121
+ """
122
+ Convert DensePose predictor output to segmentation annotation - tensors of size
123
+ (256, 256) and type `int64`.
124
+
125
+ Args:
126
+ output: DensePose predictor output with the following attributes:
127
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
128
+ segmentation scores
129
+ Return:
130
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
131
+ where S = DensePoseDataRelative.MASK_SIZE
132
+ """
133
+ sz = DensePoseDataRelative.MASK_SIZE
134
+ mask = (
135
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
136
+ .argmax(dim=1)
137
+ .long()
138
+ .squeeze()
139
+ .cpu()
140
+ )
141
+ return mask
densepose/data/samplers/densepose_cse_confidence_based.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import CfgNode
11
+ from detectron2.structures import Instances
12
+
13
+ from densepose.converters.base import IntTupleBox
14
+
15
+ from .densepose_cse_base import DensePoseCSEBaseSampler
16
+
17
+
18
+ class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
19
+ """
20
+ Samples DensePose data from DensePose predictions.
21
+ Samples for each class are drawn using confidence value estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ cfg: CfgNode,
27
+ use_gt_categories: bool,
28
+ embedder: torch.nn.Module,
29
+ confidence_channel: str,
30
+ count_per_class: int = 8,
31
+ search_count_multiplier: Optional[float] = None,
32
+ search_proportion: Optional[float] = None,
33
+ ):
34
+ """
35
+ Constructor
36
+
37
+ Args:
38
+ cfg (CfgNode): the config of the model
39
+ embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
40
+ confidence_channel (str): confidence channel to use for sampling;
41
+ possible values:
42
+ "coarse_segm_confidence": confidences for coarse segmentation
43
+ (default: "coarse_segm_confidence")
44
+ count_per_class (int): the sampler produces at most `count_per_class`
45
+ samples for each category (default: 8)
46
+ search_count_multiplier (float or None): if not None, the total number
47
+ of the most confident estimates of a given class to consider is
48
+ defined as `min(search_count_multiplier * count_per_class, N)`,
49
+ where `N` is the total number of estimates of the class; cannot be
50
+ specified together with `search_proportion` (default: None)
51
+ search_proportion (float or None): if not None, the total number of the
52
+ of the most confident estimates of a given class to consider is
53
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
54
+ where `N` is the total number of estimates of the class; cannot be
55
+ specified together with `search_count_multiplier` (default: None)
56
+ """
57
+ super().__init__(cfg, use_gt_categories, embedder, count_per_class)
58
+ self.confidence_channel = confidence_channel
59
+ self.search_count_multiplier = search_count_multiplier
60
+ self.search_proportion = search_proportion
61
+ assert (search_count_multiplier is None) or (search_proportion is None), (
62
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
63
+ f"and search_proportion (={search_proportion})"
64
+ )
65
+
66
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
67
+ """
68
+ Produce a sample of indices to select data based on confidences
69
+
70
+ Args:
71
+ values (torch.Tensor): a tensor of length k that contains confidences
72
+ k: number of points labeled with part_id
73
+ count (int): number of samples to produce, should be positive and <= k
74
+
75
+ Return:
76
+ list(int): indices of values (along axis 1) selected as a sample
77
+ """
78
+ k = values.shape[1]
79
+ if k == count:
80
+ index_sample = list(range(k))
81
+ else:
82
+ # take the best count * search_count_multiplier pixels,
83
+ # sample from them uniformly
84
+ # (here best = smallest variance)
85
+ _, sorted_confidence_indices = torch.sort(values[0])
86
+ if self.search_count_multiplier is not None:
87
+ search_count = min(int(count * self.search_count_multiplier), k)
88
+ elif self.search_proportion is not None:
89
+ search_count = min(max(int(k * self.search_proportion), count), k)
90
+ else:
91
+ search_count = min(count, k)
92
+ sample_from_top = random.sample(range(search_count), count)
93
+ index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
94
+ return index_sample
95
+
96
+ def _produce_mask_and_results(
97
+ self, instance: Instances, bbox_xywh: IntTupleBox
98
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99
+ """
100
+ Method to get labels and DensePose results from an instance
101
+
102
+ Args:
103
+ instance (Instances): an instance of
104
+ `DensePoseEmbeddingPredictorOutputWithConfidences`
105
+ bbox_xywh (IntTupleBox): the corresponding bounding box
106
+
107
+ Return:
108
+ mask (torch.Tensor): shape [H, W], DensePose segmentation mask
109
+ embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
110
+ DensePose CSE Embeddings
111
+ other_values: a tensor of shape [1, H, W], DensePose CSE confidence
112
+ """
113
+ _, _, w, h = bbox_xywh
114
+ densepose_output = instance.pred_densepose
115
+ mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
116
+ other_values = F.interpolate(
117
+ getattr(densepose_output, self.confidence_channel),
118
+ size=(h, w),
119
+ mode="bilinear",
120
+ )[0].cpu()
121
+ return mask, embeddings, other_values
densepose/data/samplers/densepose_cse_uniform.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_cse_base import DensePoseCSEBaseSampler
6
+ from .densepose_uniform import DensePoseUniformSampler
7
+
8
+
9
+ class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
10
+ """
11
+ Uniform Sampler for CSE
12
+ """
13
+
14
+ pass
densepose/data/samplers/densepose_uniform.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ import torch
7
+
8
+ from .densepose_base import DensePoseBaseSampler
9
+
10
+
11
+ class DensePoseUniformSampler(DensePoseBaseSampler):
12
+ """
13
+ Samples DensePose data from DensePose predictions.
14
+ Samples for each class are drawn uniformly over all pixels estimated
15
+ to belong to that class.
16
+ """
17
+
18
+ def __init__(self, count_per_class: int = 8):
19
+ """
20
+ Constructor
21
+
22
+ Args:
23
+ count_per_class (int): the sampler produces at most `count_per_class`
24
+ samples for each category
25
+ """
26
+ super().__init__(count_per_class)
27
+
28
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
29
+ """
30
+ Produce a uniform sample of indices to select data
31
+
32
+ Args:
33
+ values (torch.Tensor): an array of size [n, k] that contains
34
+ estimated values (U, V, confidences);
35
+ n: number of channels (U, V, confidences)
36
+ k: number of points labeled with part_id
37
+ count (int): number of samples to produce, should be positive and <= k
38
+
39
+ Return:
40
+ list(int): indices of values (along axis 1) selected as a sample
41
+ """
42
+ k = values.shape[1]
43
+ return random.sample(range(k), count)
densepose/data/samplers/mask_from_densepose.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.structures import BitMasks, Instances
6
+
7
+ from densepose.converters import ToMaskConverter
8
+
9
+
10
+ class MaskFromDensePoseSampler:
11
+ """
12
+ Produce mask GT from DensePose predictions
13
+ This sampler simply converts DensePose predictions to BitMasks
14
+ that a contain a bool tensor of the size of the input image
15
+ """
16
+
17
+ def __call__(self, instances: Instances) -> BitMasks:
18
+ """
19
+ Converts predicted data from `instances` into the GT mask data
20
+
21
+ Args:
22
+ instances (Instances): predicted results, expected to have `pred_densepose` field
23
+
24
+ Returns:
25
+ Boolean Tensor of the size of the input image that has non-zero
26
+ values at pixels that are estimated to belong to the detected object
27
+ """
28
+ return ToMaskConverter.convert(
29
+ instances.pred_densepose, instances.pred_boxes, instances.image_size
30
+ )
densepose/data/samplers/prediction_to_gt.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from detectron2.structures import Instances
9
+
10
+ ModelOutput = Dict[str, Any]
11
+ SampledData = Dict[str, Any]
12
+
13
+
14
+ @dataclass
15
+ class _Sampler:
16
+ """
17
+ Sampler registry entry that contains:
18
+ - src (str): source field to sample from (deleted after sampling)
19
+ - dst (Optional[str]): destination field to sample to, if not None
20
+ - func (Optional[Callable: Any -> Any]): function that performs sampling,
21
+ if None, reference copy is performed
22
+ """
23
+
24
+ src: str
25
+ dst: Optional[str]
26
+ func: Optional[Callable[[Any], Any]]
27
+
28
+
29
+ class PredictionToGroundTruthSampler:
30
+ """
31
+ Sampler implementation that converts predictions to GT using registered
32
+ samplers for different fields of `Instances`.
33
+ """
34
+
35
+ def __init__(self, dataset_name: str = ""):
36
+ self.dataset_name = dataset_name
37
+ self._samplers = {}
38
+ self.register_sampler("pred_boxes", "gt_boxes", None)
39
+ self.register_sampler("pred_classes", "gt_classes", None)
40
+ # delete scores
41
+ self.register_sampler("scores")
42
+
43
+ def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
44
+ """
45
+ Transform model output into ground truth data through sampling
46
+
47
+ Args:
48
+ model_output (Dict[str, Any]): model output
49
+ Returns:
50
+ Dict[str, Any]: sampled data
51
+ """
52
+ for model_output_i in model_output:
53
+ instances: Instances = model_output_i["instances"]
54
+ # transform data in each field
55
+ for _, sampler in self._samplers.items():
56
+ if not instances.has(sampler.src) or sampler.dst is None:
57
+ continue
58
+ if sampler.func is None:
59
+ instances.set(sampler.dst, instances.get(sampler.src))
60
+ else:
61
+ instances.set(sampler.dst, sampler.func(instances))
62
+ # delete model output data that was transformed
63
+ for _, sampler in self._samplers.items():
64
+ if sampler.src != sampler.dst and instances.has(sampler.src):
65
+ instances.remove(sampler.src)
66
+ model_output_i["dataset"] = self.dataset_name
67
+ return model_output
68
+
69
+ def register_sampler(
70
+ self,
71
+ prediction_attr: str,
72
+ gt_attr: Optional[str] = None,
73
+ func: Optional[Callable[[Any], Any]] = None,
74
+ ):
75
+ """
76
+ Register sampler for a field
77
+
78
+ Args:
79
+ prediction_attr (str): field to replace with a sampled value
80
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
81
+ func (Optional[Callable: Any -> Any]): sampler function
82
+ """
83
+ self._samplers[(prediction_attr, gt_attr)] = _Sampler(
84
+ src=prediction_attr, dst=gt_attr, func=func
85
+ )
86
+
87
+ def remove_sampler(
88
+ self,
89
+ prediction_attr: str,
90
+ gt_attr: Optional[str] = None,
91
+ ):
92
+ """
93
+ Remove sampler for a field
94
+
95
+ Args:
96
+ prediction_attr (str): field to replace with a sampled value
97
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
98
+ """
99
+ assert (prediction_attr, gt_attr) in self._samplers
100
+ del self._samplers[(prediction_attr, gt_attr)]
densepose/data/transform/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .image import ImageResizeTransform
densepose/data/transform/image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import torch
6
+
7
+
8
+ class ImageResizeTransform:
9
+ """
10
+ Transform that resizes images loaded from a dataset
11
+ (BGR data in NCHW channel order, typically uint8) to a format ready to be
12
+ consumed by DensePose training (BGR float32 data in NCHW channel order)
13
+ """
14
+
15
+ def __init__(self, min_size: int = 800, max_size: int = 1333):
16
+ self.min_size = min_size
17
+ self.max_size = max_size
18
+
19
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Args:
22
+ images (torch.Tensor): tensor of size [N, 3, H, W] that contains
23
+ BGR data (typically in uint8)
24
+ Returns:
25
+ images (torch.Tensor): tensor of size [N, 3, H1, W1] where
26
+ H1 and W1 are chosen to respect the specified min and max sizes
27
+ and preserve the original aspect ratio, the data channels
28
+ follow BGR order and the data type is `torch.float32`
29
+ """
30
+ # resize with min size
31
+ images = images.float()
32
+ min_size = min(images.shape[-2:])
33
+ max_size = max(images.shape[-2:])
34
+ scale = min(self.min_size / min_size, self.max_size / max_size)
35
+ images = torch.nn.functional.interpolate(
36
+ images,
37
+ scale_factor=scale,
38
+ mode="bilinear",
39
+ align_corners=False,
40
+ )
41
+ return images
densepose/data/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import os
6
+ from typing import Dict, Optional
7
+
8
+ from detectron2.config import CfgNode
9
+
10
+
11
+ def is_relative_local_path(path: str) -> bool:
12
+ path_str = os.fsdecode(path)
13
+ return ("://" not in path_str) and not os.path.isabs(path)
14
+
15
+
16
+ def maybe_prepend_base_path(base_path: Optional[str], path: str):
17
+ """
18
+ Prepends the provided path with a base path prefix if:
19
+ 1) base path is not None;
20
+ 2) path is a local path
21
+ """
22
+ if base_path is None:
23
+ return path
24
+ if is_relative_local_path(path):
25
+ return os.path.join(base_path, path)
26
+ return path
27
+
28
+
29
+ def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]:
30
+ return {
31
+ int(class_id): mesh_name
32
+ for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items()
33
+ }
34
+
35
+
36
+ def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]:
37
+ return {
38
+ category: int(class_id)
39
+ for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items()
40
+ }
densepose/data/video/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .frame_selector import (
6
+ FrameSelectionStrategy,
7
+ RandomKFramesSelector,
8
+ FirstKFramesSelector,
9
+ LastKFramesSelector,
10
+ FrameTsList,
11
+ FrameSelector,
12
+ )
13
+
14
+ from .video_keyframe_dataset import (
15
+ VideoKeyframeDataset,
16
+ video_list_from_file,
17
+ list_keyframes,
18
+ read_keyframes,
19
+ )
densepose/data/video/frame_selector.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections.abc import Callable
7
+ from enum import Enum
8
+ from typing import Callable as TCallable
9
+ from typing import List
10
+
11
+ FrameTsList = List[int]
12
+ FrameSelector = TCallable[[FrameTsList], FrameTsList]
13
+
14
+
15
+ class FrameSelectionStrategy(Enum):
16
+ """
17
+ Frame selection strategy used with videos:
18
+ - "random_k": select k random frames
19
+ - "first_k": select k first frames
20
+ - "last_k": select k last frames
21
+ - "all": select all frames
22
+ """
23
+
24
+ # fmt: off
25
+ RANDOM_K = "random_k"
26
+ FIRST_K = "first_k"
27
+ LAST_K = "last_k"
28
+ ALL = "all"
29
+ # fmt: on
30
+
31
+
32
+ class RandomKFramesSelector(Callable): # pyre-ignore[39]
33
+ """
34
+ Selector that retains at most `k` random frames
35
+ """
36
+
37
+ def __init__(self, k: int):
38
+ self.k = k
39
+
40
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
41
+ """
42
+ Select `k` random frames
43
+
44
+ Args:
45
+ frames_tss (List[int]): timestamps of input frames
46
+ Returns:
47
+ List[int]: timestamps of selected frames
48
+ """
49
+ return random.sample(frame_tss, min(self.k, len(frame_tss)))
50
+
51
+
52
+ class FirstKFramesSelector(Callable): # pyre-ignore[39]
53
+ """
54
+ Selector that retains at most `k` first frames
55
+ """
56
+
57
+ def __init__(self, k: int):
58
+ self.k = k
59
+
60
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
61
+ """
62
+ Select `k` first frames
63
+
64
+ Args:
65
+ frames_tss (List[int]): timestamps of input frames
66
+ Returns:
67
+ List[int]: timestamps of selected frames
68
+ """
69
+ return frame_tss[: self.k]
70
+
71
+
72
+ class LastKFramesSelector(Callable): # pyre-ignore[39]
73
+ """
74
+ Selector that retains at most `k` last frames from video data
75
+ """
76
+
77
+ def __init__(self, k: int):
78
+ self.k = k
79
+
80
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
81
+ """
82
+ Select `k` last frames
83
+
84
+ Args:
85
+ frames_tss (List[int]): timestamps of input frames
86
+ Returns:
87
+ List[int]: timestamps of selected frames
88
+ """
89
+ return frame_tss[-self.k :]
densepose/data/video/video_keyframe_dataset.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import csv
7
+ import logging
8
+ import numpy as np
9
+ from typing import Any, Callable, Dict, List, Optional, Union
10
+ import av
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.utils.file_io import PathManager
15
+
16
+ from ..utils import maybe_prepend_base_path
17
+ from .frame_selector import FrameSelector, FrameTsList
18
+
19
+ FrameList = List[av.frame.Frame] # pyre-ignore[16]
20
+ FrameTransform = Callable[[torch.Tensor], torch.Tensor]
21
+
22
+
23
+ def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
24
+ """
25
+ Traverses all keyframes of a video file. Returns a list of keyframe
26
+ timestamps. Timestamps are counts in timebase units.
27
+
28
+ Args:
29
+ video_fpath (str): Video file path
30
+ video_stream_idx (int): Video stream index (default: 0)
31
+ Returns:
32
+ List[int]: list of keyframe timestaps (timestamp is a count in timebase
33
+ units)
34
+ """
35
+ try:
36
+ with PathManager.open(video_fpath, "rb") as io:
37
+ # pyre-fixme[16]: Module `av` has no attribute `open`.
38
+ container = av.open(io, mode="r")
39
+ stream = container.streams.video[video_stream_idx]
40
+ keyframes = []
41
+ pts = -1
42
+ # Note: even though we request forward seeks for keyframes, sometimes
43
+ # a keyframe in backwards direction is returned. We introduce tolerance
44
+ # as a max count of ignored backward seeks
45
+ tolerance_backward_seeks = 2
46
+ while True:
47
+ try:
48
+ container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
49
+ except av.AVError as e:
50
+ # the exception occurs when the video length is exceeded,
51
+ # we then return whatever data we've already collected
52
+ logger = logging.getLogger(__name__)
53
+ logger.debug(
54
+ f"List keyframes: Error seeking video file {video_fpath}, "
55
+ f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
56
+ )
57
+ return keyframes
58
+ except OSError as e:
59
+ logger = logging.getLogger(__name__)
60
+ logger.warning(
61
+ f"List keyframes: Error seeking video file {video_fpath}, "
62
+ f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
63
+ )
64
+ return []
65
+ packet = next(container.demux(video=video_stream_idx))
66
+ if packet.pts is not None and packet.pts <= pts:
67
+ logger = logging.getLogger(__name__)
68
+ logger.warning(
69
+ f"Video file {video_fpath}, stream {video_stream_idx}: "
70
+ f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
71
+ f"tolerance {tolerance_backward_seeks}."
72
+ )
73
+ tolerance_backward_seeks -= 1
74
+ if tolerance_backward_seeks == 0:
75
+ return []
76
+ pts += 1
77
+ continue
78
+ tolerance_backward_seeks = 2
79
+ pts = packet.pts
80
+ if pts is None:
81
+ return keyframes
82
+ if packet.is_keyframe:
83
+ keyframes.append(pts)
84
+ return keyframes
85
+ except OSError as e:
86
+ logger = logging.getLogger(__name__)
87
+ logger.warning(
88
+ f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
89
+ )
90
+ except RuntimeError as e:
91
+ logger = logging.getLogger(__name__)
92
+ logger.warning(
93
+ f"List keyframes: Error opening video file container {video_fpath}, "
94
+ f"Runtime error: {e}"
95
+ )
96
+ return []
97
+
98
+
99
+ def read_keyframes(
100
+ video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
101
+ ) -> FrameList: # pyre-ignore[11]
102
+ """
103
+ Reads keyframe data from a video file.
104
+
105
+ Args:
106
+ video_fpath (str): Video file path
107
+ keyframes (List[int]): List of keyframe timestamps (as counts in
108
+ timebase units to be used in container seek operations)
109
+ video_stream_idx (int): Video stream index (default: 0)
110
+ Returns:
111
+ List[Frame]: list of frames that correspond to the specified timestamps
112
+ """
113
+ try:
114
+ with PathManager.open(video_fpath, "rb") as io:
115
+ # pyre-fixme[16]: Module `av` has no attribute `open`.
116
+ container = av.open(io)
117
+ stream = container.streams.video[video_stream_idx]
118
+ frames = []
119
+ for pts in keyframes:
120
+ try:
121
+ container.seek(pts, any_frame=False, stream=stream)
122
+ frame = next(container.decode(video=0))
123
+ frames.append(frame)
124
+ except av.AVError as e:
125
+ logger = logging.getLogger(__name__)
126
+ logger.warning(
127
+ f"Read keyframes: Error seeking video file {video_fpath}, "
128
+ f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
129
+ )
130
+ container.close()
131
+ return frames
132
+ except OSError as e:
133
+ logger = logging.getLogger(__name__)
134
+ logger.warning(
135
+ f"Read keyframes: Error seeking video file {video_fpath}, "
136
+ f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
137
+ )
138
+ container.close()
139
+ return frames
140
+ except StopIteration:
141
+ logger = logging.getLogger(__name__)
142
+ logger.warning(
143
+ f"Read keyframes: Error decoding frame from {video_fpath}, "
144
+ f"video stream {video_stream_idx}, pts {pts}"
145
+ )
146
+ container.close()
147
+ return frames
148
+
149
+ container.close()
150
+ return frames
151
+ except OSError as e:
152
+ logger = logging.getLogger(__name__)
153
+ logger.warning(
154
+ f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
155
+ )
156
+ except RuntimeError as e:
157
+ logger = logging.getLogger(__name__)
158
+ logger.warning(
159
+ f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
160
+ )
161
+ return []
162
+
163
+
164
+ def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
165
+ """
166
+ Create a list of paths to video files from a text file.
167
+
168
+ Args:
169
+ video_list_fpath (str): path to a plain text file with the list of videos
170
+ base_path (str): base path for entries from the video list (default: None)
171
+ """
172
+ video_list = []
173
+ with PathManager.open(video_list_fpath, "r") as io:
174
+ for line in io:
175
+ video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
176
+ return video_list
177
+
178
+
179
+ def read_keyframe_helper_data(fpath: str):
180
+ """
181
+ Read keyframe data from a file in CSV format: the header should contain
182
+ "video_id" and "keyframes" fields. Value specifications are:
183
+ video_id: int
184
+ keyframes: list(int)
185
+ Example of contents:
186
+ video_id,keyframes
187
+ 2,"[1,11,21,31,41,51,61,71,81]"
188
+
189
+ Args:
190
+ fpath (str): File containing keyframe data
191
+
192
+ Return:
193
+ video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
194
+ contains a list of keyframes for that video
195
+ """
196
+ video_id_to_keyframes = {}
197
+ try:
198
+ with PathManager.open(fpath, "r") as io:
199
+ csv_reader = csv.reader(io)
200
+ header = next(csv_reader)
201
+ video_id_idx = header.index("video_id")
202
+ keyframes_idx = header.index("keyframes")
203
+ for row in csv_reader:
204
+ video_id = int(row[video_id_idx])
205
+ assert (
206
+ video_id not in video_id_to_keyframes
207
+ ), f"Duplicate keyframes entry for video {fpath}"
208
+ video_id_to_keyframes[video_id] = (
209
+ [int(v) for v in row[keyframes_idx][1:-1].split(",")]
210
+ if len(row[keyframes_idx]) > 2
211
+ else []
212
+ )
213
+ except Exception as e:
214
+ logger = logging.getLogger(__name__)
215
+ logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
216
+ return video_id_to_keyframes
217
+
218
+
219
+ class VideoKeyframeDataset(Dataset):
220
+ """
221
+ Dataset that provides keyframes for a set of videos.
222
+ """
223
+
224
+ _EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
225
+
226
+ def __init__(
227
+ self,
228
+ video_list: List[str],
229
+ category_list: Union[str, List[str], None] = None,
230
+ frame_selector: Optional[FrameSelector] = None,
231
+ transform: Optional[FrameTransform] = None,
232
+ keyframe_helper_fpath: Optional[str] = None,
233
+ ):
234
+ """
235
+ Dataset constructor
236
+
237
+ Args:
238
+ video_list (List[str]): list of paths to video files
239
+ category_list (Union[str, List[str], None]): list of animal categories for each
240
+ video file. If it is a string, or None, this applies to all videos
241
+ frame_selector (Callable: KeyFrameList -> KeyFrameList):
242
+ selects keyframes to process, keyframes are given by
243
+ packet timestamps in timebase counts. If None, all keyframes
244
+ are selected (default: None)
245
+ transform (Callable: torch.Tensor -> torch.Tensor):
246
+ transforms a batch of RGB images (tensors of size [B, 3, H, W]),
247
+ returns a tensor of the same size. If None, no transform is
248
+ applied (default: None)
249
+
250
+ """
251
+ if type(category_list) is list:
252
+ self.category_list = category_list
253
+ else:
254
+ self.category_list = [category_list] * len(video_list)
255
+ assert len(video_list) == len(
256
+ self.category_list
257
+ ), "length of video and category lists must be equal"
258
+ self.video_list = video_list
259
+ self.frame_selector = frame_selector
260
+ self.transform = transform
261
+ self.keyframe_helper_data = (
262
+ read_keyframe_helper_data(keyframe_helper_fpath)
263
+ if keyframe_helper_fpath is not None
264
+ else None
265
+ )
266
+
267
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
268
+ """
269
+ Gets selected keyframes from a given video
270
+
271
+ Args:
272
+ idx (int): video index in the video list file
273
+ Returns:
274
+ A dictionary containing two keys:
275
+ images (torch.Tensor): tensor of size [N, H, W, 3] or of size
276
+ defined by the transform that contains keyframes data
277
+ categories (List[str]): categories of the frames
278
+ """
279
+ categories = [self.category_list[idx]]
280
+ fpath = self.video_list[idx]
281
+ keyframes = (
282
+ list_keyframes(fpath)
283
+ if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
284
+ else self.keyframe_helper_data[idx]
285
+ )
286
+ transform = self.transform
287
+ frame_selector = self.frame_selector
288
+ if not keyframes:
289
+ return {"images": self._EMPTY_FRAMES, "categories": []}
290
+ if frame_selector is not None:
291
+ keyframes = frame_selector(keyframes)
292
+ frames = read_keyframes(fpath, keyframes)
293
+ if not frames:
294
+ return {"images": self._EMPTY_FRAMES, "categories": []}
295
+ frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
296
+ frames = torch.as_tensor(frames, device=torch.device("cpu"))
297
+ frames = frames[..., [2, 1, 0]] # RGB -> BGR
298
+ frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW
299
+ if transform is not None:
300
+ frames = transform(frames)
301
+ return {"images": frames, "categories": categories}
302
+
303
+ def __len__(self):
304
+ return len(self.video_list)
densepose/engine/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .trainer import Trainer
densepose/engine/trainer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ import os
7
+ from collections import OrderedDict
8
+ from typing import List, Optional, Union
9
+ import torch
10
+ from torch import nn
11
+
12
+ from detectron2.checkpoint import DetectionCheckpointer
13
+ from detectron2.config import CfgNode
14
+ from detectron2.engine import DefaultTrainer
15
+ from detectron2.evaluation import (
16
+ DatasetEvaluator,
17
+ DatasetEvaluators,
18
+ inference_on_dataset,
19
+ print_csv_format,
20
+ )
21
+ from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
22
+ from detectron2.utils import comm
23
+ from detectron2.utils.events import EventWriter, get_event_storage
24
+
25
+ from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
26
+ from densepose.data import (
27
+ DatasetMapper,
28
+ build_combined_loader,
29
+ build_detection_test_loader,
30
+ build_detection_train_loader,
31
+ build_inference_based_loaders,
32
+ has_inference_based_loaders,
33
+ )
34
+ from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
35
+ from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
36
+ from densepose.modeling.cse import Embedder
37
+
38
+
39
+ class SampleCountingLoader:
40
+ def __init__(self, loader):
41
+ self.loader = loader
42
+
43
+ def __iter__(self):
44
+ it = iter(self.loader)
45
+ storage = get_event_storage()
46
+ while True:
47
+ try:
48
+ batch = next(it)
49
+ num_inst_per_dataset = {}
50
+ for data in batch:
51
+ dataset_name = data["dataset"]
52
+ if dataset_name not in num_inst_per_dataset:
53
+ num_inst_per_dataset[dataset_name] = 0
54
+ num_inst = len(data["instances"])
55
+ num_inst_per_dataset[dataset_name] += num_inst
56
+ for dataset_name in num_inst_per_dataset:
57
+ storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
58
+ yield batch
59
+ except StopIteration:
60
+ break
61
+
62
+
63
+ class SampleCountMetricPrinter(EventWriter):
64
+ def __init__(self):
65
+ self.logger = logging.getLogger(__name__)
66
+
67
+ def write(self):
68
+ storage = get_event_storage()
69
+ batch_stats_strs = []
70
+ for key, buf in storage.histories().items():
71
+ if key.startswith("batch/"):
72
+ batch_stats_strs.append(f"{key} {buf.avg(20)}")
73
+ self.logger.info(", ".join(batch_stats_strs))
74
+
75
+
76
+ class Trainer(DefaultTrainer):
77
+ @classmethod
78
+ def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
79
+ if isinstance(model, nn.parallel.DistributedDataParallel):
80
+ model = model.module
81
+ if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
82
+ return model.roi_heads.embedder
83
+ return None
84
+
85
+ # TODO: the only reason to copy the base class code here is to pass the embedder from
86
+ # the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
87
+ @classmethod
88
+ def test(
89
+ cls,
90
+ cfg: CfgNode,
91
+ model: nn.Module,
92
+ evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
93
+ ):
94
+ """
95
+ Args:
96
+ cfg (CfgNode):
97
+ model (nn.Module):
98
+ evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
99
+ :meth:`build_evaluator`. Otherwise, must have the same length as
100
+ ``cfg.DATASETS.TEST``.
101
+
102
+ Returns:
103
+ dict: a dict of result metrics
104
+ """
105
+ logger = logging.getLogger(__name__)
106
+ if isinstance(evaluators, DatasetEvaluator):
107
+ evaluators = [evaluators]
108
+ if evaluators is not None:
109
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
110
+ len(cfg.DATASETS.TEST), len(evaluators)
111
+ )
112
+
113
+ results = OrderedDict()
114
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
115
+ data_loader = cls.build_test_loader(cfg, dataset_name)
116
+ # When evaluators are passed in as arguments,
117
+ # implicitly assume that evaluators can be created before data_loader.
118
+ if evaluators is not None:
119
+ evaluator = evaluators[idx]
120
+ else:
121
+ try:
122
+ embedder = cls.extract_embedder_from_model(model)
123
+ evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
124
+ except NotImplementedError:
125
+ logger.warn(
126
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
127
+ "or implement its `build_evaluator` method."
128
+ )
129
+ results[dataset_name] = {}
130
+ continue
131
+ if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
132
+ results_i = inference_on_dataset(model, data_loader, evaluator)
133
+ else:
134
+ results_i = {}
135
+ results[dataset_name] = results_i
136
+ if comm.is_main_process():
137
+ assert isinstance(
138
+ results_i, dict
139
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
140
+ results_i
141
+ )
142
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
143
+ print_csv_format(results_i)
144
+
145
+ if len(results) == 1:
146
+ results = list(results.values())[0]
147
+ return results
148
+
149
+ @classmethod
150
+ def build_evaluator(
151
+ cls,
152
+ cfg: CfgNode,
153
+ dataset_name: str,
154
+ output_folder: Optional[str] = None,
155
+ embedder: Optional[Embedder] = None,
156
+ ) -> DatasetEvaluators:
157
+ if output_folder is None:
158
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
159
+ evaluators = []
160
+ distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
161
+ # Note: we currently use COCO evaluator for both COCO and LVIS datasets
162
+ # to have compatible metrics. LVIS bbox evaluator could also be used
163
+ # with an adapter to properly handle filtered / mapped categories
164
+ # evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
165
+ # if evaluator_type == "coco":
166
+ # evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
167
+ # elif evaluator_type == "lvis":
168
+ # evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
169
+ evaluators.append(
170
+ Detectron2COCOEvaluatorAdapter(
171
+ dataset_name, output_dir=output_folder, distributed=distributed
172
+ )
173
+ )
174
+ if cfg.MODEL.DENSEPOSE_ON:
175
+ storage = build_densepose_evaluator_storage(cfg, output_folder)
176
+ evaluators.append(
177
+ DensePoseCOCOEvaluator(
178
+ dataset_name,
179
+ distributed,
180
+ output_folder,
181
+ evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
182
+ min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
183
+ storage=storage,
184
+ embedder=embedder,
185
+ should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
186
+ mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
187
+ )
188
+ )
189
+ return DatasetEvaluators(evaluators)
190
+
191
+ @classmethod
192
+ def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
193
+ params = get_default_optimizer_params(
194
+ model,
195
+ base_lr=cfg.SOLVER.BASE_LR,
196
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
197
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
198
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
199
+ overrides={
200
+ "features": {
201
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
202
+ },
203
+ "embeddings": {
204
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
205
+ },
206
+ },
207
+ )
208
+ optimizer = torch.optim.SGD(
209
+ params,
210
+ cfg.SOLVER.BASE_LR,
211
+ momentum=cfg.SOLVER.MOMENTUM,
212
+ nesterov=cfg.SOLVER.NESTEROV,
213
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
214
+ )
215
+ # pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
216
+ return maybe_add_gradient_clipping(cfg, optimizer)
217
+
218
+ @classmethod
219
+ def build_test_loader(cls, cfg: CfgNode, dataset_name):
220
+ return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
221
+
222
+ @classmethod
223
+ def build_train_loader(cls, cfg: CfgNode):
224
+ data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
225
+ if not has_inference_based_loaders(cfg):
226
+ return data_loader
227
+ model = cls.build_model(cfg)
228
+ model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
229
+ DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
230
+ inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
231
+ loaders = [data_loader] + inference_based_loaders
232
+ ratios = [1.0] + ratios
233
+ combined_data_loader = build_combined_loader(cfg, loaders, ratios)
234
+ sample_counting_loader = SampleCountingLoader(combined_data_loader)
235
+ return sample_counting_loader
236
+
237
+ def build_writers(self):
238
+ writers = super().build_writers()
239
+ writers.append(SampleCountMetricPrinter())
240
+ return writers
241
+
242
+ @classmethod
243
+ def test_with_TTA(cls, cfg: CfgNode, model):
244
+ logger = logging.getLogger("detectron2.trainer")
245
+ # In the end of training, run an evaluation with TTA
246
+ # Only support some R-CNN models.
247
+ logger.info("Running inference with test-time augmentation ...")
248
+ transform_data = load_from_cfg(cfg)
249
+ model = DensePoseGeneralizedRCNNWithTTA(
250
+ cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
251
+ )
252
+ evaluators = [
253
+ cls.build_evaluator(
254
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
255
+ )
256
+ for name in cfg.DATASETS.TEST
257
+ ]
258
+ res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
259
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
260
+ return res
densepose/evaluation/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .evaluator import DensePoseCOCOEvaluator
densepose/evaluation/d2_evaluator_adapter.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.data.catalog import Metadata
6
+ from detectron2.evaluation import COCOEvaluator
7
+
8
+ from densepose.data.datasets.coco import (
9
+ get_contiguous_id_to_category_id_map,
10
+ maybe_filter_categories_cocoapi,
11
+ )
12
+
13
+
14
+ def _maybe_add_iscrowd_annotations(cocoapi) -> None:
15
+ for ann in cocoapi.dataset["annotations"]:
16
+ if "iscrowd" not in ann:
17
+ ann["iscrowd"] = 0
18
+
19
+
20
+ class Detectron2COCOEvaluatorAdapter(COCOEvaluator):
21
+ def __init__(
22
+ self,
23
+ dataset_name,
24
+ output_dir=None,
25
+ distributed=True,
26
+ ):
27
+ super().__init__(dataset_name, output_dir=output_dir, distributed=distributed)
28
+ maybe_filter_categories_cocoapi(dataset_name, self._coco_api)
29
+ _maybe_add_iscrowd_annotations(self._coco_api)
30
+ # substitute category metadata to account for categories
31
+ # that are mapped to the same contiguous id
32
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
33
+ self._maybe_substitute_metadata()
34
+
35
+ def _maybe_substitute_metadata(self):
36
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(self._metadata)
37
+ cat_id_2_cont_id = self._metadata.thing_dataset_id_to_contiguous_id
38
+ if len(cont_id_2_cat_id) == len(cat_id_2_cont_id):
39
+ return
40
+
41
+ cat_id_2_cont_id_injective = {}
42
+ for cat_id, cont_id in cat_id_2_cont_id.items():
43
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
44
+ cat_id_2_cont_id_injective[cat_id] = cont_id
45
+
46
+ metadata_new = Metadata(name=self._metadata.name)
47
+ for key, value in self._metadata.__dict__.items():
48
+ if key == "thing_dataset_id_to_contiguous_id":
49
+ setattr(metadata_new, key, cat_id_2_cont_id_injective)
50
+ else:
51
+ setattr(metadata_new, key, value)
52
+ self._metadata = metadata_new
densepose/evaluation/densepose_coco_evaluation.py ADDED
@@ -0,0 +1,1305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # This is a modified version of cocoeval.py where we also have the densepose evaluation.
7
+
8
+ # pyre-unsafe
9
+
10
+ __author__ = "tsungyi"
11
+
12
+ import copy
13
+ import datetime
14
+ import logging
15
+ import numpy as np
16
+ import pickle
17
+ import time
18
+ from collections import defaultdict
19
+ from enum import Enum
20
+ from typing import Any, Dict, Tuple
21
+ import scipy.spatial.distance as ssd
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from pycocotools import mask as maskUtils
25
+ from scipy.io import loadmat
26
+ from scipy.ndimage import zoom as spzoom
27
+
28
+ from detectron2.utils.file_io import PathManager
29
+
30
+ from densepose.converters.chart_output_to_chart_result import resample_uv_tensors_to_bbox
31
+ from densepose.converters.segm_to_mask import (
32
+ resample_coarse_segm_tensor_to_bbox,
33
+ resample_fine_and_coarse_segm_tensors_to_bbox,
34
+ )
35
+ from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
36
+ from densepose.structures import DensePoseDataRelative
37
+ from densepose.structures.mesh import create_mesh
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class DensePoseEvalMode(str, Enum):
43
+ # use both masks and geodesic distances (GPS * IOU) to compute scores
44
+ GPSM = "gpsm"
45
+ # use only geodesic distances (GPS) to compute scores
46
+ GPS = "gps"
47
+ # use only masks (IOU) to compute scores
48
+ IOU = "iou"
49
+
50
+
51
+ class DensePoseDataMode(str, Enum):
52
+ # use estimated IUV data (default mode)
53
+ IUV_DT = "iuvdt"
54
+ # use ground truth IUV data
55
+ IUV_GT = "iuvgt"
56
+ # use ground truth labels I and set UV to 0
57
+ I_GT_UV_0 = "igtuv0"
58
+ # use ground truth labels I and estimated UV coordinates
59
+ I_GT_UV_DT = "igtuvdt"
60
+ # use estimated labels I and set UV to 0
61
+ I_DT_UV_0 = "idtuv0"
62
+
63
+
64
+ class DensePoseCocoEval:
65
+ # Interface for evaluating detection on the Microsoft COCO dataset.
66
+ #
67
+ # The usage for CocoEval is as follows:
68
+ # cocoGt=..., cocoDt=... # load dataset and results
69
+ # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
70
+ # E.params.recThrs = ...; # set parameters as desired
71
+ # E.evaluate(); # run per image evaluation
72
+ # E.accumulate(); # accumulate per image results
73
+ # E.summarize(); # display summary metrics of results
74
+ # For example usage see evalDemo.m and http://mscoco.org/.
75
+ #
76
+ # The evaluation parameters are as follows (defaults in brackets):
77
+ # imgIds - [all] N img ids to use for evaluation
78
+ # catIds - [all] K cat ids to use for evaluation
79
+ # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
80
+ # recThrs - [0:.01:1] R=101 recall thresholds for evaluation
81
+ # areaRng - [...] A=4 object area ranges for evaluation
82
+ # maxDets - [1 10 100] M=3 thresholds on max detections per image
83
+ # iouType - ['segm'] set iouType to 'segm', 'bbox', 'keypoints' or 'densepose'
84
+ # iouType replaced the now DEPRECATED useSegm parameter.
85
+ # useCats - [1] if true use category labels for evaluation
86
+ # Note: if useCats=0 category labels are ignored as in proposal scoring.
87
+ # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
88
+ #
89
+ # evaluate(): evaluates detections on every image and every category and
90
+ # concats the results into the "evalImgs" with fields:
91
+ # dtIds - [1xD] id for each of the D detections (dt)
92
+ # gtIds - [1xG] id for each of the G ground truths (gt)
93
+ # dtMatches - [TxD] matching gt id at each IoU or 0
94
+ # gtMatches - [TxG] matching dt id at each IoU or 0
95
+ # dtScores - [1xD] confidence of each dt
96
+ # gtIgnore - [1xG] ignore flag for each gt
97
+ # dtIgnore - [TxD] ignore flag for each dt at each IoU
98
+ #
99
+ # accumulate(): accumulates the per-image, per-category evaluation
100
+ # results in "evalImgs" into the dictionary "eval" with fields:
101
+ # params - parameters used for evaluation
102
+ # date - date evaluation was performed
103
+ # counts - [T,R,K,A,M] parameter dimensions (see above)
104
+ # precision - [TxRxKxAxM] precision for every evaluation setting
105
+ # recall - [TxKxAxM] max recall for every evaluation setting
106
+ # Note: precision and recall==-1 for settings with no gt objects.
107
+ #
108
+ # See also coco, mask, pycocoDemo, pycocoEvalDemo
109
+ #
110
+ # Microsoft COCO Toolbox. version 2.0
111
+ # Data, paper, and tutorials available at: http://mscoco.org/
112
+ # Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
113
+ # Licensed under the Simplified BSD License [see coco/license.txt]
114
+ def __init__(
115
+ self,
116
+ cocoGt=None,
117
+ cocoDt=None,
118
+ iouType: str = "densepose",
119
+ multi_storage=None,
120
+ embedder=None,
121
+ dpEvalMode: DensePoseEvalMode = DensePoseEvalMode.GPS,
122
+ dpDataMode: DensePoseDataMode = DensePoseDataMode.IUV_DT,
123
+ ):
124
+ """
125
+ Initialize CocoEval using coco APIs for gt and dt
126
+ :param cocoGt: coco object with ground truth annotations
127
+ :param cocoDt: coco object with detection results
128
+ :return: None
129
+ """
130
+ self.cocoGt = cocoGt # ground truth COCO API
131
+ self.cocoDt = cocoDt # detections COCO API
132
+ self.multi_storage = multi_storage
133
+ self.embedder = embedder
134
+ self._dpEvalMode = dpEvalMode
135
+ self._dpDataMode = dpDataMode
136
+ self.evalImgs = defaultdict(list) # per-image per-category eval results [KxAxI]
137
+ self.eval = {} # accumulated evaluation results
138
+ self._gts = defaultdict(list) # gt for evaluation
139
+ self._dts = defaultdict(list) # dt for evaluation
140
+ self.params = Params(iouType=iouType) # parameters
141
+ self._paramsEval = {} # parameters for evaluation
142
+ self.stats = [] # result summarization
143
+ self.ious = {} # ious between all gts and dts
144
+ if cocoGt is not None:
145
+ self.params.imgIds = sorted(cocoGt.getImgIds())
146
+ self.params.catIds = sorted(cocoGt.getCatIds())
147
+ self.ignoreThrBB = 0.7
148
+ self.ignoreThrUV = 0.9
149
+
150
+ def _loadGEval(self):
151
+ smpl_subdiv_fpath = PathManager.get_local_path(
152
+ "https://dl.fbaipublicfiles.com/densepose/data/SMPL_subdiv.mat"
153
+ )
154
+ pdist_transform_fpath = PathManager.get_local_path(
155
+ "https://dl.fbaipublicfiles.com/densepose/data/SMPL_SUBDIV_TRANSFORM.mat"
156
+ )
157
+ pdist_matrix_fpath = PathManager.get_local_path(
158
+ "https://dl.fbaipublicfiles.com/densepose/data/Pdist_matrix.pkl", timeout_sec=120
159
+ )
160
+ SMPL_subdiv = loadmat(smpl_subdiv_fpath)
161
+ self.PDIST_transform = loadmat(pdist_transform_fpath)
162
+ self.PDIST_transform = self.PDIST_transform["index"].squeeze()
163
+ UV = np.array([SMPL_subdiv["U_subdiv"], SMPL_subdiv["V_subdiv"]]).squeeze()
164
+ ClosestVertInds = np.arange(UV.shape[1]) + 1
165
+ self.Part_UVs = []
166
+ self.Part_ClosestVertInds = []
167
+ for i in np.arange(24):
168
+ self.Part_UVs.append(UV[:, SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)])
169
+ self.Part_ClosestVertInds.append(
170
+ ClosestVertInds[SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)]
171
+ )
172
+
173
+ with open(pdist_matrix_fpath, "rb") as hFile:
174
+ arrays = pickle.load(hFile, encoding="latin1")
175
+ self.Pdist_matrix = arrays["Pdist_matrix"]
176
+ self.Part_ids = np.array(SMPL_subdiv["Part_ID_subdiv"].squeeze())
177
+ # Mean geodesic distances for parts.
178
+ self.Mean_Distances = np.array([0, 0.351, 0.107, 0.126, 0.237, 0.173, 0.142, 0.128, 0.150])
179
+ # Coarse Part labels.
180
+ self.CoarseParts = np.array(
181
+ [0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8]
182
+ )
183
+
184
+ def _prepare(self):
185
+ """
186
+ Prepare ._gts and ._dts for evaluation based on params
187
+ :return: None
188
+ """
189
+
190
+ def _toMask(anns, coco):
191
+ # modify ann['segmentation'] by reference
192
+ for ann in anns:
193
+ # safeguard for invalid segmentation annotation;
194
+ # annotations containing empty lists exist in the posetrack
195
+ # dataset. This is not a correct segmentation annotation
196
+ # in terms of COCO format; we need to deal with it somehow
197
+ segm = ann["segmentation"]
198
+ if type(segm) is list and len(segm) == 0:
199
+ ann["segmentation"] = None
200
+ continue
201
+ rle = coco.annToRLE(ann)
202
+ ann["segmentation"] = rle
203
+
204
+ def _getIgnoreRegion(iid, coco):
205
+ img = coco.imgs[iid]
206
+
207
+ if "ignore_regions_x" not in img.keys():
208
+ return None
209
+
210
+ if len(img["ignore_regions_x"]) == 0:
211
+ return None
212
+
213
+ rgns_merged = [
214
+ [v for xy in zip(region_x, region_y) for v in xy]
215
+ for region_x, region_y in zip(img["ignore_regions_x"], img["ignore_regions_y"])
216
+ ]
217
+ rles = maskUtils.frPyObjects(rgns_merged, img["height"], img["width"])
218
+ rle = maskUtils.merge(rles)
219
+ return maskUtils.decode(rle)
220
+
221
+ def _checkIgnore(dt, iregion):
222
+ if iregion is None:
223
+ return True
224
+
225
+ bb = np.array(dt["bbox"]).astype(int)
226
+ x1, y1, x2, y2 = bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3]
227
+ x2 = min([x2, iregion.shape[1]])
228
+ y2 = min([y2, iregion.shape[0]])
229
+
230
+ if bb[2] * bb[3] == 0:
231
+ return False
232
+
233
+ crop_iregion = iregion[y1:y2, x1:x2]
234
+
235
+ if crop_iregion.sum() == 0:
236
+ return True
237
+
238
+ if "densepose" not in dt.keys(): # filtering boxes
239
+ return crop_iregion.sum() / bb[2] / bb[3] < self.ignoreThrBB
240
+
241
+ # filtering UVs
242
+ ignoremask = np.require(crop_iregion, requirements=["F"])
243
+ mask = self._extract_mask(dt)
244
+ uvmask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"])
245
+ uvmask_ = maskUtils.encode(uvmask)
246
+ ignoremask_ = maskUtils.encode(ignoremask)
247
+ uviou = maskUtils.iou([uvmask_], [ignoremask_], [1])[0]
248
+ return uviou < self.ignoreThrUV
249
+
250
+ p = self.params
251
+
252
+ if p.useCats:
253
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
254
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
255
+ else:
256
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
257
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
258
+
259
+ imns = self.cocoGt.loadImgs(p.imgIds)
260
+ self.size_mapping = {}
261
+ for im in imns:
262
+ self.size_mapping[im["id"]] = [im["height"], im["width"]]
263
+
264
+ # if iouType == 'uv', add point gt annotations
265
+ if p.iouType == "densepose":
266
+ self._loadGEval()
267
+
268
+ # convert ground truth to mask if iouType == 'segm'
269
+ if p.iouType == "segm":
270
+ _toMask(gts, self.cocoGt)
271
+ _toMask(dts, self.cocoDt)
272
+
273
+ # set ignore flag
274
+ for gt in gts:
275
+ gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
276
+ gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
277
+ if p.iouType == "keypoints":
278
+ gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
279
+ if p.iouType == "densepose":
280
+ gt["ignore"] = ("dp_x" in gt) == 0
281
+ if p.iouType == "segm":
282
+ gt["ignore"] = gt["segmentation"] is None
283
+
284
+ self._gts = defaultdict(list) # gt for evaluation
285
+ self._dts = defaultdict(list) # dt for evaluation
286
+ self._igrgns = defaultdict(list)
287
+
288
+ for gt in gts:
289
+ iid = gt["image_id"]
290
+ if iid not in self._igrgns.keys():
291
+ self._igrgns[iid] = _getIgnoreRegion(iid, self.cocoGt)
292
+ if _checkIgnore(gt, self._igrgns[iid]):
293
+ self._gts[iid, gt["category_id"]].append(gt)
294
+ for dt in dts:
295
+ iid = dt["image_id"]
296
+ if (iid not in self._igrgns) or _checkIgnore(dt, self._igrgns[iid]):
297
+ self._dts[iid, dt["category_id"]].append(dt)
298
+
299
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results
300
+ self.eval = {} # accumulated evaluation results
301
+
302
+ def evaluate(self):
303
+ """
304
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
305
+ :return: None
306
+ """
307
+ tic = time.time()
308
+ logger.info("Running per image DensePose evaluation... {}".format(self.params.iouType))
309
+ p = self.params
310
+ # add backward compatibility if useSegm is specified in params
311
+ if p.useSegm is not None:
312
+ p.iouType = "segm" if p.useSegm == 1 else "bbox"
313
+ logger.info("useSegm (deprecated) is not None. Running DensePose evaluation")
314
+ p.imgIds = list(np.unique(p.imgIds))
315
+ if p.useCats:
316
+ p.catIds = list(np.unique(p.catIds))
317
+ p.maxDets = sorted(p.maxDets)
318
+ self.params = p
319
+
320
+ self._prepare()
321
+ # loop through images, area range, max detection number
322
+ catIds = p.catIds if p.useCats else [-1]
323
+
324
+ if p.iouType in ["segm", "bbox"]:
325
+ computeIoU = self.computeIoU
326
+ elif p.iouType == "keypoints":
327
+ computeIoU = self.computeOks
328
+ elif p.iouType == "densepose":
329
+ computeIoU = self.computeOgps
330
+ if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}:
331
+ self.real_ious = {
332
+ (imgId, catId): self.computeDPIoU(imgId, catId)
333
+ for imgId in p.imgIds
334
+ for catId in catIds
335
+ }
336
+
337
+ self.ious = {
338
+ (imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds
339
+ }
340
+
341
+ evaluateImg = self.evaluateImg
342
+ maxDet = p.maxDets[-1]
343
+ self.evalImgs = [
344
+ evaluateImg(imgId, catId, areaRng, maxDet)
345
+ for catId in catIds
346
+ for areaRng in p.areaRng
347
+ for imgId in p.imgIds
348
+ ]
349
+ self._paramsEval = copy.deepcopy(self.params)
350
+ toc = time.time()
351
+ logger.info("DensePose evaluation DONE (t={:0.2f}s).".format(toc - tic))
352
+
353
+ def getDensePoseMask(self, polys):
354
+ maskGen = np.zeros([256, 256])
355
+ stop = min(len(polys) + 1, 15)
356
+ for i in range(1, stop):
357
+ if polys[i - 1]:
358
+ currentMask = maskUtils.decode(polys[i - 1])
359
+ maskGen[currentMask > 0] = i
360
+ return maskGen
361
+
362
+ def _generate_rlemask_on_image(self, mask, imgId, data):
363
+ bbox_xywh = np.array(data["bbox"])
364
+ x, y, w, h = bbox_xywh
365
+ im_h, im_w = self.size_mapping[imgId]
366
+ im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
367
+ if mask is not None:
368
+ x0 = max(int(x), 0)
369
+ x1 = min(int(x + w), im_w, int(x) + mask.shape[1])
370
+ y0 = max(int(y), 0)
371
+ y1 = min(int(y + h), im_h, int(y) + mask.shape[0])
372
+ y = int(y)
373
+ x = int(x)
374
+ im_mask[y0:y1, x0:x1] = mask[y0 - y : y1 - y, x0 - x : x1 - x]
375
+ im_mask = np.require(np.asarray(im_mask > 0), dtype=np.uint8, requirements=["F"])
376
+ rle_mask = maskUtils.encode(np.array(im_mask[:, :, np.newaxis], order="F"))[0]
377
+ return rle_mask
378
+
379
+ def computeDPIoU(self, imgId, catId):
380
+ p = self.params
381
+ if p.useCats:
382
+ gt = self._gts[imgId, catId]
383
+ dt = self._dts[imgId, catId]
384
+ else:
385
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
386
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
387
+ if len(gt) == 0 and len(dt) == 0:
388
+ return []
389
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
390
+ dt = [dt[i] for i in inds]
391
+ if len(dt) > p.maxDets[-1]:
392
+ dt = dt[0 : p.maxDets[-1]]
393
+
394
+ gtmasks = []
395
+ for g in gt:
396
+ if DensePoseDataRelative.S_KEY in g:
397
+ # convert DensePose mask to a binary mask
398
+ mask = np.minimum(self.getDensePoseMask(g[DensePoseDataRelative.S_KEY]), 1.0)
399
+ _, _, w, h = g["bbox"]
400
+ scale_x = float(max(w, 1)) / mask.shape[1]
401
+ scale_y = float(max(h, 1)) / mask.shape[0]
402
+ mask = spzoom(mask, (scale_y, scale_x), order=1, prefilter=False)
403
+ mask = np.array(mask > 0.5, dtype=np.uint8)
404
+ rle_mask = self._generate_rlemask_on_image(mask, imgId, g)
405
+ elif "segmentation" in g:
406
+ segmentation = g["segmentation"]
407
+ if isinstance(segmentation, list) and segmentation:
408
+ # polygons
409
+ im_h, im_w = self.size_mapping[imgId]
410
+ rles = maskUtils.frPyObjects(segmentation, im_h, im_w)
411
+ rle_mask = maskUtils.merge(rles)
412
+ elif isinstance(segmentation, dict):
413
+ if isinstance(segmentation["counts"], list):
414
+ # uncompressed RLE
415
+ im_h, im_w = self.size_mapping[imgId]
416
+ rle_mask = maskUtils.frPyObjects(segmentation, im_h, im_w)
417
+ else:
418
+ # compressed RLE
419
+ rle_mask = segmentation
420
+ else:
421
+ rle_mask = self._generate_rlemask_on_image(None, imgId, g)
422
+ else:
423
+ rle_mask = self._generate_rlemask_on_image(None, imgId, g)
424
+ gtmasks.append(rle_mask)
425
+
426
+ dtmasks = []
427
+ for d in dt:
428
+ mask = self._extract_mask(d)
429
+ mask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"])
430
+ rle_mask = self._generate_rlemask_on_image(mask, imgId, d)
431
+ dtmasks.append(rle_mask)
432
+
433
+ # compute iou between each dt and gt region
434
+ iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
435
+ iousDP = maskUtils.iou(dtmasks, gtmasks, iscrowd)
436
+ return iousDP
437
+
438
+ def computeIoU(self, imgId, catId):
439
+ p = self.params
440
+ if p.useCats:
441
+ gt = self._gts[imgId, catId]
442
+ dt = self._dts[imgId, catId]
443
+ else:
444
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
445
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
446
+ if len(gt) == 0 and len(dt) == 0:
447
+ return []
448
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
449
+ dt = [dt[i] for i in inds]
450
+ if len(dt) > p.maxDets[-1]:
451
+ dt = dt[0 : p.maxDets[-1]]
452
+
453
+ if p.iouType == "segm":
454
+ g = [g["segmentation"] for g in gt if g["segmentation"] is not None]
455
+ d = [d["segmentation"] for d in dt if d["segmentation"] is not None]
456
+ elif p.iouType == "bbox":
457
+ g = [g["bbox"] for g in gt]
458
+ d = [d["bbox"] for d in dt]
459
+ else:
460
+ raise Exception("unknown iouType for iou computation")
461
+
462
+ # compute iou between each dt and gt region
463
+ iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
464
+ ious = maskUtils.iou(d, g, iscrowd)
465
+ return ious
466
+
467
+ def computeOks(self, imgId, catId):
468
+ p = self.params
469
+ # dimension here should be Nxm
470
+ gts = self._gts[imgId, catId]
471
+ dts = self._dts[imgId, catId]
472
+ inds = np.argsort([-d["score"] for d in dts], kind="mergesort")
473
+ dts = [dts[i] for i in inds]
474
+ if len(dts) > p.maxDets[-1]:
475
+ dts = dts[0 : p.maxDets[-1]]
476
+ # if len(gts) == 0 and len(dts) == 0:
477
+ if len(gts) == 0 or len(dts) == 0:
478
+ return []
479
+ ious = np.zeros((len(dts), len(gts)))
480
+ sigmas = (
481
+ np.array(
482
+ [
483
+ 0.26,
484
+ 0.25,
485
+ 0.25,
486
+ 0.35,
487
+ 0.35,
488
+ 0.79,
489
+ 0.79,
490
+ 0.72,
491
+ 0.72,
492
+ 0.62,
493
+ 0.62,
494
+ 1.07,
495
+ 1.07,
496
+ 0.87,
497
+ 0.87,
498
+ 0.89,
499
+ 0.89,
500
+ ]
501
+ )
502
+ / 10.0
503
+ )
504
+ vars = (sigmas * 2) ** 2
505
+ k = len(sigmas)
506
+ # compute oks between each detection and ground truth object
507
+ for j, gt in enumerate(gts):
508
+ # create bounds for ignore regions(double the gt bbox)
509
+ g = np.array(gt["keypoints"])
510
+ xg = g[0::3]
511
+ yg = g[1::3]
512
+ vg = g[2::3]
513
+ k1 = np.count_nonzero(vg > 0)
514
+ bb = gt["bbox"]
515
+ x0 = bb[0] - bb[2]
516
+ x1 = bb[0] + bb[2] * 2
517
+ y0 = bb[1] - bb[3]
518
+ y1 = bb[1] + bb[3] * 2
519
+ for i, dt in enumerate(dts):
520
+ d = np.array(dt["keypoints"])
521
+ xd = d[0::3]
522
+ yd = d[1::3]
523
+ if k1 > 0:
524
+ # measure the per-keypoint distance if keypoints visible
525
+ dx = xd - xg
526
+ dy = yd - yg
527
+ else:
528
+ # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
529
+ z = np.zeros(k)
530
+ dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
531
+ dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
532
+ e = (dx**2 + dy**2) / vars / (gt["area"] + np.spacing(1)) / 2
533
+ if k1 > 0:
534
+ e = e[vg > 0]
535
+ ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
536
+ return ious
537
+
538
+ def _extract_mask(self, dt: Dict[str, Any]) -> np.ndarray:
539
+ if "densepose" in dt:
540
+ densepose_results_quantized = dt["densepose"]
541
+ return densepose_results_quantized.labels_uv_uint8[0].numpy()
542
+ elif "cse_mask" in dt:
543
+ return dt["cse_mask"]
544
+ elif "coarse_segm" in dt:
545
+ dy = max(int(dt["bbox"][3]), 1)
546
+ dx = max(int(dt["bbox"][2]), 1)
547
+ return (
548
+ F.interpolate(
549
+ dt["coarse_segm"].unsqueeze(0),
550
+ (dy, dx),
551
+ mode="bilinear",
552
+ align_corners=False,
553
+ )
554
+ .squeeze(0)
555
+ .argmax(0)
556
+ .numpy()
557
+ .astype(np.uint8)
558
+ )
559
+ elif "record_id" in dt:
560
+ assert (
561
+ self.multi_storage is not None
562
+ ), f"Storage record id encountered in a detection {dt}, but no storage provided!"
563
+ record = self.multi_storage.get(dt["rank"], dt["record_id"])
564
+ coarse_segm = record["coarse_segm"]
565
+ dy = max(int(dt["bbox"][3]), 1)
566
+ dx = max(int(dt["bbox"][2]), 1)
567
+ return (
568
+ F.interpolate(
569
+ coarse_segm.unsqueeze(0),
570
+ (dy, dx),
571
+ mode="bilinear",
572
+ align_corners=False,
573
+ )
574
+ .squeeze(0)
575
+ .argmax(0)
576
+ .numpy()
577
+ .astype(np.uint8)
578
+ )
579
+ else:
580
+ raise Exception(f"No mask data in the detection: {dt}")
581
+ raise ValueError('The prediction dict needs to contain either "densepose" or "cse_mask"')
582
+
583
+ def _extract_iuv(
584
+ self, densepose_data: np.ndarray, py: np.ndarray, px: np.ndarray, gt: Dict[str, Any]
585
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
586
+ """
587
+ Extract arrays of I, U and V values at given points as numpy arrays
588
+ given the data mode stored in self._dpDataMode
589
+ """
590
+ if self._dpDataMode == DensePoseDataMode.IUV_DT:
591
+ # estimated labels and UV (default)
592
+ ipoints = densepose_data[0, py, px]
593
+ upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255.
594
+ vpoints = densepose_data[2, py, px] / 255.0
595
+ elif self._dpDataMode == DensePoseDataMode.IUV_GT:
596
+ # ground truth
597
+ ipoints = np.array(gt["dp_I"])
598
+ upoints = np.array(gt["dp_U"])
599
+ vpoints = np.array(gt["dp_V"])
600
+ elif self._dpDataMode == DensePoseDataMode.I_GT_UV_0:
601
+ # ground truth labels, UV = 0
602
+ ipoints = np.array(gt["dp_I"])
603
+ upoints = upoints * 0.0
604
+ vpoints = vpoints * 0.0
605
+ elif self._dpDataMode == DensePoseDataMode.I_GT_UV_DT:
606
+ # ground truth labels, estimated UV
607
+ ipoints = np.array(gt["dp_I"])
608
+ upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255.
609
+ vpoints = densepose_data[2, py, px] / 255.0
610
+ elif self._dpDataMode == DensePoseDataMode.I_DT_UV_0:
611
+ # estimated labels, UV = 0
612
+ ipoints = densepose_data[0, py, px]
613
+ upoints = upoints * 0.0
614
+ vpoints = vpoints * 0.0
615
+ else:
616
+ raise ValueError(f"Unknown data mode: {self._dpDataMode}")
617
+ return ipoints, upoints, vpoints
618
+
619
+ def computeOgps_single_pair(self, dt, gt, py, px, pt_mask):
620
+ if "densepose" in dt:
621
+ ipoints, upoints, vpoints = self.extract_iuv_from_quantized(dt, gt, py, px, pt_mask)
622
+ return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
623
+ elif "u" in dt:
624
+ ipoints, upoints, vpoints = self.extract_iuv_from_raw(dt, gt, py, px, pt_mask)
625
+ return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
626
+ elif "record_id" in dt:
627
+ assert (
628
+ self.multi_storage is not None
629
+ ), f"Storage record id encountered in detection {dt}, but no storage provided!"
630
+ record = self.multi_storage.get(dt["rank"], dt["record_id"])
631
+ record["bbox"] = dt["bbox"]
632
+ if "u" in record:
633
+ ipoints, upoints, vpoints = self.extract_iuv_from_raw(record, gt, py, px, pt_mask)
634
+ return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
635
+ elif "embedding" in record:
636
+ return self.computeOgps_single_pair_cse(
637
+ dt,
638
+ gt,
639
+ py,
640
+ px,
641
+ pt_mask,
642
+ record["coarse_segm"],
643
+ record["embedding"],
644
+ record["bbox"],
645
+ )
646
+ else:
647
+ raise Exception(f"Unknown record format: {record}")
648
+ elif "embedding" in dt:
649
+ return self.computeOgps_single_pair_cse(
650
+ dt, gt, py, px, pt_mask, dt["coarse_segm"], dt["embedding"], dt["bbox"]
651
+ )
652
+ raise Exception(f"Unknown detection format: {dt}")
653
+
654
+ def extract_iuv_from_quantized(self, dt, gt, py, px, pt_mask):
655
+ densepose_results_quantized = dt["densepose"]
656
+ ipoints, upoints, vpoints = self._extract_iuv(
657
+ densepose_results_quantized.labels_uv_uint8.numpy(), py, px, gt
658
+ )
659
+ ipoints[pt_mask == -1] = 0
660
+ return ipoints, upoints, vpoints
661
+
662
+ def extract_iuv_from_raw(self, dt, gt, py, px, pt_mask):
663
+ labels_dt = resample_fine_and_coarse_segm_tensors_to_bbox(
664
+ dt["fine_segm"].unsqueeze(0),
665
+ dt["coarse_segm"].unsqueeze(0),
666
+ dt["bbox"],
667
+ )
668
+ uv = resample_uv_tensors_to_bbox(
669
+ dt["u"].unsqueeze(0), dt["v"].unsqueeze(0), labels_dt.squeeze(0), dt["bbox"]
670
+ )
671
+ labels_uv_uint8 = torch.cat((labels_dt.byte(), (uv * 255).clamp(0, 255).byte()))
672
+ ipoints, upoints, vpoints = self._extract_iuv(labels_uv_uint8.numpy(), py, px, gt)
673
+ ipoints[pt_mask == -1] = 0
674
+ return ipoints, upoints, vpoints
675
+
676
+ def computeOgps_single_pair_iuv(self, dt, gt, ipoints, upoints, vpoints):
677
+ cVertsGT, ClosestVertsGTTransformed = self.findAllClosestVertsGT(gt)
678
+ cVerts = self.findAllClosestVertsUV(upoints, vpoints, ipoints)
679
+ # Get pairwise geodesic distances between gt and estimated mesh points.
680
+ dist = self.getDistancesUV(ClosestVertsGTTransformed, cVerts)
681
+ # Compute the Ogps measure.
682
+ # Find the mean geodesic normalization distance for
683
+ # each GT point, based on which part it is on.
684
+ Current_Mean_Distances = self.Mean_Distances[
685
+ self.CoarseParts[self.Part_ids[cVertsGT[cVertsGT > 0].astype(int) - 1]]
686
+ ]
687
+ return dist, Current_Mean_Distances
688
+
689
+ def computeOgps_single_pair_cse(
690
+ self, dt, gt, py, px, pt_mask, coarse_segm, embedding, bbox_xywh_abs
691
+ ):
692
+ # 0-based mesh vertex indices
693
+ cVertsGT = torch.as_tensor(gt["dp_vertex"], dtype=torch.int64)
694
+ # label for each pixel of the bbox, [H, W] tensor of long
695
+ labels_dt = resample_coarse_segm_tensor_to_bbox(
696
+ coarse_segm.unsqueeze(0), bbox_xywh_abs
697
+ ).squeeze(0)
698
+ x, y, w, h = bbox_xywh_abs
699
+ # embedding for each pixel of the bbox, [D, H, W] tensor of float32
700
+ embedding = F.interpolate(
701
+ embedding.unsqueeze(0), (int(h), int(w)), mode="bilinear", align_corners=False
702
+ ).squeeze(0)
703
+ # valid locations py, px
704
+ py_pt = torch.from_numpy(py[pt_mask > -1])
705
+ px_pt = torch.from_numpy(px[pt_mask > -1])
706
+ cVerts = torch.ones_like(cVertsGT) * -1
707
+ cVerts[pt_mask > -1] = self.findClosestVertsCse(
708
+ embedding, py_pt, px_pt, labels_dt, gt["ref_model"]
709
+ )
710
+ # Get pairwise geodesic distances between gt and estimated mesh points.
711
+ dist = self.getDistancesCse(cVertsGT, cVerts, gt["ref_model"])
712
+ # normalize distances
713
+ if (gt["ref_model"] == "smpl_27554") and ("dp_I" in gt):
714
+ Current_Mean_Distances = self.Mean_Distances[
715
+ self.CoarseParts[np.array(gt["dp_I"], dtype=int)]
716
+ ]
717
+ else:
718
+ Current_Mean_Distances = 0.255
719
+ return dist, Current_Mean_Distances
720
+
721
+ def computeOgps(self, imgId, catId):
722
+ p = self.params
723
+ # dimension here should be Nxm
724
+ g = self._gts[imgId, catId]
725
+ d = self._dts[imgId, catId]
726
+ inds = np.argsort([-d_["score"] for d_ in d], kind="mergesort")
727
+ d = [d[i] for i in inds]
728
+ if len(d) > p.maxDets[-1]:
729
+ d = d[0 : p.maxDets[-1]]
730
+ # if len(gts) == 0 and len(dts) == 0:
731
+ if len(g) == 0 or len(d) == 0:
732
+ return []
733
+ ious = np.zeros((len(d), len(g)))
734
+ # compute opgs between each detection and ground truth object
735
+ # sigma = self.sigma #0.255 # dist = 0.3m corresponds to ogps = 0.5
736
+ # 1 # dist = 0.3m corresponds to ogps = 0.96
737
+ # 1.45 # dist = 1.7m (person height) corresponds to ogps = 0.5)
738
+ for j, gt in enumerate(g):
739
+ if not gt["ignore"]:
740
+ g_ = gt["bbox"]
741
+ for i, dt in enumerate(d):
742
+ #
743
+ dy = int(dt["bbox"][3])
744
+ dx = int(dt["bbox"][2])
745
+ dp_x = np.array(gt["dp_x"]) * g_[2] / 255.0
746
+ dp_y = np.array(gt["dp_y"]) * g_[3] / 255.0
747
+ py = (dp_y + g_[1] - dt["bbox"][1]).astype(int)
748
+ px = (dp_x + g_[0] - dt["bbox"][0]).astype(int)
749
+ #
750
+ pts = np.zeros(len(px))
751
+ pts[px >= dx] = -1
752
+ pts[py >= dy] = -1
753
+ pts[px < 0] = -1
754
+ pts[py < 0] = -1
755
+ if len(pts) < 1:
756
+ ogps = 0.0
757
+ elif np.max(pts) == -1:
758
+ ogps = 0.0
759
+ else:
760
+ px[pts == -1] = 0
761
+ py[pts == -1] = 0
762
+ dists_between_matches, dist_norm_coeffs = self.computeOgps_single_pair(
763
+ dt, gt, py, px, pts
764
+ )
765
+ # Compute gps
766
+ ogps_values = np.exp(
767
+ -(dists_between_matches**2) / (2 * (dist_norm_coeffs**2))
768
+ )
769
+ #
770
+ ogps = np.mean(ogps_values) if len(ogps_values) > 0 else 0.0
771
+ ious[i, j] = ogps
772
+
773
+ gbb = [gt["bbox"] for gt in g]
774
+ dbb = [dt["bbox"] for dt in d]
775
+
776
+ # compute iou between each dt and gt region
777
+ iscrowd = [int(o.get("iscrowd", 0)) for o in g]
778
+ ious_bb = maskUtils.iou(dbb, gbb, iscrowd)
779
+ return ious, ious_bb
780
+
781
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
782
+ """
783
+ perform evaluation for single category and image
784
+ :return: dict (single image results)
785
+ """
786
+
787
+ p = self.params
788
+ if p.useCats:
789
+ gt = self._gts[imgId, catId]
790
+ dt = self._dts[imgId, catId]
791
+ else:
792
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
793
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
794
+ if len(gt) == 0 and len(dt) == 0:
795
+ return None
796
+
797
+ for g in gt:
798
+ # g['_ignore'] = g['ignore']
799
+ if g["ignore"] or (g["area"] < aRng[0] or g["area"] > aRng[1]):
800
+ g["_ignore"] = True
801
+ else:
802
+ g["_ignore"] = False
803
+
804
+ # sort dt highest score first, sort gt ignore last
805
+ gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
806
+ gt = [gt[i] for i in gtind]
807
+ dtind = np.argsort([-d["score"] for d in dt], kind="mergesort")
808
+ dt = [dt[i] for i in dtind[0:maxDet]]
809
+ iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
810
+ # load computed ious
811
+ if p.iouType == "densepose":
812
+ # print('Checking the length', len(self.ious[imgId, catId]))
813
+ # if len(self.ious[imgId, catId]) == 0:
814
+ # print(self.ious[imgId, catId])
815
+ ious = (
816
+ self.ious[imgId, catId][0][:, gtind]
817
+ if len(self.ious[imgId, catId]) > 0
818
+ else self.ious[imgId, catId]
819
+ )
820
+ ioubs = (
821
+ self.ious[imgId, catId][1][:, gtind]
822
+ if len(self.ious[imgId, catId]) > 0
823
+ else self.ious[imgId, catId]
824
+ )
825
+ if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}:
826
+ iousM = (
827
+ self.real_ious[imgId, catId][:, gtind]
828
+ if len(self.real_ious[imgId, catId]) > 0
829
+ else self.real_ious[imgId, catId]
830
+ )
831
+ else:
832
+ ious = (
833
+ self.ious[imgId, catId][:, gtind]
834
+ if len(self.ious[imgId, catId]) > 0
835
+ else self.ious[imgId, catId]
836
+ )
837
+
838
+ T = len(p.iouThrs)
839
+ G = len(gt)
840
+ D = len(dt)
841
+ gtm = np.zeros((T, G))
842
+ dtm = np.zeros((T, D))
843
+ gtIg = np.array([g["_ignore"] for g in gt])
844
+ dtIg = np.zeros((T, D))
845
+ if np.all(gtIg) and p.iouType == "densepose":
846
+ dtIg = np.logical_or(dtIg, True)
847
+
848
+ if len(ious) > 0: # and not p.iouType == 'densepose':
849
+ for tind, t in enumerate(p.iouThrs):
850
+ for dind, d in enumerate(dt):
851
+ # information about best match so far (m=-1 -> unmatched)
852
+ iou = min([t, 1 - 1e-10])
853
+ m = -1
854
+ for gind, _g in enumerate(gt):
855
+ # if this gt already matched, and not a crowd, continue
856
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
857
+ continue
858
+ # if dt matched to reg gt, and on ignore gt, stop
859
+ if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
860
+ break
861
+ if p.iouType == "densepose":
862
+ if self._dpEvalMode == DensePoseEvalMode.GPSM:
863
+ new_iou = np.sqrt(iousM[dind, gind] * ious[dind, gind])
864
+ elif self._dpEvalMode == DensePoseEvalMode.IOU:
865
+ new_iou = iousM[dind, gind]
866
+ elif self._dpEvalMode == DensePoseEvalMode.GPS:
867
+ new_iou = ious[dind, gind]
868
+ else:
869
+ new_iou = ious[dind, gind]
870
+ if new_iou < iou:
871
+ continue
872
+ if new_iou == 0.0:
873
+ continue
874
+ # if match successful and best so far, store appropriately
875
+ iou = new_iou
876
+ m = gind
877
+ # if match made store id of match for both dt and gt
878
+ if m == -1:
879
+ continue
880
+ dtIg[tind, dind] = gtIg[m]
881
+ dtm[tind, dind] = gt[m]["id"]
882
+ gtm[tind, m] = d["id"]
883
+
884
+ if p.iouType == "densepose":
885
+ if not len(ioubs) == 0:
886
+ for dind, d in enumerate(dt):
887
+ # information about best match so far (m=-1 -> unmatched)
888
+ if dtm[tind, dind] == 0:
889
+ ioub = 0.8
890
+ m = -1
891
+ for gind, _g in enumerate(gt):
892
+ # if this gt already matched, and not a crowd, continue
893
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
894
+ continue
895
+ # continue to next gt unless better match made
896
+ if ioubs[dind, gind] < ioub:
897
+ continue
898
+ # if match successful and best so far, store appropriately
899
+ ioub = ioubs[dind, gind]
900
+ m = gind
901
+ # if match made store id of match for both dt and gt
902
+ if m > -1:
903
+ dtIg[:, dind] = gtIg[m]
904
+ if gtIg[m]:
905
+ dtm[tind, dind] = gt[m]["id"]
906
+ gtm[tind, m] = d["id"]
907
+ # set unmatched detections outside of area range to ignore
908
+ a = np.array([d["area"] < aRng[0] or d["area"] > aRng[1] for d in dt]).reshape((1, len(dt)))
909
+ dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0)))
910
+ # store results for given image and category
911
+ # print('Done with the function', len(self.ious[imgId, catId]))
912
+ return {
913
+ "image_id": imgId,
914
+ "category_id": catId,
915
+ "aRng": aRng,
916
+ "maxDet": maxDet,
917
+ "dtIds": [d["id"] for d in dt],
918
+ "gtIds": [g["id"] for g in gt],
919
+ "dtMatches": dtm,
920
+ "gtMatches": gtm,
921
+ "dtScores": [d["score"] for d in dt],
922
+ "gtIgnore": gtIg,
923
+ "dtIgnore": dtIg,
924
+ }
925
+
926
+ def accumulate(self, p=None):
927
+ """
928
+ Accumulate per image evaluation results and store the result in self.eval
929
+ :param p: input params for evaluation
930
+ :return: None
931
+ """
932
+ logger.info("Accumulating evaluation results...")
933
+ tic = time.time()
934
+ if not self.evalImgs:
935
+ logger.info("Please run evaluate() first")
936
+ # allows input customized parameters
937
+ if p is None:
938
+ p = self.params
939
+ p.catIds = p.catIds if p.useCats == 1 else [-1]
940
+ T = len(p.iouThrs)
941
+ R = len(p.recThrs)
942
+ K = len(p.catIds) if p.useCats else 1
943
+ A = len(p.areaRng)
944
+ M = len(p.maxDets)
945
+ precision = -(np.ones((T, R, K, A, M))) # -1 for the precision of absent categories
946
+ recall = -(np.ones((T, K, A, M)))
947
+
948
+ # create dictionary for future indexing
949
+ logger.info("Categories: {}".format(p.catIds))
950
+ _pe = self._paramsEval
951
+ catIds = _pe.catIds if _pe.useCats else [-1]
952
+ setK = set(catIds)
953
+ setA = set(map(tuple, _pe.areaRng))
954
+ setM = set(_pe.maxDets)
955
+ setI = set(_pe.imgIds)
956
+ # get inds to evaluate
957
+ k_list = [n for n, k in enumerate(p.catIds) if k in setK]
958
+ m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
959
+ a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
960
+ i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
961
+ I0 = len(_pe.imgIds)
962
+ A0 = len(_pe.areaRng)
963
+ # retrieve E at each category, area range, and max number of detections
964
+ for k, k0 in enumerate(k_list):
965
+ Nk = k0 * A0 * I0
966
+ for a, a0 in enumerate(a_list):
967
+ Na = a0 * I0
968
+ for m, maxDet in enumerate(m_list):
969
+ E = [self.evalImgs[Nk + Na + i] for i in i_list]
970
+ E = [e for e in E if e is not None]
971
+ if len(E) == 0:
972
+ continue
973
+ dtScores = np.concatenate([e["dtScores"][0:maxDet] for e in E])
974
+
975
+ # different sorting method generates slightly different results.
976
+ # mergesort is used to be consistent as Matlab implementation.
977
+ inds = np.argsort(-dtScores, kind="mergesort")
978
+
979
+ dtm = np.concatenate([e["dtMatches"][:, 0:maxDet] for e in E], axis=1)[:, inds]
980
+ dtIg = np.concatenate([e["dtIgnore"][:, 0:maxDet] for e in E], axis=1)[:, inds]
981
+ gtIg = np.concatenate([e["gtIgnore"] for e in E])
982
+ npig = np.count_nonzero(gtIg == 0)
983
+ if npig == 0:
984
+ continue
985
+ tps = np.logical_and(dtm, np.logical_not(dtIg))
986
+ fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg))
987
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
988
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
989
+ for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
990
+ tp = np.array(tp)
991
+ fp = np.array(fp)
992
+ nd = len(tp)
993
+ rc = tp / npig
994
+ pr = tp / (fp + tp + np.spacing(1))
995
+ q = np.zeros((R,))
996
+
997
+ if nd:
998
+ recall[t, k, a, m] = rc[-1]
999
+ else:
1000
+ recall[t, k, a, m] = 0
1001
+
1002
+ # numpy is slow without cython optimization for accessing elements
1003
+ # use python array gets significant speed improvement
1004
+ pr = pr.tolist()
1005
+ q = q.tolist()
1006
+
1007
+ for i in range(nd - 1, 0, -1):
1008
+ if pr[i] > pr[i - 1]:
1009
+ pr[i - 1] = pr[i]
1010
+
1011
+ inds = np.searchsorted(rc, p.recThrs, side="left")
1012
+ try:
1013
+ for ri, pi in enumerate(inds):
1014
+ q[ri] = pr[pi]
1015
+ except Exception:
1016
+ pass
1017
+ precision[t, :, k, a, m] = np.array(q)
1018
+ logger.info(
1019
+ "Final: max precision {}, min precision {}".format(np.max(precision), np.min(precision))
1020
+ )
1021
+ self.eval = {
1022
+ "params": p,
1023
+ "counts": [T, R, K, A, M],
1024
+ "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
1025
+ "precision": precision,
1026
+ "recall": recall,
1027
+ }
1028
+ toc = time.time()
1029
+ logger.info("DONE (t={:0.2f}s).".format(toc - tic))
1030
+
1031
+ def summarize(self):
1032
+ """
1033
+ Compute and display summary metrics for evaluation results.
1034
+ Note this function can *only* be applied on the default parameter setting
1035
+ """
1036
+
1037
+ def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
1038
+ p = self.params
1039
+ iStr = " {:<18} {} @[ {}={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
1040
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
1041
+ typeStr = "(AP)" if ap == 1 else "(AR)"
1042
+ measure = "IoU"
1043
+ if self.params.iouType == "keypoints":
1044
+ measure = "OKS"
1045
+ elif self.params.iouType == "densepose":
1046
+ measure = "OGPS"
1047
+ iouStr = (
1048
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
1049
+ if iouThr is None
1050
+ else "{:0.2f}".format(iouThr)
1051
+ )
1052
+
1053
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
1054
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
1055
+ if ap == 1:
1056
+ # dimension of precision: [TxRxKxAxM]
1057
+ s = self.eval["precision"]
1058
+ # IoU
1059
+ if iouThr is not None:
1060
+ t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
1061
+ s = s[t]
1062
+ s = s[:, :, :, aind, mind]
1063
+ else:
1064
+ # dimension of recall: [TxKxAxM]
1065
+ s = self.eval["recall"]
1066
+ if iouThr is not None:
1067
+ t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
1068
+ s = s[t]
1069
+ s = s[:, :, aind, mind]
1070
+ if len(s[s > -1]) == 0:
1071
+ mean_s = -1
1072
+ else:
1073
+ mean_s = np.mean(s[s > -1])
1074
+ logger.info(iStr.format(titleStr, typeStr, measure, iouStr, areaRng, maxDets, mean_s))
1075
+ return mean_s
1076
+
1077
+ def _summarizeDets():
1078
+ stats = np.zeros((12,))
1079
+ stats[0] = _summarize(1)
1080
+ stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
1081
+ stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
1082
+ stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
1083
+ stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
1084
+ stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
1085
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
1086
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
1087
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
1088
+ stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
1089
+ stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
1090
+ stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
1091
+ return stats
1092
+
1093
+ def _summarizeKps():
1094
+ stats = np.zeros((10,))
1095
+ stats[0] = _summarize(1, maxDets=20)
1096
+ stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
1097
+ stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
1098
+ stats[3] = _summarize(1, maxDets=20, areaRng="medium")
1099
+ stats[4] = _summarize(1, maxDets=20, areaRng="large")
1100
+ stats[5] = _summarize(0, maxDets=20)
1101
+ stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
1102
+ stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
1103
+ stats[8] = _summarize(0, maxDets=20, areaRng="medium")
1104
+ stats[9] = _summarize(0, maxDets=20, areaRng="large")
1105
+ return stats
1106
+
1107
+ def _summarizeUvs():
1108
+ stats = [_summarize(1, maxDets=self.params.maxDets[0])]
1109
+ min_threshold = self.params.iouThrs.min()
1110
+ if min_threshold <= 0.201:
1111
+ stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.2)]
1112
+ if min_threshold <= 0.301:
1113
+ stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.3)]
1114
+ if min_threshold <= 0.401:
1115
+ stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.4)]
1116
+ stats += [
1117
+ _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5),
1118
+ _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75),
1119
+ _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium"),
1120
+ _summarize(1, maxDets=self.params.maxDets[0], areaRng="large"),
1121
+ _summarize(0, maxDets=self.params.maxDets[0]),
1122
+ _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5),
1123
+ _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75),
1124
+ _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium"),
1125
+ _summarize(0, maxDets=self.params.maxDets[0], areaRng="large"),
1126
+ ]
1127
+ return np.array(stats)
1128
+
1129
+ def _summarizeUvsOld():
1130
+ stats = np.zeros((18,))
1131
+ stats[0] = _summarize(1, maxDets=self.params.maxDets[0])
1132
+ stats[1] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5)
1133
+ stats[2] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.55)
1134
+ stats[3] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.60)
1135
+ stats[4] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.65)
1136
+ stats[5] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.70)
1137
+ stats[6] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75)
1138
+ stats[7] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.80)
1139
+ stats[8] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.85)
1140
+ stats[9] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.90)
1141
+ stats[10] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.95)
1142
+ stats[11] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium")
1143
+ stats[12] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="large")
1144
+ stats[13] = _summarize(0, maxDets=self.params.maxDets[0])
1145
+ stats[14] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5)
1146
+ stats[15] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75)
1147
+ stats[16] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium")
1148
+ stats[17] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="large")
1149
+ return stats
1150
+
1151
+ if not self.eval:
1152
+ raise Exception("Please run accumulate() first")
1153
+ iouType = self.params.iouType
1154
+ if iouType in ["segm", "bbox"]:
1155
+ summarize = _summarizeDets
1156
+ elif iouType in ["keypoints"]:
1157
+ summarize = _summarizeKps
1158
+ elif iouType in ["densepose"]:
1159
+ summarize = _summarizeUvs
1160
+ self.stats = summarize()
1161
+
1162
+ def __str__(self):
1163
+ self.summarize()
1164
+
1165
+ # ================ functions for dense pose ==============================
1166
+ def findAllClosestVertsUV(self, U_points, V_points, Index_points):
1167
+ ClosestVerts = np.ones(Index_points.shape) * -1
1168
+ for i in np.arange(24):
1169
+ #
1170
+ if (i + 1) in Index_points:
1171
+ UVs = np.array(
1172
+ [U_points[Index_points == (i + 1)], V_points[Index_points == (i + 1)]]
1173
+ )
1174
+ Current_Part_UVs = self.Part_UVs[i]
1175
+ Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
1176
+ D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
1177
+ ClosestVerts[Index_points == (i + 1)] = Current_Part_ClosestVertInds[
1178
+ np.argmin(D, axis=0)
1179
+ ]
1180
+ ClosestVertsTransformed = self.PDIST_transform[ClosestVerts.astype(int) - 1]
1181
+ ClosestVertsTransformed[ClosestVerts < 0] = 0
1182
+ return ClosestVertsTransformed
1183
+
1184
+ def findClosestVertsCse(self, embedding, py, px, mask, mesh_name):
1185
+ mesh_vertex_embeddings = self.embedder(mesh_name)
1186
+ pixel_embeddings = embedding[:, py, px].t().to(device="cuda")
1187
+ mask_vals = mask[py, px]
1188
+ edm = squared_euclidean_distance_matrix(pixel_embeddings, mesh_vertex_embeddings)
1189
+ vertex_indices = edm.argmin(dim=1).cpu()
1190
+ vertex_indices[mask_vals <= 0] = -1
1191
+ return vertex_indices
1192
+
1193
+ def findAllClosestVertsGT(self, gt):
1194
+ #
1195
+ I_gt = np.array(gt["dp_I"])
1196
+ U_gt = np.array(gt["dp_U"])
1197
+ V_gt = np.array(gt["dp_V"])
1198
+ #
1199
+ # print(I_gt)
1200
+ #
1201
+ ClosestVertsGT = np.ones(I_gt.shape) * -1
1202
+ for i in np.arange(24):
1203
+ if (i + 1) in I_gt:
1204
+ UVs = np.array([U_gt[I_gt == (i + 1)], V_gt[I_gt == (i + 1)]])
1205
+ Current_Part_UVs = self.Part_UVs[i]
1206
+ Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
1207
+ D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
1208
+ ClosestVertsGT[I_gt == (i + 1)] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)]
1209
+ #
1210
+ ClosestVertsGTTransformed = self.PDIST_transform[ClosestVertsGT.astype(int) - 1]
1211
+ ClosestVertsGTTransformed[ClosestVertsGT < 0] = 0
1212
+ return ClosestVertsGT, ClosestVertsGTTransformed
1213
+
1214
+ def getDistancesCse(self, cVertsGT, cVerts, mesh_name):
1215
+ geodists_vertices = torch.ones_like(cVertsGT) * float("inf")
1216
+ selected = (cVertsGT >= 0) * (cVerts >= 0)
1217
+ mesh = create_mesh(mesh_name, "cpu")
1218
+ geodists_vertices[selected] = mesh.geodists[cVertsGT[selected], cVerts[selected]]
1219
+ return geodists_vertices.numpy()
1220
+
1221
+ def getDistancesUV(self, cVertsGT, cVerts):
1222
+ #
1223
+ n = 27554
1224
+ dists = []
1225
+ for d in range(len(cVertsGT)):
1226
+ if cVertsGT[d] > 0:
1227
+ if cVerts[d] > 0:
1228
+ i = cVertsGT[d] - 1
1229
+ j = cVerts[d] - 1
1230
+ if j == i:
1231
+ dists.append(0)
1232
+ elif j > i:
1233
+ ccc = i
1234
+ i = j
1235
+ j = ccc
1236
+ i = n - i - 1
1237
+ j = n - j - 1
1238
+ k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1
1239
+ k = (n * n - n) / 2 - k - 1
1240
+ dists.append(self.Pdist_matrix[int(k)][0])
1241
+ else:
1242
+ i = n - i - 1
1243
+ j = n - j - 1
1244
+ k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1
1245
+ k = (n * n - n) / 2 - k - 1
1246
+ dists.append(self.Pdist_matrix[int(k)][0])
1247
+ else:
1248
+ dists.append(np.inf)
1249
+ return np.atleast_1d(np.array(dists).squeeze())
1250
+
1251
+
1252
+ class Params:
1253
+ """
1254
+ Params for coco evaluation api
1255
+ """
1256
+
1257
+ def setDetParams(self):
1258
+ self.imgIds = []
1259
+ self.catIds = []
1260
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
1261
+ self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
1262
+ self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
1263
+ self.maxDets = [1, 10, 100]
1264
+ self.areaRng = [
1265
+ [0**2, 1e5**2],
1266
+ [0**2, 32**2],
1267
+ [32**2, 96**2],
1268
+ [96**2, 1e5**2],
1269
+ ]
1270
+ self.areaRngLbl = ["all", "small", "medium", "large"]
1271
+ self.useCats = 1
1272
+
1273
+ def setKpParams(self):
1274
+ self.imgIds = []
1275
+ self.catIds = []
1276
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
1277
+ self.iouThrs = np.linspace(0.5, 0.95, np.round((0.95 - 0.5) / 0.05) + 1, endpoint=True)
1278
+ self.recThrs = np.linspace(0.0, 1.00, np.round((1.00 - 0.0) / 0.01) + 1, endpoint=True)
1279
+ self.maxDets = [20]
1280
+ self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
1281
+ self.areaRngLbl = ["all", "medium", "large"]
1282
+ self.useCats = 1
1283
+
1284
+ def setUvParams(self):
1285
+ self.imgIds = []
1286
+ self.catIds = []
1287
+ self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
1288
+ self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
1289
+ self.maxDets = [20]
1290
+ self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
1291
+ self.areaRngLbl = ["all", "medium", "large"]
1292
+ self.useCats = 1
1293
+
1294
+ def __init__(self, iouType="segm"):
1295
+ if iouType == "segm" or iouType == "bbox":
1296
+ self.setDetParams()
1297
+ elif iouType == "keypoints":
1298
+ self.setKpParams()
1299
+ elif iouType == "densepose":
1300
+ self.setUvParams()
1301
+ else:
1302
+ raise Exception("iouType not supported")
1303
+ self.iouType = iouType
1304
+ # useSegm is deprecated
1305
+ self.useSegm = None