abubakar123456 commited on
Commit
1c54d21
1 Parent(s): 40a5209

Upload 686 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +46 -35
  2. .gitignore +31 -0
  3. __pycache__/utils.cpython-39.pyc +0 -0
  4. app.py +353 -0
  5. densepose/.DS_Store +0 -0
  6. densepose/__init__.py +22 -0
  7. densepose/__pycache__/__init__.cpython-39.pyc +0 -0
  8. densepose/__pycache__/config.cpython-39.pyc +0 -0
  9. densepose/config.py +278 -0
  10. densepose/converters/__init__.py +17 -0
  11. densepose/converters/__pycache__/__init__.cpython-39.pyc +0 -0
  12. densepose/converters/__pycache__/base.cpython-39.pyc +0 -0
  13. densepose/converters/__pycache__/builtin.cpython-39.pyc +0 -0
  14. densepose/converters/__pycache__/chart_output_hflip.cpython-39.pyc +0 -0
  15. densepose/converters/__pycache__/chart_output_to_chart_result.cpython-39.pyc +0 -0
  16. densepose/converters/__pycache__/hflip.cpython-39.pyc +0 -0
  17. densepose/converters/__pycache__/segm_to_mask.cpython-39.pyc +0 -0
  18. densepose/converters/__pycache__/to_chart_result.cpython-39.pyc +0 -0
  19. densepose/converters/__pycache__/to_mask.cpython-39.pyc +0 -0
  20. densepose/converters/base.py +95 -0
  21. densepose/converters/builtin.py +33 -0
  22. densepose/converters/chart_output_hflip.py +73 -0
  23. densepose/converters/chart_output_to_chart_result.py +190 -0
  24. densepose/converters/hflip.py +36 -0
  25. densepose/converters/segm_to_mask.py +152 -0
  26. densepose/converters/to_chart_result.py +72 -0
  27. densepose/converters/to_mask.py +51 -0
  28. densepose/data/.DS_Store +0 -0
  29. densepose/data/__init__.py +27 -0
  30. densepose/data/__pycache__/__init__.cpython-39.pyc +0 -0
  31. densepose/data/__pycache__/build.cpython-39.pyc +0 -0
  32. densepose/data/__pycache__/combined_loader.cpython-39.pyc +0 -0
  33. densepose/data/__pycache__/dataset_mapper.cpython-39.pyc +0 -0
  34. densepose/data/__pycache__/image_list_dataset.cpython-39.pyc +0 -0
  35. densepose/data/__pycache__/inference_based_loader.cpython-39.pyc +0 -0
  36. densepose/data/__pycache__/utils.cpython-39.pyc +0 -0
  37. densepose/data/build.py +738 -0
  38. densepose/data/combined_loader.py +46 -0
  39. densepose/data/dataset_mapper.py +170 -0
  40. densepose/data/datasets/__init__.py +7 -0
  41. densepose/data/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  42. densepose/data/datasets/__pycache__/builtin.cpython-39.pyc +0 -0
  43. densepose/data/datasets/__pycache__/chimpnsee.cpython-39.pyc +0 -0
  44. densepose/data/datasets/__pycache__/coco.cpython-39.pyc +0 -0
  45. densepose/data/datasets/__pycache__/dataset_type.cpython-39.pyc +0 -0
  46. densepose/data/datasets/__pycache__/lvis.cpython-39.pyc +0 -0
  47. densepose/data/datasets/builtin.py +18 -0
  48. densepose/data/datasets/chimpnsee.py +31 -0
  49. densepose/data/datasets/coco.py +434 -0
  50. densepose/data/datasets/dataset_type.py +13 -0
.gitattributes CHANGED
@@ -1,35 +1,46 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.py eol=lf
2
+ *.js eol=lf
3
+ *.jsx eol=lf
4
+ *.json eol=lf
5
+ # .gitattributes snippet to force users to use same line endings for project.
6
+ # Handle line endings automatically for files detected as text
7
+ # and leave all files detected as binary untouched.
8
+ * text=auto
9
+ # The above will handle all files NOT found below
10
+ # These files are text and should be normalized (Convert crlf => lf)
11
+ *.php text
12
+ *.css text
13
+ *.js text
14
+ *.json text
15
+ *.htm text
16
+ *.html text
17
+ *.xml text
18
+ *.txt text
19
+ *.ini text
20
+ *.inc text
21
+ *.pl text
22
+ *.rb text
23
+ *.py text
24
+ *.scm text
25
+ *.sql text
26
+ .htaccess text
27
+ *.sh text
28
+ # These files are binary and should be left untouched
29
+ # (binary is a macro for -text -diff)
30
+ *.png binary
31
+ *.jpg binary
32
+ *.jpeg binary
33
+ *.gif binary
34
+ *.ico binary
35
+ *.mov binary
36
+ *.mp4 binary
37
+ *.mp3 binary
38
+ *.flv binary
39
+ *.fla binary
40
+ *.swf binary
41
+ *.gz binary
42
+ *.zip binary
43
+ *.7z binary
44
+ *.ttf binary
45
+ *.pyc binary
46
+ detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .vscode
3
+ .idea
4
+ .venv
5
+ venv
6
+ vm
7
+ *.pyc
8
+ *.egg-info
9
+ __pycache__
10
+ .ebextensions
11
+ .spyproject
12
+ node_modules
13
+ bak
14
+ baks
15
+ logs
16
+ myTestes
17
+ myHelpers
18
+ conf
19
+ .requirements.txt.bak
20
+ USE.INFO
21
+ templates/src
22
+ dabolinux-clients-demo-32103b022bf6.json
23
+ media/
24
+ MANIFEST
25
+ build
26
+ dist
27
+ docs/_build
28
+ docs/_static
29
+ npm-debug.log
30
+ setup.cfg
31
+ pyproject.toml
__pycache__/utils.cpython-39.pyc ADDED
Binary file (16.5 kB). View file
 
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ device = torch.device('cpu') # Explicitly use CPU if desired
9
+
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from huggingface_hub import snapshot_download
12
+ from PIL import Image
13
+
14
+ from model.cloth_masker import AutoMasker, vis_mask
15
+ from model.pipeline import CatVTONPipeline
16
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
20
+ parser.add_argument(
21
+ "--base_model_path",
22
+ type=str,
23
+ default="runwayml/stable-diffusion-inpainting",
24
+ help=(
25
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
26
+ ),
27
+ )
28
+ parser.add_argument(
29
+ "--resume_path",
30
+ type=str,
31
+ default="zhengchong/CatVTON",
32
+ help=(
33
+ "The Path to the checkpoint of trained tryon model."
34
+ ),
35
+ )
36
+ parser.add_argument(
37
+ "--output_dir",
38
+ type=str,
39
+ default="resource/demo/output",
40
+ help="The output directory where the model predictions will be written.",
41
+ )
42
+
43
+ parser.add_argument(
44
+ "--width",
45
+ type=int,
46
+ default=768,
47
+ help=(
48
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
49
+ " resolution"
50
+ ),
51
+ )
52
+ parser.add_argument(
53
+ "--height",
54
+ type=int,
55
+ default=1024,
56
+ help=(
57
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
58
+ " resolution"
59
+ ),
60
+ )
61
+ parser.add_argument(
62
+ "--repaint",
63
+ action="store_true",
64
+ help="Whether to repaint the result image with the original background."
65
+ )
66
+ parser.add_argument(
67
+ "--allow_tf32",
68
+ action="store_true",
69
+ default=True,
70
+ help=(
71
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
72
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
73
+ ),
74
+ )
75
+ parser.add_argument(
76
+ "--mixed_precision",
77
+ type=str,
78
+ default="bf16",
79
+ choices=["no", "fp16", "bf16"],
80
+ help=(
81
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
82
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
83
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
84
+ ),
85
+ )
86
+
87
+ args = parser.parse_args()
88
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
89
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
90
+ args.local_rank = env_local_rank
91
+
92
+ return args
93
+
94
+ def image_grid(imgs, rows, cols):
95
+ assert len(imgs) == rows * cols
96
+
97
+ w, h = imgs[0].size
98
+ grid = Image.new("RGB", size=(cols * w, rows * h))
99
+
100
+ for i, img in enumerate(imgs):
101
+ grid.paste(img, box=(i % cols * w, i // cols * h))
102
+ return grid
103
+
104
+
105
+ args = parse_args()
106
+ repo_path = snapshot_download(repo_id=args.resume_path)
107
+ # Pipeline
108
+ pipeline = CatVTONPipeline(
109
+ base_ckpt=args.base_model_path,
110
+ attn_ckpt=repo_path,
111
+ attn_ckpt_version="mix",
112
+ weight_dtype=init_weight_dtype(args.mixed_precision),
113
+ use_tf32=args.allow_tf32,
114
+ # device='cuda'
115
+ device='cpu'
116
+ )
117
+ # AutoMasker
118
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
119
+ automasker = AutoMasker(
120
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
121
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
122
+ # device='cuda',
123
+ device='cpu'
124
+ )
125
+
126
+ def submit_function(
127
+ person_image,
128
+ cloth_image,
129
+ cloth_type,
130
+ num_inference_steps,
131
+ guidance_scale,
132
+ seed,
133
+ show_type
134
+ ):
135
+ person_image, mask = person_image["background"], person_image["layers"][0]
136
+ mask = Image.open(mask).convert("L")
137
+ if len(np.unique(np.array(mask))) == 1:
138
+ mask = None
139
+ else:
140
+ mask = np.array(mask)
141
+ mask[mask > 0] = 255
142
+ mask = Image.fromarray(mask)
143
+
144
+ tmp_folder = args.output_dir
145
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
146
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
147
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
148
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
149
+
150
+ generator = None
151
+ if seed != -1:
152
+ # generator = torch.Generator(device='cuda').manual_seed(seed)
153
+ generator = torch.Generator(device='cpu').manual_seed(seed)
154
+
155
+ person_image = Image.open(person_image).convert("RGB")
156
+ cloth_image = Image.open(cloth_image).convert("RGB")
157
+ person_image = resize_and_crop(person_image, (args.width, args.height))
158
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
159
+
160
+ # Process mask
161
+ if mask is not None:
162
+ mask = resize_and_crop(mask, (args.width, args.height))
163
+ else:
164
+ mask = automasker(
165
+ person_image,
166
+ cloth_type
167
+ )['mask']
168
+ mask = mask_processor.blur(mask, blur_factor=9)
169
+
170
+ # Inference
171
+ # try:
172
+ result_image = pipeline(
173
+ image=person_image,
174
+ condition_image=cloth_image,
175
+ mask=mask,
176
+ num_inference_steps=num_inference_steps,
177
+ guidance_scale=guidance_scale,
178
+ generator=generator
179
+ )[0]
180
+ # except Exception as e:
181
+ # raise gr.Error(
182
+ # "An error occurred. Please try again later: {}".format(e)
183
+ # )
184
+
185
+ # Post-process
186
+ masked_person = vis_mask(person_image, mask)
187
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
188
+ save_result_image.save(result_save_path)
189
+ if show_type == "result only":
190
+ return result_image
191
+ else:
192
+ width, height = person_image.size
193
+ if show_type == "input & result":
194
+ condition_width = width // 2
195
+ conditions = image_grid([person_image, cloth_image], 2, 1)
196
+ else:
197
+ condition_width = width // 3
198
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
199
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
200
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
201
+ new_result_image.paste(conditions, (0, 0))
202
+ new_result_image.paste(result_image, (condition_width + 5, 0))
203
+ return new_result_image
204
+
205
+
206
+ def person_example_fn(image_path):
207
+ return image_path
208
+
209
+ HEADER = """
210
+ <h1 style="text-align: center;">
211
+ Fashable Virtual Tryon
212
+ </h1>
213
+
214
+ """
215
+
216
+ def app_gradio():
217
+ with gr.Blocks(title="CatVTON") as demo:
218
+ gr.Markdown(HEADER)
219
+ with gr.Row():
220
+ with gr.Column(scale=1, min_width=350):
221
+ with gr.Row():
222
+ image_path = gr.Image(
223
+ type="filepath",
224
+ interactive=True,
225
+ visible=False,
226
+ )
227
+ person_image = gr.ImageEditor(
228
+ interactive=True, label="Person Image", type="filepath"
229
+ )
230
+
231
+ with gr.Row():
232
+ with gr.Column(scale=1, min_width=230):
233
+ cloth_image = gr.Image(
234
+ interactive=True, label="Condition Image", type="filepath"
235
+ )
236
+ with gr.Column(scale=1, min_width=120):
237
+ gr.Markdown(
238
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
239
+ )
240
+ cloth_type = gr.Radio(
241
+ label="Try-On Cloth Type",
242
+ choices=["upper", "lower", "overall"],
243
+ value="upper",
244
+ )
245
+
246
+
247
+ submit = gr.Button("Submit")
248
+ gr.Markdown(
249
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
250
+ )
251
+
252
+ gr.Markdown(
253
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
254
+ )
255
+ with gr.Accordion("Advanced Options", open=False):
256
+ num_inference_steps = gr.Slider(
257
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
258
+ )
259
+ # Guidence Scale
260
+ guidance_scale = gr.Slider(
261
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
262
+ )
263
+ # Random Seed
264
+ seed = gr.Slider(
265
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
266
+ )
267
+ show_type = gr.Radio(
268
+ label="Show Type",
269
+ choices=["result only", "input & result", "input & mask & result"],
270
+ value="input & mask & result",
271
+ )
272
+
273
+ with gr.Column(scale=2, min_width=500):
274
+ result_image = gr.Image(interactive=False, label="Result")
275
+ with gr.Row():
276
+ # Photo Examples
277
+ root_path = "resource/demo/example"
278
+ with gr.Column():
279
+ men_exm = gr.Examples(
280
+ examples=[
281
+ os.path.join(root_path, "person", "men", _)
282
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
283
+ ],
284
+ examples_per_page=4,
285
+ inputs=image_path,
286
+ label="Person Examples ①",
287
+ )
288
+ women_exm = gr.Examples(
289
+ examples=[
290
+ os.path.join(root_path, "person", "women", _)
291
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
292
+ ],
293
+ examples_per_page=4,
294
+ inputs=image_path,
295
+ label="Person Examples ②",
296
+ )
297
+ gr.Markdown(
298
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
299
+ )
300
+ with gr.Column():
301
+ condition_upper_exm = gr.Examples(
302
+ examples=[
303
+ os.path.join(root_path, "condition", "upper", _)
304
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
305
+ ],
306
+ examples_per_page=4,
307
+ inputs=cloth_image,
308
+ label="Condition Upper Examples",
309
+ )
310
+ condition_overall_exm = gr.Examples(
311
+ examples=[
312
+ os.path.join(root_path, "condition", "overall", _)
313
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
314
+ ],
315
+ examples_per_page=4,
316
+ inputs=cloth_image,
317
+ label="Condition Overall Examples",
318
+ )
319
+ condition_person_exm = gr.Examples(
320
+ examples=[
321
+ os.path.join(root_path, "condition", "person", _)
322
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
323
+ ],
324
+ examples_per_page=4,
325
+ inputs=cloth_image,
326
+ label="Condition Reference Person Examples",
327
+ )
328
+ gr.Markdown(
329
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
330
+ )
331
+
332
+ image_path.change(
333
+ person_example_fn, inputs=image_path, outputs=person_image
334
+ )
335
+
336
+ submit.click(
337
+ submit_function,
338
+ [
339
+ person_image,
340
+ cloth_image,
341
+ cloth_type,
342
+ num_inference_steps,
343
+ guidance_scale,
344
+ seed,
345
+ show_type,
346
+ ],
347
+ result_image,
348
+ )
349
+ demo.queue().launch(share=True, show_error=True)
350
+
351
+
352
+ if __name__ == "__main__":
353
+ app_gradio()
densepose/.DS_Store ADDED
Binary file (10.2 kB). View file
 
densepose/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .data.datasets import builtin # just to register data
5
+ from .converters import builtin as builtin_converters # register converters
6
+ from .config import (
7
+ add_densepose_config,
8
+ add_densepose_head_config,
9
+ add_hrnet_config,
10
+ add_dataset_category_config,
11
+ add_bootstrap_config,
12
+ load_bootstrap_config,
13
+ )
14
+ from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
15
+ from .evaluation import DensePoseCOCOEvaluator
16
+ from .modeling.roi_heads import DensePoseROIHeads
17
+ from .modeling.test_time_augmentation import (
18
+ DensePoseGeneralizedRCNNWithTTA,
19
+ DensePoseDatasetMapperTTA,
20
+ )
21
+ from .utils.transform import load_from_cfg
22
+ from .modeling.hrfpn import build_hrfpn_backbone
densepose/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (947 Bytes). View file
 
densepose/__pycache__/config.cpython-39.pyc ADDED
Binary file (5.84 kB). View file
 
densepose/config.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding = utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # pyre-ignore-all-errors
4
+
5
+ from detectron2.config import CfgNode as CN
6
+
7
+
8
+ def add_dataset_category_config(cfg: CN) -> None:
9
+ """
10
+ Add config for additional category-related dataset options
11
+ - category whitelisting
12
+ - category mapping
13
+ """
14
+ _C = cfg
15
+ _C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
16
+ _C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
17
+ # class to mesh mapping
18
+ _C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
19
+
20
+
21
+ def add_evaluation_config(cfg: CN) -> None:
22
+ _C = cfg
23
+ _C.DENSEPOSE_EVALUATION = CN()
24
+ # evaluator type, possible values:
25
+ # - "iou": evaluator for models that produce iou data
26
+ # - "cse": evaluator for models that produce cse data
27
+ _C.DENSEPOSE_EVALUATION.TYPE = "iou"
28
+ # storage for DensePose results, possible values:
29
+ # - "none": no explicit storage, all the results are stored in the
30
+ # dictionary with predictions, memory intensive;
31
+ # historically the default storage type
32
+ # - "ram": RAM storage, uses per-process RAM storage, which is
33
+ # reduced to a single process storage on later stages,
34
+ # less memory intensive
35
+ # - "file": file storage, uses per-process file-based storage,
36
+ # the least memory intensive, but may create bottlenecks
37
+ # on file system accesses
38
+ _C.DENSEPOSE_EVALUATION.STORAGE = "none"
39
+ # minimum threshold for IOU values: the lower its values is,
40
+ # the more matches are produced (and the higher the AP score)
41
+ _C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
42
+ # Non-distributed inference is slower (at inference time) but can avoid RAM OOM
43
+ _C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
44
+ # evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
45
+ _C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
46
+ # meshes to compute mesh alignment for
47
+ _C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
48
+
49
+
50
+ def add_bootstrap_config(cfg: CN) -> None:
51
+ """ """
52
+ _C = cfg
53
+ _C.BOOTSTRAP_DATASETS = []
54
+ _C.BOOTSTRAP_MODEL = CN()
55
+ _C.BOOTSTRAP_MODEL.WEIGHTS = ""
56
+ # _C.BOOTSTRAP_MODEL.DEVICE = "cuda"
57
+ _C.BOOTSTRAP_MODEL.DEVICE = "cpu"
58
+
59
+
60
+ def get_bootstrap_dataset_config() -> CN:
61
+ _C = CN()
62
+ _C.DATASET = ""
63
+ # ratio used to mix data loaders
64
+ _C.RATIO = 0.1
65
+ # image loader
66
+ _C.IMAGE_LOADER = CN(new_allowed=True)
67
+ _C.IMAGE_LOADER.TYPE = ""
68
+ _C.IMAGE_LOADER.BATCH_SIZE = 4
69
+ _C.IMAGE_LOADER.NUM_WORKERS = 4
70
+ _C.IMAGE_LOADER.CATEGORIES = []
71
+ _C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
72
+ _C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
73
+ # inference
74
+ _C.INFERENCE = CN()
75
+ # batch size for model inputs
76
+ _C.INFERENCE.INPUT_BATCH_SIZE = 4
77
+ # batch size to group model outputs
78
+ _C.INFERENCE.OUTPUT_BATCH_SIZE = 2
79
+ # sampled data
80
+ _C.DATA_SAMPLER = CN(new_allowed=True)
81
+ _C.DATA_SAMPLER.TYPE = ""
82
+ _C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
83
+ # filter
84
+ _C.FILTER = CN(new_allowed=True)
85
+ _C.FILTER.TYPE = ""
86
+ return _C
87
+
88
+
89
+ def load_bootstrap_config(cfg: CN) -> None:
90
+ """
91
+ Bootstrap datasets are given as a list of `dict` that are not automatically
92
+ converted into CfgNode. This method processes all bootstrap dataset entries
93
+ and ensures that they are in CfgNode format and comply with the specification
94
+ """
95
+ if not cfg.BOOTSTRAP_DATASETS:
96
+ return
97
+
98
+ bootstrap_datasets_cfgnodes = []
99
+ for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
100
+ _C = get_bootstrap_dataset_config().clone()
101
+ _C.merge_from_other_cfg(CN(dataset_cfg))
102
+ bootstrap_datasets_cfgnodes.append(_C)
103
+ cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
104
+
105
+
106
+ def add_densepose_head_cse_config(cfg: CN) -> None:
107
+ """
108
+ Add configuration options for Continuous Surface Embeddings (CSE)
109
+ """
110
+ _C = cfg
111
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
112
+ # Dimensionality D of the embedding space
113
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
114
+ # Embedder specifications for various mesh IDs
115
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
116
+ # normalization coefficient for embedding distances
117
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
118
+ # normalization coefficient for geodesic distances
119
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
120
+ # embedding loss weight
121
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
122
+ # embedding loss name, currently the following options are supported:
123
+ # - EmbeddingLoss: cross-entropy on vertex labels
124
+ # - SoftEmbeddingLoss: cross-entropy on vertex label combined with
125
+ # Gaussian penalty on distance between vertices
126
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
127
+ # optimizer hyperparameters
128
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
129
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
130
+ # Shape to shape cycle consistency loss parameters:
131
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
132
+ # shape to shape cycle consistency loss weight
133
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
134
+ # norm type used for loss computation
135
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
136
+ # normalization term for embedding similarity matrices
137
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
138
+ # maximum number of vertices to include into shape to shape cycle loss
139
+ # if negative or zero, all vertices are considered
140
+ # if positive, random subset of vertices of given size is considered
141
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
142
+ # Pixel to shape cycle consistency loss parameters:
143
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
144
+ # pixel to shape cycle consistency loss weight
145
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
146
+ # norm type used for loss computation
147
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
148
+ # map images to all meshes and back (if false, use only gt meshes from the batch)
149
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
150
+ # Randomly select at most this number of pixels from every instance
151
+ # if negative or zero, all vertices are considered
152
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
153
+ # normalization factor for pixel to pixel distances (higher value = smoother distribution)
154
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
155
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
156
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
157
+
158
+
159
+ def add_densepose_head_config(cfg: CN) -> None:
160
+ """
161
+ Add config for densepose head.
162
+ """
163
+ _C = cfg
164
+
165
+ _C.MODEL.DENSEPOSE_ON = True
166
+
167
+ _C.MODEL.ROI_DENSEPOSE_HEAD = CN()
168
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
169
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
170
+ # Number of parts used for point labels
171
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
172
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
173
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
174
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
175
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
176
+ _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
177
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
178
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
179
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
180
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
181
+ # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
182
+ _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
183
+ # Loss weights for annotation masks.(14 Parts)
184
+ _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
185
+ # Loss weights for surface parts. (24 Parts)
186
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
187
+ # Loss weights for UV regression.
188
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
189
+ # Coarse segmentation is trained using instance segmentation task data
190
+ _C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
191
+ # For Decoder
192
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
193
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
194
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
195
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
196
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
197
+ # For DeepLab head
198
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
199
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
200
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
201
+ # Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
202
+ # Some registered predictors:
203
+ # "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
204
+ # "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
205
+ # and associated confidences for predefined charts (default)
206
+ # "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
207
+ # and associated confidences for CSE
208
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
209
+ # Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
210
+ # Some registered losses:
211
+ # "DensePoseChartLoss": loss for chart-based models that estimate
212
+ # segmentation and UV coordinates
213
+ # "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
214
+ # segmentation, UV coordinates and the corresponding confidences (default)
215
+ _C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
216
+ # Confidences
217
+ # Enable learning UV confidences (variances) along with the actual values
218
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
219
+ # UV confidence lower bound
220
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
221
+ # Enable learning segmentation confidences (variances) along with the actual values
222
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
223
+ # Segmentation confidence lower bound
224
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
225
+ # Statistical model type for confidence learning, possible values:
226
+ # - "iid_iso": statistically independent identically distributed residuals
227
+ # with isotropic covariance
228
+ # - "indep_aniso": statistically independent residuals with anisotropic
229
+ # covariances
230
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
231
+ # List of angles for rotation in data augmentation during training
232
+ _C.INPUT.ROTATION_ANGLES = [0]
233
+ _C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
234
+
235
+ add_densepose_head_cse_config(cfg)
236
+
237
+
238
+ def add_hrnet_config(cfg: CN) -> None:
239
+ """
240
+ Add config for HRNet backbone.
241
+ """
242
+ _C = cfg
243
+
244
+ # For HigherHRNet w32
245
+ _C.MODEL.HRNET = CN()
246
+ _C.MODEL.HRNET.STEM_INPLANES = 64
247
+ _C.MODEL.HRNET.STAGE2 = CN()
248
+ _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
249
+ _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
250
+ _C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
251
+ _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
252
+ _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
253
+ _C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
254
+ _C.MODEL.HRNET.STAGE3 = CN()
255
+ _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
256
+ _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
257
+ _C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
258
+ _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
259
+ _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
260
+ _C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
261
+ _C.MODEL.HRNET.STAGE4 = CN()
262
+ _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
263
+ _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
264
+ _C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
265
+ _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
266
+ _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
267
+ _C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
268
+
269
+ _C.MODEL.HRNET.HRFPN = CN()
270
+ _C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
271
+
272
+
273
+ def add_densepose_config(cfg: CN) -> None:
274
+ add_densepose_head_config(cfg)
275
+ add_hrnet_config(cfg)
276
+ add_bootstrap_config(cfg)
277
+ add_dataset_category_config(cfg)
278
+ add_evaluation_config(cfg)
densepose/converters/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .hflip import HFlipConverter
6
+ from .to_mask import ToMaskConverter
7
+ from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
8
+ from .segm_to_mask import (
9
+ predictor_output_with_fine_and_coarse_segm_to_mask,
10
+ predictor_output_with_coarse_segm_to_mask,
11
+ resample_fine_and_coarse_segm_to_bbox,
12
+ )
13
+ from .chart_output_to_chart_result import (
14
+ densepose_chart_predictor_output_to_result,
15
+ densepose_chart_predictor_output_to_result_with_confidences,
16
+ )
17
+ from .chart_output_hflip import densepose_chart_predictor_output_hflip
densepose/converters/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (821 Bytes). View file
 
densepose/converters/__pycache__/base.cpython-39.pyc ADDED
Binary file (3.7 kB). View file
 
densepose/converters/__pycache__/builtin.cpython-39.pyc ADDED
Binary file (826 Bytes). View file
 
densepose/converters/__pycache__/chart_output_hflip.cpython-39.pyc ADDED
Binary file (1.97 kB). View file
 
densepose/converters/__pycache__/chart_output_to_chart_result.cpython-39.pyc ADDED
Binary file (6.05 kB). View file
 
densepose/converters/__pycache__/hflip.cpython-39.pyc ADDED
Binary file (1.37 kB). View file
 
densepose/converters/__pycache__/segm_to_mask.cpython-39.pyc ADDED
Binary file (5.78 kB). View file
 
densepose/converters/__pycache__/to_chart_result.cpython-39.pyc ADDED
Binary file (2.76 kB). View file
 
densepose/converters/__pycache__/to_mask.cpython-39.pyc ADDED
Binary file (1.78 kB). View file
 
densepose/converters/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple, Type
6
+ import torch
7
+
8
+
9
+ class BaseConverter:
10
+ """
11
+ Converter base class to be reused by various converters.
12
+ Converter allows one to convert data from various source types to a particular
13
+ destination type. Each source type needs to register its converter. The
14
+ registration for each source type is valid for all descendants of that type.
15
+ """
16
+
17
+ @classmethod
18
+ def register(cls, from_type: Type, converter: Any = None):
19
+ """
20
+ Registers a converter for the specified type.
21
+ Can be used as a decorator (if converter is None), or called as a method.
22
+
23
+ Args:
24
+ from_type (type): type to register the converter for;
25
+ all instances of this type will use the same converter
26
+ converter (callable): converter to be registered for the given
27
+ type; if None, this method is assumed to be a decorator for the converter
28
+ """
29
+
30
+ if converter is not None:
31
+ cls._do_register(from_type, converter)
32
+
33
+ def wrapper(converter: Any) -> Any:
34
+ cls._do_register(from_type, converter)
35
+ return converter
36
+
37
+ return wrapper
38
+
39
+ @classmethod
40
+ def _do_register(cls, from_type: Type, converter: Any):
41
+ cls.registry[from_type] = converter # pyre-ignore[16]
42
+
43
+ @classmethod
44
+ def _lookup_converter(cls, from_type: Type) -> Any:
45
+ """
46
+ Perform recursive lookup for the given type
47
+ to find registered converter. If a converter was found for some base
48
+ class, it gets registered for this class to save on further lookups.
49
+
50
+ Args:
51
+ from_type: type for which to find a converter
52
+ Return:
53
+ callable or None - registered converter or None
54
+ if no suitable entry was found in the registry
55
+ """
56
+ if from_type in cls.registry: # pyre-ignore[16]
57
+ return cls.registry[from_type]
58
+ for base in from_type.__bases__:
59
+ converter = cls._lookup_converter(base)
60
+ if converter is not None:
61
+ cls._do_register(from_type, converter)
62
+ return converter
63
+ return None
64
+
65
+ @classmethod
66
+ def convert(cls, instance: Any, *args, **kwargs):
67
+ """
68
+ Convert an instance to the destination type using some registered
69
+ converter. Does recursive lookup for base classes, so there's no need
70
+ for explicit registration for derived classes.
71
+
72
+ Args:
73
+ instance: source instance to convert to the destination type
74
+ Return:
75
+ An instance of the destination type obtained from the source instance
76
+ Raises KeyError, if no suitable converter found
77
+ """
78
+ instance_type = type(instance)
79
+ converter = cls._lookup_converter(instance_type)
80
+ if converter is None:
81
+ if cls.dst_type is None: # pyre-ignore[16]
82
+ output_type_str = "itself"
83
+ else:
84
+ output_type_str = cls.dst_type
85
+ raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
86
+ return converter(instance, *args, **kwargs)
87
+
88
+
89
+ IntTupleBox = Tuple[int, int, int, int]
90
+
91
+
92
+ def make_int_box(box: torch.Tensor) -> IntTupleBox:
93
+ int_box = [0, 0, 0, 0]
94
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
95
+ return int_box[0], int_box[1], int_box[2], int_box[3]
densepose/converters/builtin.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
6
+ from . import (
7
+ HFlipConverter,
8
+ ToChartResultConverter,
9
+ ToChartResultConverterWithConfidences,
10
+ ToMaskConverter,
11
+ densepose_chart_predictor_output_hflip,
12
+ densepose_chart_predictor_output_to_result,
13
+ densepose_chart_predictor_output_to_result_with_confidences,
14
+ predictor_output_with_coarse_segm_to_mask,
15
+ predictor_output_with_fine_and_coarse_segm_to_mask,
16
+ )
17
+
18
+ ToMaskConverter.register(
19
+ DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
20
+ )
21
+ ToMaskConverter.register(
22
+ DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
23
+ )
24
+
25
+ ToChartResultConverter.register(
26
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
27
+ )
28
+
29
+ ToChartResultConverterWithConfidences.register(
30
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
31
+ )
32
+
33
+ HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
densepose/converters/chart_output_hflip.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from dataclasses import fields
5
+ import torch
6
+
7
+ from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
8
+
9
+
10
+ def densepose_chart_predictor_output_hflip(
11
+ densepose_predictor_output: DensePoseChartPredictorOutput,
12
+ transform_data: DensePoseTransformData,
13
+ ) -> DensePoseChartPredictorOutput:
14
+ """
15
+ Change to take into account a Horizontal flip.
16
+ """
17
+ if len(densepose_predictor_output) > 0:
18
+
19
+ PredictorOutput = type(densepose_predictor_output)
20
+ output_dict = {}
21
+
22
+ for field in fields(densepose_predictor_output):
23
+ field_value = getattr(densepose_predictor_output, field.name)
24
+ # flip tensors
25
+ if isinstance(field_value, torch.Tensor):
26
+ setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
27
+
28
+ densepose_predictor_output = _flip_iuv_semantics_tensor(
29
+ densepose_predictor_output, transform_data
30
+ )
31
+ densepose_predictor_output = _flip_segm_semantics_tensor(
32
+ densepose_predictor_output, transform_data
33
+ )
34
+
35
+ for field in fields(densepose_predictor_output):
36
+ output_dict[field.name] = getattr(densepose_predictor_output, field.name)
37
+
38
+ return PredictorOutput(**output_dict)
39
+ else:
40
+ return densepose_predictor_output
41
+
42
+
43
+ def _flip_iuv_semantics_tensor(
44
+ densepose_predictor_output: DensePoseChartPredictorOutput,
45
+ dp_transform_data: DensePoseTransformData,
46
+ ) -> DensePoseChartPredictorOutput:
47
+ point_label_symmetries = dp_transform_data.point_label_symmetries
48
+ uv_symmetries = dp_transform_data.uv_symmetries
49
+
50
+ N, C, H, W = densepose_predictor_output.u.shape
51
+ u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
52
+ v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
53
+ Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
54
+ None, :, None, None
55
+ ].expand(N, C - 1, H, W)
56
+ densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
57
+ densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
58
+
59
+ for el in ["fine_segm", "u", "v"]:
60
+ densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
61
+ :, point_label_symmetries, :, :
62
+ ]
63
+ return densepose_predictor_output
64
+
65
+
66
+ def _flip_segm_semantics_tensor(
67
+ densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
68
+ ):
69
+ if densepose_predictor_output.coarse_segm.shape[1] > 2:
70
+ densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
71
+ :, dp_transform_data.mask_label_symmetries, :, :
72
+ ]
73
+ return densepose_predictor_output
densepose/converters/chart_output_to_chart_result.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Dict
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures.boxes import Boxes, BoxMode
10
+
11
+ from ..structures import (
12
+ DensePoseChartPredictorOutput,
13
+ DensePoseChartResult,
14
+ DensePoseChartResultWithConfidences,
15
+ )
16
+ from . import resample_fine_and_coarse_segm_to_bbox
17
+ from .base import IntTupleBox, make_int_box
18
+
19
+
20
+ def resample_uv_tensors_to_bbox(
21
+ u: torch.Tensor,
22
+ v: torch.Tensor,
23
+ labels: torch.Tensor,
24
+ box_xywh_abs: IntTupleBox,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Resamples U and V coordinate estimates for the given bounding box
28
+
29
+ Args:
30
+ u (tensor [1, C, H, W] of float): U coordinates
31
+ v (tensor [1, C, H, W] of float): V coordinates
32
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
33
+ outputs for the given bounding box
34
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
35
+ Return:
36
+ Resampled U and V coordinates - a tensor [2, H, W] of float
37
+ """
38
+ x, y, w, h = box_xywh_abs
39
+ w = max(int(w), 1)
40
+ h = max(int(h), 1)
41
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
42
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
43
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
44
+ for part_id in range(1, u_bbox.size(1)):
45
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
46
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
47
+ return uv
48
+
49
+
50
+ def resample_uv_to_bbox(
51
+ predictor_output: DensePoseChartPredictorOutput,
52
+ labels: torch.Tensor,
53
+ box_xywh_abs: IntTupleBox,
54
+ ) -> torch.Tensor:
55
+ """
56
+ Resamples U and V coordinate estimates for the given bounding box
57
+
58
+ Args:
59
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
60
+ output to be resampled
61
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
62
+ outputs for the given bounding box
63
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
64
+ Return:
65
+ Resampled U and V coordinates - a tensor [2, H, W] of float
66
+ """
67
+ return resample_uv_tensors_to_bbox(
68
+ predictor_output.u,
69
+ predictor_output.v,
70
+ labels,
71
+ box_xywh_abs,
72
+ )
73
+
74
+
75
+ def densepose_chart_predictor_output_to_result(
76
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
77
+ ) -> DensePoseChartResult:
78
+ """
79
+ Convert densepose chart predictor outputs to results
80
+
81
+ Args:
82
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
83
+ output to be converted to results, must contain only 1 output
84
+ boxes (Boxes): bounding box that corresponds to the predictor output,
85
+ must contain only 1 bounding box
86
+ Return:
87
+ DensePose chart-based result (DensePoseChartResult)
88
+ """
89
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
90
+ f"Predictor output to result conversion can operate only single outputs"
91
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
92
+ )
93
+
94
+ boxes_xyxy_abs = boxes.tensor.clone()
95
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
96
+ box_xywh = make_int_box(boxes_xywh_abs[0])
97
+
98
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
99
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
100
+ return DensePoseChartResult(labels=labels, uv=uv)
101
+
102
+
103
+ def resample_confidences_to_bbox(
104
+ predictor_output: DensePoseChartPredictorOutput,
105
+ labels: torch.Tensor,
106
+ box_xywh_abs: IntTupleBox,
107
+ ) -> Dict[str, torch.Tensor]:
108
+ """
109
+ Resamples confidences for the given bounding box
110
+
111
+ Args:
112
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
113
+ output to be resampled
114
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
115
+ outputs for the given bounding box
116
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
117
+ Return:
118
+ Resampled confidences - a dict of [H, W] tensors of float
119
+ """
120
+
121
+ x, y, w, h = box_xywh_abs
122
+ w = max(int(w), 1)
123
+ h = max(int(h), 1)
124
+
125
+ confidence_names = [
126
+ "sigma_1",
127
+ "sigma_2",
128
+ "kappa_u",
129
+ "kappa_v",
130
+ "fine_segm_confidence",
131
+ "coarse_segm_confidence",
132
+ ]
133
+ confidence_results = {key: None for key in confidence_names}
134
+ confidence_names = [
135
+ key for key in confidence_names if getattr(predictor_output, key) is not None
136
+ ]
137
+ confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
138
+
139
+ # assign data from channels that correspond to the labels
140
+ for key in confidence_names:
141
+ resampled_confidence = F.interpolate(
142
+ getattr(predictor_output, key),
143
+ (h, w),
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+ result = confidence_base.clone()
148
+ for part_id in range(1, predictor_output.u.size(1)):
149
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
150
+ # confidence is not part-based, don't try to fill it part by part
151
+ continue
152
+ result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
153
+
154
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
155
+ # confidence is not part-based, fill the data with the first channel
156
+ # (targeted for segmentation confidences that have only 1 channel)
157
+ result = resampled_confidence[0, 0]
158
+
159
+ confidence_results[key] = result
160
+
161
+ return confidence_results # pyre-ignore[7]
162
+
163
+
164
+ def densepose_chart_predictor_output_to_result_with_confidences(
165
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
166
+ ) -> DensePoseChartResultWithConfidences:
167
+ """
168
+ Convert densepose chart predictor outputs to results
169
+
170
+ Args:
171
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
172
+ output with confidences to be converted to results, must contain only 1 output
173
+ boxes (Boxes): bounding box that corresponds to the predictor output,
174
+ must contain only 1 bounding box
175
+ Return:
176
+ DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
177
+ """
178
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
179
+ f"Predictor output to result conversion can operate only single outputs"
180
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
181
+ )
182
+
183
+ boxes_xyxy_abs = boxes.tensor.clone()
184
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
185
+ box_xywh = make_int_box(boxes_xywh_abs[0])
186
+
187
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
188
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
189
+ confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
190
+ return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
densepose/converters/hflip.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from .base import BaseConverter
8
+
9
+
10
+ class HFlipConverter(BaseConverter):
11
+ """
12
+ Converts various DensePose predictor outputs to DensePose results.
13
+ Each DensePose predictor output type has to register its convertion strategy.
14
+ """
15
+
16
+ registry = {}
17
+ dst_type = None
18
+
19
+ @classmethod
20
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
21
+ # inconsistently.
22
+ def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
23
+ """
24
+ Performs an horizontal flip on DensePose predictor outputs.
25
+ Does recursive lookup for base classes, so there's no need
26
+ for explicit registration for derived classes.
27
+
28
+ Args:
29
+ predictor_outputs: DensePose predictor output to be converted to BitMasks
30
+ transform_data: Anything useful for the flip
31
+ Return:
32
+ An instance of the same type as predictor_outputs
33
+ """
34
+ return super(HFlipConverter, cls).convert(
35
+ predictor_outputs, transform_data, *args, **kwargs
36
+ )
densepose/converters/segm_to_mask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BitMasks, Boxes, BoxMode
10
+
11
+ from .base import IntTupleBox, make_int_box
12
+ from .to_mask import ImageSizeType
13
+
14
+
15
+ def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
16
+ """
17
+ Resample coarse segmentation tensor to the given
18
+ bounding box and derive labels for each pixel of the bounding box
19
+
20
+ Args:
21
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
22
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
23
+ corner coordinates, width (W) and height (H)
24
+ Return:
25
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
26
+ """
27
+ x, y, w, h = box_xywh_abs
28
+ w = max(int(w), 1)
29
+ h = max(int(h), 1)
30
+ labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
31
+ return labels
32
+
33
+
34
+ def resample_fine_and_coarse_segm_tensors_to_bbox(
35
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
36
+ ):
37
+ """
38
+ Resample fine and coarse segmentation tensors to the given
39
+ bounding box and derive labels for each pixel of the bounding box
40
+
41
+ Args:
42
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
43
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
44
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
45
+ corner coordinates, width (W) and height (H)
46
+ Return:
47
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
48
+ """
49
+ x, y, w, h = box_xywh_abs
50
+ w = max(int(w), 1)
51
+ h = max(int(h), 1)
52
+ # coarse segmentation
53
+ coarse_segm_bbox = F.interpolate(
54
+ coarse_segm,
55
+ (h, w),
56
+ mode="bilinear",
57
+ align_corners=False,
58
+ ).argmax(dim=1)
59
+ # combined coarse and fine segmentation
60
+ labels = (
61
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
62
+ * (coarse_segm_bbox > 0).long()
63
+ )
64
+ return labels
65
+
66
+
67
+ def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
68
+ """
69
+ Resample fine and coarse segmentation outputs from a predictor to the given
70
+ bounding box and derive labels for each pixel of the bounding box
71
+
72
+ Args:
73
+ predictor_output: DensePose predictor output that contains segmentation
74
+ results to be resampled
75
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
76
+ corner coordinates, width (W) and height (H)
77
+ Return:
78
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
79
+ """
80
+ return resample_fine_and_coarse_segm_tensors_to_bbox(
81
+ predictor_output.fine_segm,
82
+ predictor_output.coarse_segm,
83
+ box_xywh_abs,
84
+ )
85
+
86
+
87
+ def predictor_output_with_coarse_segm_to_mask(
88
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
89
+ ) -> BitMasks:
90
+ """
91
+ Convert predictor output with coarse and fine segmentation to a mask.
92
+ Assumes that predictor output has the following attributes:
93
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
94
+ unnormalized scores for N instances; D is the number of coarse
95
+ segmentation labels, H and W is the resolution of the estimate
96
+
97
+ Args:
98
+ predictor_output: DensePose predictor output to be converted to mask
99
+ boxes (Boxes): bounding boxes that correspond to the DensePose
100
+ predictor outputs
101
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
102
+ Return:
103
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
104
+ a mask of the size of the image for each instance
105
+ """
106
+ H, W = image_size_hw
107
+ boxes_xyxy_abs = boxes.tensor.clone()
108
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
109
+ N = len(boxes_xywh_abs)
110
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
111
+ for i in range(len(boxes_xywh_abs)):
112
+ box_xywh = make_int_box(boxes_xywh_abs[i])
113
+ box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
114
+ x, y, w, h = box_xywh
115
+ masks[i, y : y + h, x : x + w] = box_mask
116
+
117
+ return BitMasks(masks)
118
+
119
+
120
+ def predictor_output_with_fine_and_coarse_segm_to_mask(
121
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
122
+ ) -> BitMasks:
123
+ """
124
+ Convert predictor output with coarse and fine segmentation to a mask.
125
+ Assumes that predictor output has the following attributes:
126
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
127
+ unnormalized scores for N instances; D is the number of coarse
128
+ segmentation labels, H and W is the resolution of the estimate
129
+ - fine_segm (tensor of size [N, C, H, W]): fine segmentation
130
+ unnormalized scores for N instances; C is the number of fine
131
+ segmentation labels, H and W is the resolution of the estimate
132
+
133
+ Args:
134
+ predictor_output: DensePose predictor output to be converted to mask
135
+ boxes (Boxes): bounding boxes that correspond to the DensePose
136
+ predictor outputs
137
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
138
+ Return:
139
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
140
+ a mask of the size of the image for each instance
141
+ """
142
+ H, W = image_size_hw
143
+ boxes_xyxy_abs = boxes.tensor.clone()
144
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
145
+ N = len(boxes_xywh_abs)
146
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
147
+ for i in range(len(boxes_xywh_abs)):
148
+ box_xywh = make_int_box(boxes_xywh_abs[i])
149
+ labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
150
+ x, y, w, h = box_xywh
151
+ masks[i, y : y + h, x : x + w] = labels_i > 0
152
+ return BitMasks(masks)
densepose/converters/to_chart_result.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from detectron2.structures import Boxes
8
+
9
+ from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
10
+ from .base import BaseConverter
11
+
12
+
13
+ class ToChartResultConverter(BaseConverter):
14
+ """
15
+ Converts various DensePose predictor outputs to DensePose results.
16
+ Each DensePose predictor output type has to register its convertion strategy.
17
+ """
18
+
19
+ registry = {}
20
+ dst_type = DensePoseChartResult
21
+
22
+ @classmethod
23
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
24
+ # inconsistently.
25
+ def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
26
+ """
27
+ Convert DensePose predictor outputs to DensePoseResult using some registered
28
+ converter. Does recursive lookup for base classes, so there's no need
29
+ for explicit registration for derived classes.
30
+
31
+ Args:
32
+ densepose_predictor_outputs: DensePose predictor output to be
33
+ converted to BitMasks
34
+ boxes (Boxes): bounding boxes that correspond to the DensePose
35
+ predictor outputs
36
+ Return:
37
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
38
+ """
39
+ return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
40
+
41
+
42
+ class ToChartResultConverterWithConfidences(BaseConverter):
43
+ """
44
+ Converts various DensePose predictor outputs to DensePose results.
45
+ Each DensePose predictor output type has to register its convertion strategy.
46
+ """
47
+
48
+ registry = {}
49
+ dst_type = DensePoseChartResultWithConfidences
50
+
51
+ @classmethod
52
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
53
+ # inconsistently.
54
+ def convert(
55
+ cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
56
+ ) -> DensePoseChartResultWithConfidences:
57
+ """
58
+ Convert DensePose predictor outputs to DensePoseResult with confidences
59
+ using some registered converter. Does recursive lookup for base classes,
60
+ so there's no need for explicit registration for derived classes.
61
+
62
+ Args:
63
+ densepose_predictor_outputs: DensePose predictor output with confidences
64
+ to be converted to BitMasks
65
+ boxes (Boxes): bounding boxes that correspond to the DensePose
66
+ predictor outputs
67
+ Return:
68
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
69
+ """
70
+ return super(ToChartResultConverterWithConfidences, cls).convert(
71
+ predictor_outputs, boxes, *args, **kwargs
72
+ )
densepose/converters/to_mask.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple
6
+
7
+ from detectron2.structures import BitMasks, Boxes
8
+
9
+ from .base import BaseConverter
10
+
11
+ ImageSizeType = Tuple[int, int]
12
+
13
+
14
+ class ToMaskConverter(BaseConverter):
15
+ """
16
+ Converts various DensePose predictor outputs to masks
17
+ in bit mask format (see `BitMasks`). Each DensePose predictor output type
18
+ has to register its convertion strategy.
19
+ """
20
+
21
+ registry = {}
22
+ dst_type = BitMasks
23
+
24
+ @classmethod
25
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
26
+ # inconsistently.
27
+ def convert(
28
+ cls,
29
+ densepose_predictor_outputs: Any,
30
+ boxes: Boxes,
31
+ image_size_hw: ImageSizeType,
32
+ *args,
33
+ **kwargs
34
+ ) -> BitMasks:
35
+ """
36
+ Convert DensePose predictor outputs to BitMasks using some registered
37
+ converter. Does recursive lookup for base classes, so there's no need
38
+ for explicit registration for derived classes.
39
+
40
+ Args:
41
+ densepose_predictor_outputs: DensePose predictor output to be
42
+ converted to BitMasks
43
+ boxes (Boxes): bounding boxes that correspond to the DensePose
44
+ predictor outputs
45
+ image_size_hw (tuple [int, int]): image height and width
46
+ Return:
47
+ An instance of `BitMasks`. If no suitable converter was found, raises KeyError
48
+ """
49
+ return super(ToMaskConverter, cls).convert(
50
+ densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
51
+ )
densepose/data/.DS_Store ADDED
Binary file (8.2 kB). View file
 
densepose/data/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .meshes import builtin
6
+ from .build import (
7
+ build_detection_test_loader,
8
+ build_detection_train_loader,
9
+ build_combined_loader,
10
+ build_frame_selector,
11
+ build_inference_based_loaders,
12
+ has_inference_based_loaders,
13
+ BootstrapDatasetFactoryCatalog,
14
+ )
15
+ from .combined_loader import CombinedDataLoader
16
+ from .dataset_mapper import DatasetMapper
17
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
18
+ from .image_list_dataset import ImageListDataset
19
+ from .utils import is_relative_local_path, maybe_prepend_base_path
20
+
21
+ # ensure the builtin datasets are registered
22
+ from . import datasets
23
+
24
+ # ensure the bootstrap datasets builders are registered
25
+ from . import build
26
+
27
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
densepose/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.07 kB). View file
 
densepose/data/__pycache__/build.cpython-39.pyc ADDED
Binary file (23.8 kB). View file
 
densepose/data/__pycache__/combined_loader.cpython-39.pyc ADDED
Binary file (1.83 kB). View file
 
densepose/data/__pycache__/dataset_mapper.cpython-39.pyc ADDED
Binary file (5.42 kB). View file
 
densepose/data/__pycache__/image_list_dataset.cpython-39.pyc ADDED
Binary file (2.67 kB). View file
 
densepose/data/__pycache__/inference_based_loader.cpython-39.pyc ADDED
Binary file (5.8 kB). View file
 
densepose/data/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.64 kB). View file
 
densepose/data/build.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import itertools
6
+ import logging
7
+ import numpy as np
8
+ from collections import UserDict, defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.config import CfgNode
15
+ from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
16
+ from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
17
+ from detectron2.data.build import (
18
+ load_proposals_into_dataset,
19
+ print_instances_class_histogram,
20
+ trivial_batch_collator,
21
+ worker_init_reset_seed,
22
+ )
23
+ from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
24
+ from detectron2.data.samplers import TrainingSampler
25
+ from detectron2.utils.comm import get_world_size
26
+
27
+ from densepose.config import get_bootstrap_dataset_config
28
+ from densepose.modeling import build_densepose_embedder
29
+
30
+ from .combined_loader import CombinedDataLoader, Loader
31
+ from .dataset_mapper import DatasetMapper
32
+ from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
33
+ from .datasets.dataset_type import DatasetType
34
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
35
+ from .samplers import (
36
+ DensePoseConfidenceBasedSampler,
37
+ DensePoseCSEConfidenceBasedSampler,
38
+ DensePoseCSEUniformSampler,
39
+ DensePoseUniformSampler,
40
+ MaskFromDensePoseSampler,
41
+ PredictionToGroundTruthSampler,
42
+ )
43
+ from .transform import ImageResizeTransform
44
+ from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
45
+ from .video import (
46
+ FirstKFramesSelector,
47
+ FrameSelectionStrategy,
48
+ LastKFramesSelector,
49
+ RandomKFramesSelector,
50
+ VideoKeyframeDataset,
51
+ video_list_from_file,
52
+ )
53
+
54
+ __all__ = ["build_detection_train_loader", "build_detection_test_loader"]
55
+
56
+
57
+ Instance = Dict[str, Any]
58
+ InstancePredicate = Callable[[Instance], bool]
59
+
60
+
61
+ def _compute_num_images_per_worker(cfg: CfgNode) -> int:
62
+ num_workers = get_world_size()
63
+ images_per_batch = cfg.SOLVER.IMS_PER_BATCH
64
+ assert (
65
+ images_per_batch % num_workers == 0
66
+ ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
67
+ images_per_batch, num_workers
68
+ )
69
+ assert (
70
+ images_per_batch >= num_workers
71
+ ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
72
+ images_per_batch, num_workers
73
+ )
74
+ images_per_worker = images_per_batch // num_workers
75
+ return images_per_worker
76
+
77
+
78
+ def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
79
+ meta = MetadataCatalog.get(dataset_name)
80
+ for dataset_dict in dataset_dicts:
81
+ for ann in dataset_dict["annotations"]:
82
+ ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
83
+
84
+
85
+ @dataclass
86
+ class _DatasetCategory:
87
+ """
88
+ Class representing category data in a dataset:
89
+ - id: category ID, as specified in the dataset annotations file
90
+ - name: category name, as specified in the dataset annotations file
91
+ - mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
92
+ - mapped_name: category name after applying category maps
93
+ - dataset_name: dataset in which the category is defined
94
+
95
+ For example, when training models in a class-agnostic manner, one could take LVIS 1.0
96
+ dataset and map the animal categories to the same category as human data from COCO:
97
+ id = 225
98
+ name = "cat"
99
+ mapped_id = 1
100
+ mapped_name = "person"
101
+ dataset_name = "lvis_v1_animals_dp_train"
102
+ """
103
+
104
+ id: int
105
+ name: str
106
+ mapped_id: int
107
+ mapped_name: str
108
+ dataset_name: str
109
+
110
+
111
+ _MergedCategoriesT = Dict[int, List[_DatasetCategory]]
112
+
113
+
114
+ def _add_category_id_to_contiguous_id_maps_to_metadata(
115
+ merged_categories: _MergedCategoriesT,
116
+ ) -> None:
117
+ merged_categories_per_dataset = {}
118
+ for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
119
+ for cat in merged_categories[cat_id]:
120
+ if cat.dataset_name not in merged_categories_per_dataset:
121
+ merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
122
+ merged_categories_per_dataset[cat.dataset_name][cat_id].append(
123
+ (
124
+ contiguous_cat_id,
125
+ cat,
126
+ )
127
+ )
128
+
129
+ logger = logging.getLogger(__name__)
130
+ for dataset_name, merged_categories in merged_categories_per_dataset.items():
131
+ meta = MetadataCatalog.get(dataset_name)
132
+ if not hasattr(meta, "thing_classes"):
133
+ meta.thing_classes = []
134
+ meta.thing_dataset_id_to_contiguous_id = {}
135
+ meta.thing_dataset_id_to_merged_id = {}
136
+ else:
137
+ meta.thing_classes.clear()
138
+ meta.thing_dataset_id_to_contiguous_id.clear()
139
+ meta.thing_dataset_id_to_merged_id.clear()
140
+ logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
141
+ for _cat_id, categories in sorted(merged_categories.items()):
142
+ added_to_thing_classes = False
143
+ for contiguous_cat_id, cat in categories:
144
+ if not added_to_thing_classes:
145
+ meta.thing_classes.append(cat.mapped_name)
146
+ added_to_thing_classes = True
147
+ meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
148
+ meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
149
+ logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
150
+
151
+
152
+ def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
153
+ def has_annotations(instance: Instance) -> bool:
154
+ return "annotations" in instance
155
+
156
+ def has_only_crowd_anotations(instance: Instance) -> bool:
157
+ for ann in instance["annotations"]:
158
+ if ann.get("is_crowd", 0) == 0:
159
+ return False
160
+ return True
161
+
162
+ def general_keep_instance_predicate(instance: Instance) -> bool:
163
+ return has_annotations(instance) and not has_only_crowd_anotations(instance)
164
+
165
+ if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
166
+ return None
167
+ return general_keep_instance_predicate
168
+
169
+
170
+ def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
171
+
172
+ min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
173
+
174
+ def has_sufficient_num_keypoints(instance: Instance) -> bool:
175
+ num_kpts = sum(
176
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
177
+ for ann in instance["annotations"]
178
+ if "keypoints" in ann
179
+ )
180
+ return num_kpts >= min_num_keypoints
181
+
182
+ if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
183
+ return has_sufficient_num_keypoints
184
+ return None
185
+
186
+
187
+ def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
188
+ if not cfg.MODEL.MASK_ON:
189
+ return None
190
+
191
+ def has_mask_annotations(instance: Instance) -> bool:
192
+ return any("segmentation" in ann for ann in instance["annotations"])
193
+
194
+ return has_mask_annotations
195
+
196
+
197
+ def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
198
+ if not cfg.MODEL.DENSEPOSE_ON:
199
+ return None
200
+
201
+ use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
202
+
203
+ def has_densepose_annotations(instance: Instance) -> bool:
204
+ for ann in instance["annotations"]:
205
+ if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
206
+ key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
207
+ ):
208
+ return True
209
+ if use_masks and "segmentation" in ann:
210
+ return True
211
+ return False
212
+
213
+ return has_densepose_annotations
214
+
215
+
216
+ def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
217
+ specific_predicate_creators = [
218
+ _maybe_create_keypoints_keep_instance_predicate,
219
+ _maybe_create_mask_keep_instance_predicate,
220
+ _maybe_create_densepose_keep_instance_predicate,
221
+ ]
222
+ predicates = [creator(cfg) for creator in specific_predicate_creators]
223
+ predicates = [p for p in predicates if p is not None]
224
+ if not predicates:
225
+ return None
226
+
227
+ def combined_predicate(instance: Instance) -> bool:
228
+ return any(p(instance) for p in predicates)
229
+
230
+ return combined_predicate
231
+
232
+
233
+ def _get_train_keep_instance_predicate(cfg: CfgNode):
234
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
235
+ combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
236
+
237
+ def combined_general_specific_keep_predicate(instance: Instance) -> bool:
238
+ return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
239
+
240
+ if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
241
+ return None
242
+ if general_keep_predicate is None:
243
+ return combined_specific_keep_predicate
244
+ if combined_specific_keep_predicate is None:
245
+ return general_keep_predicate
246
+ return combined_general_specific_keep_predicate
247
+
248
+
249
+ def _get_test_keep_instance_predicate(cfg: CfgNode):
250
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
251
+ return general_keep_predicate
252
+
253
+
254
+ def _maybe_filter_and_map_categories(
255
+ dataset_name: str, dataset_dicts: List[Instance]
256
+ ) -> List[Instance]:
257
+ meta = MetadataCatalog.get(dataset_name)
258
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
259
+ filtered_dataset_dicts = []
260
+ for dataset_dict in dataset_dicts:
261
+ anns = []
262
+ for ann in dataset_dict["annotations"]:
263
+ cat_id = ann["category_id"]
264
+ if cat_id not in category_id_map:
265
+ continue
266
+ ann["category_id"] = category_id_map[cat_id]
267
+ anns.append(ann)
268
+ dataset_dict["annotations"] = anns
269
+ filtered_dataset_dicts.append(dataset_dict)
270
+ return filtered_dataset_dicts
271
+
272
+
273
+ def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
274
+ for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
275
+ meta = MetadataCatalog.get(dataset_name)
276
+ meta.whitelisted_categories = whitelisted_cat_ids
277
+ logger = logging.getLogger(__name__)
278
+ logger.info(
279
+ "Whitelisted categories for dataset {}: {}".format(
280
+ dataset_name, meta.whitelisted_categories
281
+ )
282
+ )
283
+
284
+
285
+ def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
286
+ for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
287
+ category_map = {
288
+ int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
289
+ }
290
+ meta = MetadataCatalog.get(dataset_name)
291
+ meta.category_map = category_map
292
+ logger = logging.getLogger(__name__)
293
+ logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
294
+
295
+
296
+ def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
297
+ meta = MetadataCatalog.get(dataset_name)
298
+ meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
299
+ meta.categories = dataset_cfg.CATEGORIES
300
+ meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
301
+ logger = logging.getLogger(__name__)
302
+ logger.info(
303
+ "Category to class mapping for dataset {}: {}".format(
304
+ dataset_name, meta.category_to_class_mapping
305
+ )
306
+ )
307
+
308
+
309
+ def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
310
+ for dataset_name in dataset_names:
311
+ meta = MetadataCatalog.get(dataset_name)
312
+ if not hasattr(meta, "class_to_mesh_name"):
313
+ meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
314
+
315
+
316
+ def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
317
+ merged_categories = defaultdict(list)
318
+ category_names = {}
319
+ for dataset_name in dataset_names:
320
+ meta = MetadataCatalog.get(dataset_name)
321
+ whitelisted_categories = meta.get("whitelisted_categories")
322
+ category_map = meta.get("category_map", {})
323
+ cat_ids = (
324
+ whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
325
+ )
326
+ for cat_id in cat_ids:
327
+ cat_name = meta.categories[cat_id]
328
+ cat_id_mapped = category_map.get(cat_id, cat_id)
329
+ if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
330
+ category_names[cat_id] = cat_name
331
+ else:
332
+ category_names[cat_id] = str(cat_id_mapped)
333
+ # assign temporary mapped category name, this name can be changed
334
+ # during the second pass, since mapped ID can correspond to a category
335
+ # from a different dataset
336
+ cat_name_mapped = meta.categories[cat_id_mapped]
337
+ merged_categories[cat_id_mapped].append(
338
+ _DatasetCategory(
339
+ id=cat_id,
340
+ name=cat_name,
341
+ mapped_id=cat_id_mapped,
342
+ mapped_name=cat_name_mapped,
343
+ dataset_name=dataset_name,
344
+ )
345
+ )
346
+ # second pass to assign proper mapped category names
347
+ for cat_id, categories in merged_categories.items():
348
+ for cat in categories:
349
+ if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
350
+ cat.mapped_name = category_names[cat_id]
351
+
352
+ return merged_categories
353
+
354
+
355
+ def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
356
+ logger = logging.getLogger(__name__)
357
+ for cat_id in merged_categories:
358
+ merged_categories_i = merged_categories[cat_id]
359
+ first_cat_name = merged_categories_i[0].name
360
+ if len(merged_categories_i) > 1 and not all(
361
+ cat.name == first_cat_name for cat in merged_categories_i[1:]
362
+ ):
363
+ cat_summary_str = ", ".join(
364
+ [f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
365
+ )
366
+ logger.warning(
367
+ f"Merged category {cat_id} corresponds to the following categories: "
368
+ f"{cat_summary_str}"
369
+ )
370
+
371
+
372
+ def combine_detection_dataset_dicts(
373
+ dataset_names: Collection[str],
374
+ keep_instance_predicate: Optional[InstancePredicate] = None,
375
+ proposal_files: Optional[Collection[str]] = None,
376
+ ) -> List[Instance]:
377
+ """
378
+ Load and prepare dataset dicts for training / testing
379
+
380
+ Args:
381
+ dataset_names (Collection[str]): a list of dataset names
382
+ keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
383
+ applied to instance dicts which defines whether to keep the instance
384
+ proposal_files (Collection[str]): if given, a list of object proposal files
385
+ that match each dataset in `dataset_names`.
386
+ """
387
+ assert len(dataset_names)
388
+ if proposal_files is None:
389
+ proposal_files = [None] * len(dataset_names)
390
+ assert len(dataset_names) == len(proposal_files)
391
+ # load datasets and metadata
392
+ dataset_name_to_dicts = {}
393
+ for dataset_name in dataset_names:
394
+ dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
395
+ assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
396
+ # merge categories, requires category metadata to be loaded
397
+ # cat_id -> [(orig_cat_id, cat_name, dataset_name)]
398
+ merged_categories = _merge_categories(dataset_names)
399
+ _warn_if_merged_different_categories(merged_categories)
400
+ merged_category_names = [
401
+ merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
402
+ ]
403
+ # map to contiguous category IDs
404
+ _add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
405
+ # load annotations and dataset metadata
406
+ for dataset_name, proposal_file in zip(dataset_names, proposal_files):
407
+ dataset_dicts = dataset_name_to_dicts[dataset_name]
408
+ assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
409
+ if proposal_file is not None:
410
+ dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
411
+ dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
412
+ print_instances_class_histogram(dataset_dicts, merged_category_names)
413
+ dataset_name_to_dicts[dataset_name] = dataset_dicts
414
+
415
+ if keep_instance_predicate is not None:
416
+ all_datasets_dicts_plain = [
417
+ d
418
+ for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
419
+ if keep_instance_predicate(d)
420
+ ]
421
+ else:
422
+ all_datasets_dicts_plain = list(
423
+ itertools.chain.from_iterable(dataset_name_to_dicts.values())
424
+ )
425
+ return all_datasets_dicts_plain
426
+
427
+
428
+ def build_detection_train_loader(cfg: CfgNode, mapper=None):
429
+ """
430
+ A data loader is created in a way similar to that of Detectron2.
431
+ The main differences are:
432
+ - it allows to combine datasets with different but compatible object category sets
433
+
434
+ The data loader is created by the following steps:
435
+ 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
436
+ 2. Start workers to work on the dicts. Each worker will:
437
+ * Map each metadata dict into another format to be consumed by the model.
438
+ * Batch them by simply putting dicts into a list.
439
+ The batched ``list[mapped_dict]`` is what this dataloader will return.
440
+
441
+ Args:
442
+ cfg (CfgNode): the config
443
+ mapper (callable): a callable which takes a sample (dict) from dataset and
444
+ returns the format to be consumed by the model.
445
+ By default it will be `DatasetMapper(cfg, True)`.
446
+
447
+ Returns:
448
+ an infinite iterator of training data
449
+ """
450
+
451
+ _add_category_whitelists_to_metadata(cfg)
452
+ _add_category_maps_to_metadata(cfg)
453
+ _maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
454
+ dataset_dicts = combine_detection_dataset_dicts(
455
+ cfg.DATASETS.TRAIN,
456
+ keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
457
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
458
+ )
459
+ if mapper is None:
460
+ mapper = DatasetMapper(cfg, True)
461
+ return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
462
+
463
+
464
+ def build_detection_test_loader(cfg, dataset_name, mapper=None):
465
+ """
466
+ Similar to `build_detection_train_loader`.
467
+ But this function uses the given `dataset_name` argument (instead of the names in cfg),
468
+ and uses batch size 1.
469
+
470
+ Args:
471
+ cfg: a detectron2 CfgNode
472
+ dataset_name (str): a name of the dataset that's available in the DatasetCatalog
473
+ mapper (callable): a callable which takes a sample (dict) from dataset
474
+ and returns the format to be consumed by the model.
475
+ By default it will be `DatasetMapper(cfg, False)`.
476
+
477
+ Returns:
478
+ DataLoader: a torch DataLoader, that loads the given detection
479
+ dataset, with test-time transformation and batching.
480
+ """
481
+ _add_category_whitelists_to_metadata(cfg)
482
+ _add_category_maps_to_metadata(cfg)
483
+ _maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
484
+ dataset_dicts = combine_detection_dataset_dicts(
485
+ [dataset_name],
486
+ keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
487
+ proposal_files=(
488
+ [cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
489
+ if cfg.MODEL.LOAD_PROPOSALS
490
+ else None
491
+ ),
492
+ )
493
+ sampler = None
494
+ if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
495
+ sampler = torch.utils.data.SequentialSampler(dataset_dicts)
496
+ if mapper is None:
497
+ mapper = DatasetMapper(cfg, False)
498
+ return d2_build_detection_test_loader(
499
+ dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
500
+ )
501
+
502
+
503
+ def build_frame_selector(cfg: CfgNode):
504
+ strategy = FrameSelectionStrategy(cfg.STRATEGY)
505
+ if strategy == FrameSelectionStrategy.RANDOM_K:
506
+ frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
507
+ elif strategy == FrameSelectionStrategy.FIRST_K:
508
+ frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
509
+ elif strategy == FrameSelectionStrategy.LAST_K:
510
+ frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
511
+ elif strategy == FrameSelectionStrategy.ALL:
512
+ frame_selector = None
513
+ # pyre-fixme[61]: `frame_selector` may not be initialized here.
514
+ return frame_selector
515
+
516
+
517
+ def build_transform(cfg: CfgNode, data_type: str):
518
+ if cfg.TYPE == "resize":
519
+ if data_type == "image":
520
+ return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
521
+ raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
522
+
523
+
524
+ def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
525
+ images_per_worker = _compute_num_images_per_worker(cfg)
526
+ return CombinedDataLoader(loaders, images_per_worker, ratios)
527
+
528
+
529
+ def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
530
+ """
531
+ Build dataset that provides data to bootstrap on
532
+
533
+ Args:
534
+ dataset_name (str): Name of the dataset, needs to have associated metadata
535
+ to load the data
536
+ cfg (CfgNode): bootstrapping config
537
+ Returns:
538
+ Sequence[Tensor] - dataset that provides image batches, Tensors of size
539
+ [N, C, H, W] of type float32
540
+ """
541
+ logger = logging.getLogger(__name__)
542
+ _add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
543
+ meta = MetadataCatalog.get(dataset_name)
544
+ factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
545
+ dataset = None
546
+ if factory is not None:
547
+ dataset = factory(meta, cfg)
548
+ if dataset is None:
549
+ logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
550
+ return dataset
551
+
552
+
553
+ def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
554
+ if sampler_cfg.TYPE == "densepose_uniform":
555
+ data_sampler = PredictionToGroundTruthSampler()
556
+ # transform densepose pred -> gt
557
+ data_sampler.register_sampler(
558
+ "pred_densepose",
559
+ "gt_densepose",
560
+ DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
561
+ )
562
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
563
+ return data_sampler
564
+ elif sampler_cfg.TYPE == "densepose_UV_confidence":
565
+ data_sampler = PredictionToGroundTruthSampler()
566
+ # transform densepose pred -> gt
567
+ data_sampler.register_sampler(
568
+ "pred_densepose",
569
+ "gt_densepose",
570
+ DensePoseConfidenceBasedSampler(
571
+ confidence_channel="sigma_2",
572
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
573
+ search_proportion=0.5,
574
+ ),
575
+ )
576
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
577
+ return data_sampler
578
+ elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
579
+ data_sampler = PredictionToGroundTruthSampler()
580
+ # transform densepose pred -> gt
581
+ data_sampler.register_sampler(
582
+ "pred_densepose",
583
+ "gt_densepose",
584
+ DensePoseConfidenceBasedSampler(
585
+ confidence_channel="fine_segm_confidence",
586
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
587
+ search_proportion=0.5,
588
+ ),
589
+ )
590
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
591
+ return data_sampler
592
+ elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
593
+ data_sampler = PredictionToGroundTruthSampler()
594
+ # transform densepose pred -> gt
595
+ data_sampler.register_sampler(
596
+ "pred_densepose",
597
+ "gt_densepose",
598
+ DensePoseConfidenceBasedSampler(
599
+ confidence_channel="coarse_segm_confidence",
600
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
601
+ search_proportion=0.5,
602
+ ),
603
+ )
604
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
605
+ return data_sampler
606
+ elif sampler_cfg.TYPE == "densepose_cse_uniform":
607
+ assert embedder is not None
608
+ data_sampler = PredictionToGroundTruthSampler()
609
+ # transform densepose pred -> gt
610
+ data_sampler.register_sampler(
611
+ "pred_densepose",
612
+ "gt_densepose",
613
+ DensePoseCSEUniformSampler(
614
+ cfg=cfg,
615
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
616
+ embedder=embedder,
617
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
618
+ ),
619
+ )
620
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
621
+ return data_sampler
622
+ elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
623
+ assert embedder is not None
624
+ data_sampler = PredictionToGroundTruthSampler()
625
+ # transform densepose pred -> gt
626
+ data_sampler.register_sampler(
627
+ "pred_densepose",
628
+ "gt_densepose",
629
+ DensePoseCSEConfidenceBasedSampler(
630
+ cfg=cfg,
631
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
632
+ embedder=embedder,
633
+ confidence_channel="coarse_segm_confidence",
634
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
635
+ search_proportion=0.5,
636
+ ),
637
+ )
638
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
639
+ return data_sampler
640
+
641
+ raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
642
+
643
+
644
+ def build_data_filter(cfg: CfgNode):
645
+ if cfg.TYPE == "detection_score":
646
+ min_score = cfg.MIN_VALUE
647
+ return ScoreBasedFilter(min_score=min_score)
648
+ raise ValueError(f"Unknown data filter type {cfg.TYPE}")
649
+
650
+
651
+ def build_inference_based_loader(
652
+ cfg: CfgNode,
653
+ dataset_cfg: CfgNode,
654
+ model: torch.nn.Module,
655
+ embedder: Optional[torch.nn.Module] = None,
656
+ ) -> InferenceBasedLoader:
657
+ """
658
+ Constructs data loader based on inference results of a model.
659
+ """
660
+ dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
661
+ meta = MetadataCatalog.get(dataset_cfg.DATASET)
662
+ training_sampler = TrainingSampler(len(dataset))
663
+ data_loader = torch.utils.data.DataLoader(
664
+ dataset, # pyre-ignore[6]
665
+ batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
666
+ sampler=training_sampler,
667
+ num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
668
+ collate_fn=trivial_batch_collator,
669
+ worker_init_fn=worker_init_reset_seed,
670
+ )
671
+ return InferenceBasedLoader(
672
+ model,
673
+ data_loader=data_loader,
674
+ data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
675
+ data_filter=build_data_filter(dataset_cfg.FILTER),
676
+ shuffle=True,
677
+ batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
678
+ inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
679
+ category_to_class_mapping=meta.category_to_class_mapping,
680
+ )
681
+
682
+
683
+ def has_inference_based_loaders(cfg: CfgNode) -> bool:
684
+ """
685
+ Returns True, if at least one inferense-based loader must
686
+ be instantiated for training
687
+ """
688
+ return len(cfg.BOOTSTRAP_DATASETS) > 0
689
+
690
+
691
+ def build_inference_based_loaders(
692
+ cfg: CfgNode, model: torch.nn.Module
693
+ ) -> Tuple[List[InferenceBasedLoader], List[float]]:
694
+ loaders = []
695
+ ratios = []
696
+ embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
697
+ for dataset_spec in cfg.BOOTSTRAP_DATASETS:
698
+ dataset_cfg = get_bootstrap_dataset_config().clone()
699
+ dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
700
+ loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
701
+ loaders.append(loader)
702
+ ratios.append(dataset_cfg.RATIO)
703
+ return loaders, ratios
704
+
705
+
706
+ def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
707
+ video_list_fpath = meta.video_list_fpath
708
+ video_base_path = meta.video_base_path
709
+ category = meta.category
710
+ if cfg.TYPE == "video_keyframe":
711
+ frame_selector = build_frame_selector(cfg.SELECT)
712
+ transform = build_transform(cfg.TRANSFORM, data_type="image")
713
+ video_list = video_list_from_file(video_list_fpath, video_base_path)
714
+ keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
715
+ return VideoKeyframeDataset(
716
+ video_list, category, frame_selector, transform, keyframe_helper_fpath
717
+ )
718
+
719
+
720
+ class _BootstrapDatasetFactoryCatalog(UserDict):
721
+ """
722
+ A global dictionary that stores information about bootstrapped datasets creation functions
723
+ from metadata and config, for diverse DatasetType
724
+ """
725
+
726
+ def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
727
+ """
728
+ Args:
729
+ dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
730
+ factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
731
+ arguments and returns a dataset object.
732
+ """
733
+ assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
734
+ self[dataset_type] = factory
735
+
736
+
737
+ BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
738
+ BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
densepose/data/combined_loader.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections import deque
7
+ from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
8
+
9
+ Loader = Iterable[Any]
10
+
11
+
12
+ def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
13
+ if not pool:
14
+ pool.extend(next(iterator))
15
+ return pool.popleft()
16
+
17
+
18
+ class CombinedDataLoader:
19
+ """
20
+ Combines data loaders using the provided sampling ratios
21
+ """
22
+
23
+ BATCH_COUNT = 100
24
+
25
+ def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
26
+ self.loaders = loaders
27
+ self.batch_size = batch_size
28
+ self.ratios = ratios
29
+
30
+ def __iter__(self) -> Iterator[List[Any]]:
31
+ iters = [iter(loader) for loader in self.loaders]
32
+ indices = []
33
+ pool = [deque()] * len(iters)
34
+ # infinite iterator, as in D2
35
+ while True:
36
+ if not indices:
37
+ # just a buffer of indices, its size doesn't matter
38
+ # as long as it's a multiple of batch_size
39
+ k = self.batch_size * self.BATCH_COUNT
40
+ indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
41
+ try:
42
+ batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
43
+ except StopIteration:
44
+ break
45
+ indices = indices[self.batch_size :]
46
+ yield batch
densepose/data/dataset_mapper.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import copy
7
+ import logging
8
+ from typing import Any, Dict, List, Tuple
9
+ import torch
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.data import detection_utils as utils
13
+ from detectron2.data import transforms as T
14
+ from detectron2.layers import ROIAlign
15
+ from detectron2.structures import BoxMode
16
+ from detectron2.utils.file_io import PathManager
17
+
18
+ from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
19
+
20
+
21
+ def build_augmentation(cfg, is_train):
22
+ logger = logging.getLogger(__name__)
23
+ result = utils.build_augmentation(cfg, is_train)
24
+ if is_train:
25
+ random_rotation = T.RandomRotation(
26
+ cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
27
+ )
28
+ result.append(random_rotation)
29
+ logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
30
+ return result
31
+
32
+
33
+ class DatasetMapper:
34
+ """
35
+ A customized version of `detectron2.data.DatasetMapper`
36
+ """
37
+
38
+ def __init__(self, cfg, is_train=True):
39
+ self.augmentation = build_augmentation(cfg, is_train)
40
+
41
+ # fmt: off
42
+ self.img_format = cfg.INPUT.FORMAT
43
+ self.mask_on = (
44
+ cfg.MODEL.MASK_ON or (
45
+ cfg.MODEL.DENSEPOSE_ON
46
+ and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
47
+ )
48
+ self.keypoint_on = cfg.MODEL.KEYPOINT_ON
49
+ self.densepose_on = cfg.MODEL.DENSEPOSE_ON
50
+ assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
51
+ # fmt: on
52
+ if self.keypoint_on and is_train:
53
+ # Flip only makes sense in training
54
+ self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
55
+ else:
56
+ self.keypoint_hflip_indices = None
57
+
58
+ if self.densepose_on:
59
+ densepose_transform_srcs = [
60
+ MetadataCatalog.get(ds).densepose_transform_src
61
+ for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
62
+ ]
63
+ assert len(densepose_transform_srcs) > 0
64
+ # TODO: check that DensePose transformation data is the same for
65
+ # all the datasets. Otherwise one would have to pass DB ID with
66
+ # each entry to select proper transformation data. For now, since
67
+ # all DensePose annotated data uses the same data semantics, we
68
+ # omit this check.
69
+ densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
70
+ self.densepose_transform_data = DensePoseTransformData.load(
71
+ densepose_transform_data_fpath
72
+ )
73
+
74
+ self.is_train = is_train
75
+
76
+ def __call__(self, dataset_dict):
77
+ """
78
+ Args:
79
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
80
+
81
+ Returns:
82
+ dict: a format that builtin models in detectron2 accept
83
+ """
84
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
85
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
86
+ utils.check_image_size(dataset_dict, image)
87
+
88
+ image, transforms = T.apply_transform_gens(self.augmentation, image)
89
+ image_shape = image.shape[:2] # h, w
90
+ dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
91
+
92
+ if not self.is_train:
93
+ dataset_dict.pop("annotations", None)
94
+ return dataset_dict
95
+
96
+ for anno in dataset_dict["annotations"]:
97
+ if not self.mask_on:
98
+ anno.pop("segmentation", None)
99
+ if not self.keypoint_on:
100
+ anno.pop("keypoints", None)
101
+
102
+ # USER: Implement additional transformations if you have other types of data
103
+ # USER: Don't call transpose_densepose if you don't need
104
+ annos = [
105
+ self._transform_densepose(
106
+ utils.transform_instance_annotations(
107
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
108
+ ),
109
+ transforms,
110
+ )
111
+ for obj in dataset_dict.pop("annotations")
112
+ if obj.get("iscrowd", 0) == 0
113
+ ]
114
+
115
+ if self.mask_on:
116
+ self._add_densepose_masks_as_segmentation(annos, image_shape)
117
+
118
+ instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
119
+ densepose_annotations = [obj.get("densepose") for obj in annos]
120
+ if densepose_annotations and not all(v is None for v in densepose_annotations):
121
+ instances.gt_densepose = DensePoseList(
122
+ densepose_annotations, instances.gt_boxes, image_shape
123
+ )
124
+
125
+ dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
126
+ return dataset_dict
127
+
128
+ def _transform_densepose(self, annotation, transforms):
129
+ if not self.densepose_on:
130
+ return annotation
131
+
132
+ # Handle densepose annotations
133
+ is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
134
+ if is_valid:
135
+ densepose_data = DensePoseDataRelative(annotation, cleanup=True)
136
+ densepose_data.apply_transform(transforms, self.densepose_transform_data)
137
+ annotation["densepose"] = densepose_data
138
+ else:
139
+ # logger = logging.getLogger(__name__)
140
+ # logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
141
+ DensePoseDataRelative.cleanup_annotation(annotation)
142
+ # NOTE: annotations for certain instances may be unavailable.
143
+ # 'None' is accepted by the DensePostList data structure.
144
+ annotation["densepose"] = None
145
+ return annotation
146
+
147
+ def _add_densepose_masks_as_segmentation(
148
+ self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
149
+ ):
150
+ for obj in annotations:
151
+ if ("densepose" not in obj) or ("segmentation" in obj):
152
+ continue
153
+ # DP segmentation: torch.Tensor [S, S] of float32, S=256
154
+ segm_dp = torch.zeros_like(obj["densepose"].segm)
155
+ segm_dp[obj["densepose"].segm > 0] = 1
156
+ segm_h, segm_w = segm_dp.shape
157
+ bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
158
+ # image bbox
159
+ x0, y0, x1, y1 = (
160
+ v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
161
+ )
162
+ segm_aligned = (
163
+ ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
164
+ .forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
165
+ .squeeze()
166
+ )
167
+ image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
168
+ image_mask[y0:y1, x0:x1] = segm_aligned
169
+ # segmentation for BitMask: np.array [H, W] of bool
170
+ obj["segmentation"] = image_mask >= 0.5
densepose/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin # ensure the builtin datasets are registered
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
densepose/data/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (406 Bytes). View file
 
densepose/data/datasets/__pycache__/builtin.cpython-39.pyc ADDED
Binary file (597 Bytes). View file
 
densepose/data/datasets/__pycache__/chimpnsee.cpython-39.pyc ADDED
Binary file (1.06 kB). View file
 
densepose/data/datasets/__pycache__/coco.cpython-39.pyc ADDED
Binary file (11.7 kB). View file
 
densepose/data/datasets/__pycache__/dataset_type.cpython-39.pyc ADDED
Binary file (521 Bytes). View file
 
densepose/data/datasets/__pycache__/lvis.cpython-39.pyc ADDED
Binary file (7.85 kB). View file
 
densepose/data/datasets/builtin.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .chimpnsee import register_dataset as register_chimpnsee_dataset
5
+ from .coco import BASE_DATASETS as BASE_COCO_DATASETS
6
+ from .coco import DATASETS as COCO_DATASETS
7
+ from .coco import register_datasets as register_coco_datasets
8
+ from .lvis import DATASETS as LVIS_DATASETS
9
+ from .lvis import register_datasets as register_lvis_datasets
10
+
11
+ DEFAULT_DATASETS_ROOT = "datasets"
12
+
13
+
14
+ register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
15
+ register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
16
+ register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
17
+
18
+ register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
densepose/data/datasets/chimpnsee.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Optional
6
+
7
+ from detectron2.data import DatasetCatalog, MetadataCatalog
8
+
9
+ from ..utils import maybe_prepend_base_path
10
+ from .dataset_type import DatasetType
11
+
12
+ CHIMPNSEE_DATASET_NAME = "chimpnsee"
13
+
14
+
15
+ def register_dataset(datasets_root: Optional[str] = None) -> None:
16
+ def empty_load_callback():
17
+ pass
18
+
19
+ video_list_fpath = maybe_prepend_base_path(
20
+ datasets_root,
21
+ "chimpnsee/cdna.eva.mpg.de/video_list.txt",
22
+ )
23
+ video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de")
24
+
25
+ DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback)
26
+ MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set(
27
+ dataset_type=DatasetType.VIDEO_LIST,
28
+ video_list_fpath=video_list_fpath,
29
+ video_base_path=video_base_path,
30
+ category="chimpanzee",
31
+ )
densepose/data/datasets/coco.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import contextlib
5
+ import io
6
+ import logging
7
+ import os
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+ from fvcore.common.timer import Timer
12
+
13
+ from detectron2.data import DatasetCatalog, MetadataCatalog
14
+ from detectron2.structures import BoxMode
15
+ from detectron2.utils.file_io import PathManager
16
+
17
+ from ..utils import maybe_prepend_base_path
18
+
19
+ DENSEPOSE_MASK_KEY = "dp_masks"
20
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"]
21
+ DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"]
22
+ DENSEPOSE_ALL_POSSIBLE_KEYS = set(
23
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY]
24
+ )
25
+ DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/"
26
+
27
+
28
+ @dataclass
29
+ class CocoDatasetInfo:
30
+ name: str
31
+ images_root: str
32
+ annotations_fpath: str
33
+
34
+
35
+ DATASETS = [
36
+ CocoDatasetInfo(
37
+ name="densepose_coco_2014_train",
38
+ images_root="coco/train2014",
39
+ annotations_fpath="coco/annotations/densepose_train2014.json",
40
+ ),
41
+ CocoDatasetInfo(
42
+ name="densepose_coco_2014_minival",
43
+ images_root="coco/val2014",
44
+ annotations_fpath="coco/annotations/densepose_minival2014.json",
45
+ ),
46
+ CocoDatasetInfo(
47
+ name="densepose_coco_2014_minival_100",
48
+ images_root="coco/val2014",
49
+ annotations_fpath="coco/annotations/densepose_minival2014_100.json",
50
+ ),
51
+ CocoDatasetInfo(
52
+ name="densepose_coco_2014_valminusminival",
53
+ images_root="coco/val2014",
54
+ annotations_fpath="coco/annotations/densepose_valminusminival2014.json",
55
+ ),
56
+ CocoDatasetInfo(
57
+ name="densepose_coco_2014_train_cse",
58
+ images_root="coco/train2014",
59
+ annotations_fpath="coco_cse/densepose_train2014_cse.json",
60
+ ),
61
+ CocoDatasetInfo(
62
+ name="densepose_coco_2014_minival_cse",
63
+ images_root="coco/val2014",
64
+ annotations_fpath="coco_cse/densepose_minival2014_cse.json",
65
+ ),
66
+ CocoDatasetInfo(
67
+ name="densepose_coco_2014_minival_100_cse",
68
+ images_root="coco/val2014",
69
+ annotations_fpath="coco_cse/densepose_minival2014_100_cse.json",
70
+ ),
71
+ CocoDatasetInfo(
72
+ name="densepose_coco_2014_valminusminival_cse",
73
+ images_root="coco/val2014",
74
+ annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json",
75
+ ),
76
+ CocoDatasetInfo(
77
+ name="densepose_chimps",
78
+ images_root="densepose_chimps/images",
79
+ annotations_fpath="densepose_chimps/densepose_chimps_densepose.json",
80
+ ),
81
+ CocoDatasetInfo(
82
+ name="densepose_chimps_cse_train",
83
+ images_root="densepose_chimps/images",
84
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json",
85
+ ),
86
+ CocoDatasetInfo(
87
+ name="densepose_chimps_cse_val",
88
+ images_root="densepose_chimps/images",
89
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json",
90
+ ),
91
+ CocoDatasetInfo(
92
+ name="posetrack2017_train",
93
+ images_root="posetrack2017/posetrack_data_2017",
94
+ annotations_fpath="posetrack2017/densepose_posetrack_train2017.json",
95
+ ),
96
+ CocoDatasetInfo(
97
+ name="posetrack2017_val",
98
+ images_root="posetrack2017/posetrack_data_2017",
99
+ annotations_fpath="posetrack2017/densepose_posetrack_val2017.json",
100
+ ),
101
+ CocoDatasetInfo(
102
+ name="lvis_v05_train",
103
+ images_root="coco/train2017",
104
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json",
105
+ ),
106
+ CocoDatasetInfo(
107
+ name="lvis_v05_val",
108
+ images_root="coco/val2017",
109
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json",
110
+ ),
111
+ ]
112
+
113
+
114
+ BASE_DATASETS = [
115
+ CocoDatasetInfo(
116
+ name="base_coco_2017_train",
117
+ images_root="coco/train2017",
118
+ annotations_fpath="coco/annotations/instances_train2017.json",
119
+ ),
120
+ CocoDatasetInfo(
121
+ name="base_coco_2017_val",
122
+ images_root="coco/val2017",
123
+ annotations_fpath="coco/annotations/instances_val2017.json",
124
+ ),
125
+ CocoDatasetInfo(
126
+ name="base_coco_2017_val_100",
127
+ images_root="coco/val2017",
128
+ annotations_fpath="coco/annotations/instances_val2017_100.json",
129
+ ),
130
+ ]
131
+
132
+
133
+ def get_metadata(base_path: Optional[str]) -> Dict[str, Any]:
134
+ """
135
+ Returns metadata associated with COCO DensePose datasets
136
+
137
+ Args:
138
+ base_path: Optional[str]
139
+ Base path used to load metadata from
140
+
141
+ Returns:
142
+ Dict[str, Any]
143
+ Metadata in the form of a dictionary
144
+ """
145
+ meta = {
146
+ "densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"),
147
+ "densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"),
148
+ "densepose_smpl_subdiv_transform": maybe_prepend_base_path(
149
+ base_path,
150
+ "SMPL_SUBDIV_TRANSFORM.mat",
151
+ ),
152
+ }
153
+ return meta
154
+
155
+
156
+ def _load_coco_annotations(json_file: str):
157
+ """
158
+ Load COCO annotations from a JSON file
159
+
160
+ Args:
161
+ json_file: str
162
+ Path to the file to load annotations from
163
+ Returns:
164
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
165
+ data
166
+ """
167
+ from pycocotools.coco import COCO
168
+
169
+ logger = logging.getLogger(__name__)
170
+ timer = Timer()
171
+ with contextlib.redirect_stdout(io.StringIO()):
172
+ coco_api = COCO(json_file)
173
+ if timer.seconds() > 1:
174
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
175
+ return coco_api
176
+
177
+
178
+ def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]):
179
+ meta = MetadataCatalog.get(dataset_name)
180
+ meta.categories = {c["id"]: c["name"] for c in categories}
181
+ logger = logging.getLogger(__name__)
182
+ logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories))
183
+
184
+
185
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]):
186
+ if "minival" in json_file:
187
+ # Skip validation on COCO2014 valminusminival and minival annotations
188
+ # The ratio of buggy annotations there is tiny and does not affect accuracy
189
+ # Therefore we explicitly white-list them
190
+ return
191
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
192
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
193
+ json_file
194
+ )
195
+
196
+
197
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
198
+ if "bbox" not in ann_dict:
199
+ return
200
+ obj["bbox"] = ann_dict["bbox"]
201
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
202
+
203
+
204
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
205
+ if "segmentation" not in ann_dict:
206
+ return
207
+ segm = ann_dict["segmentation"]
208
+ if not isinstance(segm, dict):
209
+ # filter out invalid polygons (< 3 points)
210
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
211
+ if len(segm) == 0:
212
+ return
213
+ obj["segmentation"] = segm
214
+
215
+
216
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
217
+ if "keypoints" not in ann_dict:
218
+ return
219
+ keypts = ann_dict["keypoints"] # list[int]
220
+ for idx, v in enumerate(keypts):
221
+ if idx % 3 != 2:
222
+ # COCO's segmentation coordinates are floating points in [0, H or W],
223
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
224
+ # Therefore we assume the coordinates are "pixel indices" and
225
+ # add 0.5 to convert to floating point coordinates.
226
+ keypts[idx] = v + 0.5
227
+ obj["keypoints"] = keypts
228
+
229
+
230
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
231
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
232
+ if key in ann_dict:
233
+ obj[key] = ann_dict[key]
234
+
235
+
236
+ def _combine_images_with_annotations(
237
+ dataset_name: str,
238
+ image_root: str,
239
+ img_datas: Iterable[Dict[str, Any]],
240
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
241
+ ):
242
+
243
+ ann_keys = ["iscrowd", "category_id"]
244
+ dataset_dicts = []
245
+ contains_video_frame_info = False
246
+
247
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
248
+ record = {}
249
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
250
+ record["height"] = img_dict["height"]
251
+ record["width"] = img_dict["width"]
252
+ record["image_id"] = img_dict["id"]
253
+ record["dataset"] = dataset_name
254
+ if "frame_id" in img_dict:
255
+ record["frame_id"] = img_dict["frame_id"]
256
+ record["video_id"] = img_dict.get("vid_id", None)
257
+ contains_video_frame_info = True
258
+ objs = []
259
+ for ann_dict in ann_dicts:
260
+ assert ann_dict["image_id"] == record["image_id"]
261
+ assert ann_dict.get("ignore", 0) == 0
262
+ obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict}
263
+ _maybe_add_bbox(obj, ann_dict)
264
+ _maybe_add_segm(obj, ann_dict)
265
+ _maybe_add_keypoints(obj, ann_dict)
266
+ _maybe_add_densepose(obj, ann_dict)
267
+ objs.append(obj)
268
+ record["annotations"] = objs
269
+ dataset_dicts.append(record)
270
+ if contains_video_frame_info:
271
+ create_video_frame_mapping(dataset_name, dataset_dicts)
272
+ return dataset_dicts
273
+
274
+
275
+ def get_contiguous_id_to_category_id_map(metadata):
276
+ cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id
277
+ cont_id_2_cat_id = {}
278
+ for cat_id, cont_id in cat_id_2_cont_id.items():
279
+ if cont_id in cont_id_2_cat_id:
280
+ continue
281
+ cont_id_2_cat_id[cont_id] = cat_id
282
+ return cont_id_2_cat_id
283
+
284
+
285
+ def maybe_filter_categories_cocoapi(dataset_name, coco_api):
286
+ meta = MetadataCatalog.get(dataset_name)
287
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta)
288
+ cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id
289
+ # filter categories
290
+ cats = []
291
+ for cat in coco_api.dataset["categories"]:
292
+ cat_id = cat["id"]
293
+ if cat_id not in cat_id_2_cont_id:
294
+ continue
295
+ cont_id = cat_id_2_cont_id[cat_id]
296
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
297
+ cats.append(cat)
298
+ coco_api.dataset["categories"] = cats
299
+ # filter annotations, if multiple categories are mapped to a single
300
+ # contiguous ID, use only one category ID and map all annotations to that category ID
301
+ anns = []
302
+ for ann in coco_api.dataset["annotations"]:
303
+ cat_id = ann["category_id"]
304
+ if cat_id not in cat_id_2_cont_id:
305
+ continue
306
+ cont_id = cat_id_2_cont_id[cat_id]
307
+ ann["category_id"] = cont_id_2_cat_id[cont_id]
308
+ anns.append(ann)
309
+ coco_api.dataset["annotations"] = anns
310
+ # recreate index
311
+ coco_api.createIndex()
312
+
313
+
314
+ def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api):
315
+ meta = MetadataCatalog.get(dataset_name)
316
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
317
+ # map categories
318
+ cats = []
319
+ for cat in coco_api.dataset["categories"]:
320
+ cat_id = cat["id"]
321
+ if cat_id not in category_id_map:
322
+ continue
323
+ cat["id"] = category_id_map[cat_id]
324
+ cats.append(cat)
325
+ coco_api.dataset["categories"] = cats
326
+ # map annotation categories
327
+ anns = []
328
+ for ann in coco_api.dataset["annotations"]:
329
+ cat_id = ann["category_id"]
330
+ if cat_id not in category_id_map:
331
+ continue
332
+ ann["category_id"] = category_id_map[cat_id]
333
+ anns.append(ann)
334
+ coco_api.dataset["annotations"] = anns
335
+ # recreate index
336
+ coco_api.createIndex()
337
+
338
+
339
+ def create_video_frame_mapping(dataset_name, dataset_dicts):
340
+ mapping = defaultdict(dict)
341
+ for d in dataset_dicts:
342
+ video_id = d.get("video_id")
343
+ if video_id is None:
344
+ continue
345
+ mapping[video_id].update({d["frame_id"]: d["file_name"]})
346
+ MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping)
347
+
348
+
349
+ def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str):
350
+ """
351
+ Loads a JSON file with annotations in COCO instances format.
352
+ Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata
353
+ in a more flexible way. Postpones category mapping to a later stage to be
354
+ able to combine several datasets with different (but coherent) sets of
355
+ categories.
356
+
357
+ Args:
358
+
359
+ annotations_json_file: str
360
+ Path to the JSON file with annotations in COCO instances format.
361
+ image_root: str
362
+ directory that contains all the images
363
+ dataset_name: str
364
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
365
+ extra_annotation_keys: Optional[List[str]]
366
+ If provided, these keys are used to extract additional data from
367
+ the annotations.
368
+ """
369
+ coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file))
370
+ _add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds()))
371
+ # sort indices for reproducible results
372
+ img_ids = sorted(coco_api.imgs.keys())
373
+ # imgs is a list of dicts, each looks something like:
374
+ # {'license': 4,
375
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
376
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
377
+ # 'height': 427,
378
+ # 'width': 640,
379
+ # 'date_captured': '2013-11-17 05:57:24',
380
+ # 'id': 1268}
381
+ imgs = coco_api.loadImgs(img_ids)
382
+ logger = logging.getLogger(__name__)
383
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file))
384
+ # anns is a list[list[dict]], where each dict is an annotation
385
+ # record for an object. The inner list enumerates the objects in an image
386
+ # and the outer list enumerates over images.
387
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
388
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
389
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
390
+ return dataset_records
391
+
392
+
393
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None):
394
+ """
395
+ Registers provided COCO DensePose dataset
396
+
397
+ Args:
398
+ dataset_data: CocoDatasetInfo
399
+ Dataset data
400
+ datasets_root: Optional[str]
401
+ Datasets root folder (default: None)
402
+ """
403
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
404
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
405
+
406
+ def load_annotations():
407
+ return load_coco_json(
408
+ annotations_json_file=annotations_fpath,
409
+ image_root=images_root,
410
+ dataset_name=dataset_data.name,
411
+ )
412
+
413
+ DatasetCatalog.register(dataset_data.name, load_annotations)
414
+ MetadataCatalog.get(dataset_data.name).set(
415
+ json_file=annotations_fpath,
416
+ image_root=images_root,
417
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX)
418
+ )
419
+
420
+
421
+ def register_datasets(
422
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
423
+ ):
424
+ """
425
+ Registers provided COCO DensePose datasets
426
+
427
+ Args:
428
+ datasets_data: Iterable[CocoDatasetInfo]
429
+ An iterable of dataset datas
430
+ datasets_root: Optional[str]
431
+ Datasets root folder (default: None)
432
+ """
433
+ for dataset_data in datasets_data:
434
+ register_dataset(dataset_data, datasets_root)
densepose/data/datasets/dataset_type.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class DatasetType(Enum):
9
+ """
10
+ Dataset type, mostly used for datasets that contain data to bootstrap models on
11
+ """
12
+
13
+ VIDEO_LIST = "video_list"