diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e47318b66a78291b44506796e270497c2228478a
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,203 @@
+Copyright (c) 2022 SenseTime. All Rights Reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2020 MMClassification Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Pose_Anything_Teaser.png b/Pose_Anything_Teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..a85ca59f36cf61651492e8b01009e1faaacb2b2e
Binary files /dev/null and b/Pose_Anything_Teaser.png differ
diff --git a/README.md b/README.md
index 04d327139c3c1c4130534b3c4c253b3dc21696cd..e256010348e30a09da1beefff2bee9b63c819f7b 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,145 @@
----
-title: PoseAnything
-emoji: 🏢
-colorFrom: red
-colorTo: red
-sdk: gradio
-sdk_version: 4.11.0
-app_file: app.py
-pinned: false
-license: apache-2.0
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Pose Anything: A Graph-Based Approach for Category-Agnostic Pose Estimation
+
+
+
+[](https://paperswithcode.com/sota/2d-pose-estimation-on-mp-100?p=pose-anything-a-graph-based-approach-for)
+
+By [Or Hirschorn](https://scholar.google.co.il/citations?user=GgFuT_QAAAAJ&hl=iw&oi=ao) and [Shai Avidan](https://scholar.google.co.il/citations?hl=iw&user=hpItE1QAAAAJ)
+
+This repo is the official implementation of "[Pose Anything: A Graph-Based Approach for Category-Agnostic Pose Estimation](https://arxiv.org/pdf/2311.17891.pdf)".
+
+
+
+
+## Introduction
+
+We present a novel approach to CAPE that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities.
+
+## Citation
+If you find this useful, please cite this work as follows:
+```bibtex
+@misc{hirschorn2023pose,
+ title={Pose Anything: A Graph-Based Approach for Category-Agnostic Pose Estimation},
+ author={Or Hirschorn and Shai Avidan},
+ year={2023},
+ eprint={2311.17891},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Getting Started
+
+### Docker [Recommended]
+We provide a docker image for easy use.
+You can simply pull the docker image from docker hub, containing all the required libraries and packages:
+
+```
+docker pull orhir/pose_anything
+docker run --name pose_anything -v {DATA_DIR}:/workspace/PoseAnything/PoseAnything/data/mp100 -it orhir/pose_anything /bin/bash
+```
+### Conda Environment
+We train and evaluate our model on Python 3.8 and Pytorch 2.0.1 with CUDA 12.1.
+
+Please first install pytorch and torchvision following official documentation Pytorch.
+Then, follow [MMPose](https://mmpose.readthedocs.io/en/latest/installation.html) to install the following packages:
+```
+mmcv-full=1.6.2
+mmpose=0.29.0
+```
+Having installed these packages, run:
+```
+python setup.py develop
+```
+
+## Demo on Custom Images
+We provide a demo code to test our code on custom images.
+
+***A bigger and more accurate version of the model - COMING SOON!***
+
+### Gradio Demo
+We first require to install gradio:
+```
+pip install gradio==3.44.0
+```
+Then, Download the [pretrained model](https://drive.google.com/file/d/1RT1Q8AMEa1kj6k9ZqrtWIKyuR4Jn4Pqc/view?usp=drive_link) and run:
+```
+python app.py --checkpoint [path_to_pretrained_ckpt]
+```
+### Terminal Demo
+Download
+the [pretrained model](https://drive.google.com/file/d/1RT1Q8AMEa1kj6k9ZqrtWIKyuR4Jn4Pqc/view?usp=drive_link)
+and run:
+
+```
+python demo.py --support [path_to_support_image] --query [path_to_query_image] --config configs/demo_b.py --checkpoint [path_to_pretrained_ckpt]
+```
+***Note:*** The demo code supports any config with suitable checkpoint file. More pre-trained models can be found in the evaluation section.
+
+
+## MP-100 Dataset
+Please follow the [official guide](https://github.com/luminxu/Pose-for-Everything/blob/main/mp100/README.md) to prepare the MP-100 dataset for training and evaluation, and organize the data structure properly.
+
+We provide an updated annotation file, which includes skeleton definitions, in the following [link](https://drive.google.com/drive/folders/1uRyGB-P5Tc_6TmAZ6RnOi0SWjGq9b28T?usp=sharing).
+
+**Please note:**
+
+Current version of the MP-100 dataset includes some discrepancies and filenames errors:
+1. Note that the mentioned DeepFasion dataset is actually DeepFashion2 dataset. The link in the official repo is wrong. Use this [repo](https://github.com/switchablenorms/DeepFashion2/tree/master) instead.
+2. We provide a script to fix CarFusion filename errors, which can be run by:
+```
+python tools/fix_carfusion.py [path_to_CarFusion_dataset] [path_to_mp100_annotation]
+```
+
+## Training
+
+### Backbone Options
+To use pre-trained Swin-Transformer as used in our paper, we provide the weights, taken from this [repo](https://github.com/microsoft/Swin-Transformer/blob/main/MODELHUB.md), in the following [link](https://drive.google.com/drive/folders/1-q4mSxlNAUwDlevc3Hm5Ij0l_2OGkrcg?usp=sharing).
+These should be placed in the `./pretrained` folder.
+
+We also support DINO and ResNet backbones. To use them, you can easily change the config file to use the desired backbone.
+This can be done by changing the `pretrained` field in the config file to `dinov2`, `dino` or `resnet` respectively (this will automatically load the pretrained weights from the official repo).
+
+### Training
+To train the model, run:
+```
+python train.py --config [path_to_config_file] --work-dir [path_to_work_dir]
+```
+
+## Evaluation and Pretrained Models
+You can download the pretrained checkpoints from following [link](https://drive.google.com/drive/folders/1RmrqzE3g0qYRD5xn54-aXEzrIkdYXpEW?usp=sharing).
+
+Here we provide the evaluation results of our pretrained models on MP-100 dataset along with the config files and checkpoints:
+
+### 1-Shot Models
+| Setting | split 1 | split 2 | split 3 | split 4 | split 5 |
+|:-------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|
+| Tiny | 91.06 | 88024 | 86.09 | 86.17 | 85.78 |
+| | [link](https://drive.google.com/file/d/1GubmkVkqybs-eD4hiRkgBzkUVGE_rIFX/view?usp=drive_link) / [config](configs/1shots/graph_split1_config.py) | [link](https://drive.google.com/file/d/1EEekDF3xV_wJOVk7sCQWUA8ygUKzEm2l/view?usp=drive_link) / [config](configs/1shots/graph_split2_config.py) | [link](https://drive.google.com/file/d/1FuwpNBdPI3mfSovta2fDGKoqJynEXPZQ/view?usp=drive_link) / [config](configs/1shots/graph_split3_config.py) | [link](https://drive.google.com/file/d/1_SSqSANuZlbC0utzIfzvZihAW9clefcR/view?usp=drive_link) / [config](configs/1shots/graph_split4_config.py) | [link](https://drive.google.com/file/d/1nUHr07W5F55u-FKQEPFq_CECgWZOKKLF/view?usp=drive_link) / [config](configs/1shots/graph_split5_config.py) |
+| Small | 93.66 | 90.42 | 89.79 | 88.68 | 89.61 |
+| | [link](https://drive.google.com/file/d/1RT1Q8AMEa1kj6k9ZqrtWIKyuR4Jn4Pqc/view?usp=drive_link) / [config](configs/1shot-swin/graph_split1_config.py) | [link](https://drive.google.com/file/d/1BT5b8MlnkflcdhTFiBROIQR3HccLsPQd/view?usp=drive_link) / [config](configs/1shot-swin/graph_split2_config.py) | [link](https://drive.google.com/file/d/1Z64cw_1CSDGObabSAWKnMK0BA_bqDHxn/view?usp=drive_link) / [config](configs/1shot-swin/graph_split3_config.py) | [link](https://drive.google.com/file/d/1vf82S8LAjIzpuBcbEoDCa26cR8DqNriy/view?usp=drive_link) / [config](configs/1shot-swin/graph_split4_config.py) | [link](https://drive.google.com/file/d/14FNx0JNbkS2CvXQMiuMU_kMZKFGO2rDV/view?usp=drive_link) / [config](configs/1shot-swin/graph_split5_config.py) |
+
+### 5-Shot Models
+| Setting | split 1 | split 2 | split 3 | split 4 | split 5 |
+|:-------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|
+| Tiny | 94.18 | 91.46 | 90.50 | 90.18 | 89.47 |
+| | [link](https://drive.google.com/file/d/1PeMuwv5YwiF3UCE5oN01Qchu5K3BaQ9L/view?usp=drive_link) / [config](configs/5shots/graph_split1_config.py) | [link](https://drive.google.com/file/d/1enIapPU1D8lZOET7q_qEjnhC1HFy3jWK/view?usp=drive_link) / [config](configs/5shots/graph_split2_config.py) | [link](https://drive.google.com/file/d/1MTeZ9Ba-ucLuqX0KBoLbBD5PaEct7VUp/view?usp=drive_link) / [config](configs/5shots/graph_split3_config.py) | [link](https://drive.google.com/file/d/1U2N7DI2F0v7NTnPCEEAgx-WKeBZNAFoa/view?usp=drive_link) / [config](configs/5shots/graph_split4_config.py) | [link](https://drive.google.com/file/d/1wapJDgtBWtmz61JNY7ktsFyvckRKiR2C/view?usp=drive_link) / [config](configs/5shots/graph_split5_config.py) |
+| Small | 96.51 | 92.15 | 91.99 | 92.01 | 92.36 |
+| | [link](https://drive.google.com/file/d/1p5rnA0MhmndSKEbyXMk49QXvNE03QV2p/view?usp=drive_link) / [config](configs/5shot-swin/graph_split1_config.py) | [link](https://drive.google.com/file/d/1Q3KNyUW_Gp3JytYxUPhkvXFiDYF6Hv8w/view?usp=drive_link) / [config](configs/5shot-swin/graph_split2_config.py) | [link](https://drive.google.com/file/d/1gWgTk720fSdAf_ze1FkfXTW0t7k-69dV/view?usp=drive_link) / [config](configs/5shot-swin/graph_split3_config.py) | [link](https://drive.google.com/file/d/1LuaRQ8a6AUPrkr7l5j2W6Fe_QbgASkwY/view?usp=drive_link) / [config](configs/5shot-swin/graph_split4_config.py) | [link](https://drive.google.com/file/d/1z--MAOPCwMG_GQXru9h2EStbnIvtHv1L/view?usp=drive_link) / [config](configs/5shot-swin/graph_split5_config.py) |
+
+### Evaluation
+The evaluation on a single GPU will take approximately 30 min.
+
+To evaluate the pretrained model, run:
+```
+python test.py [path_to_config_file] [path_to_pretrained_ckpt]
+```
+## Acknowledgement
+
+Our code is based on code from:
+ - [MMPose](https://github.com/open-mmlab/mmpose)
+ - [CapeFormer](https://github.com/flyinglynx/CapeFormer)
+
+
+## License
+This project is released under the Apache 2.0 license.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..15d4c5e5551490155fe01d9fe19348a6907fcfa7
--- /dev/null
+++ b/app.py
@@ -0,0 +1,320 @@
+import argparse
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+
+# os.system('python -m pip install timm')
+# os.system('python -m pip install -U openxlab')
+# os.system('python -m pip install -U pillow')
+# os.system('python -m pip install Openmim')
+# os.system('python -m mim install mmengine')
+os.system('python -m mim install "mmcv-full==1.6.2"')
+os.system('python -m mim install "mmpose==0.29.0"')
+os.system('python -m mim install "gradio==3.44.0"')
+os.system('python setup.py develop')
+
+import gradio as gr
+import numpy as np
+import torch
+from PIL import ImageDraw, Image
+from matplotlib import pyplot as plt
+from mmcv import Config
+from mmcv.runner import load_checkpoint
+from mmpose.core import wrap_fp16_model
+from mmpose.models import build_posenet
+from torchvision import transforms
+from demo import Resize_Pad
+from models import *
+import matplotlib
+
+matplotlib.use('agg')
+
+
+def plot_results(support_img, query_img, support_kp, support_w, query_kp,
+ query_w, skeleton,
+ initial_proposals, prediction, radius=6):
+ h, w, c = support_img.shape
+ prediction = prediction[-1].cpu().numpy() * h
+ query_img = (query_img - np.min(query_img)) / (
+ np.max(query_img) - np.min(query_img))
+ for id, (img, w, keypoint) in enumerate(zip([query_img],
+ [query_w],
+ [prediction])):
+ f, axes = plt.subplots()
+ plt.imshow(img)
+ for k in range(keypoint.shape[0]):
+ if w[k] > 0:
+ kp = keypoint[k, :2]
+ c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)
+ patch = plt.Circle(kp, radius, color=c)
+ axes.add_patch(patch)
+ axes.text(kp[0], kp[1], k)
+ plt.draw()
+ for l, limb in enumerate(skeleton):
+ kp = keypoint[:, :2]
+ if l > len(COLORS) - 1:
+ c = [x / 255 for x in random.sample(range(0, 255), 3)]
+ else:
+ c = [x / 255 for x in COLORS[l]]
+ if w[limb[0]] > 0 and w[limb[1]] > 0:
+ patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]],
+ [kp[limb[0], 1], kp[limb[1], 1]],
+ linewidth=6, color=c, alpha=0.6)
+ axes.add_artist(patch)
+ plt.axis('off') # command for hiding the axis.
+ plt.subplots_adjust(0, 0, 1, 1, 0, 0)
+ return plt
+
+
+COLORS = [
+ [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
+ [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
+ [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
+ [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
+]
+
+kp_src = []
+skeleton = []
+count = 0
+color_idx = 0
+prev_pt = None
+prev_pt_idx = None
+prev_clicked = None
+original_support_image = None
+checkpoint_path = ''
+
+def process(query_img,
+ cfg_path='configs/demo_b.py'):
+ global skeleton
+ cfg = Config.fromfile(cfg_path)
+ kp_src_np = np.array(kp_src).copy().astype(np.float32)
+ kp_src_np[:, 0] = kp_src_np[:, 0] / 128. * cfg.model.encoder_config.img_size
+ kp_src_np[:, 1] = kp_src_np[:, 1] / 128. * cfg.model.encoder_config.img_size
+ kp_src_np = np.flip(kp_src_np, 1).copy()
+ kp_src_tensor = torch.tensor(kp_src_np).float()
+ preprocess = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ Resize_Pad(cfg.model.encoder_config.img_size,
+ cfg.model.encoder_config.img_size)])
+
+ if len(skeleton) == 0:
+ skeleton = [(0, 0)]
+
+ support_img = preprocess(original_support_image).flip(0)[None]
+ np_query = np.array(query_img)[:, :, ::-1].copy()
+ q_img = preprocess(np_query).flip(0)[None]
+ # Create heatmap from keypoints
+ genHeatMap = TopDownGenerateTargetFewShot()
+ data_cfg = cfg.data_cfg
+ data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size,
+ cfg.model.encoder_config.img_size])
+ data_cfg['joint_weights'] = None
+ data_cfg['use_different_joint_weights'] = False
+ kp_src_3d = torch.concatenate(
+ (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
+ kp_src_3d_weight = torch.concatenate(
+ (torch.ones_like(kp_src_tensor),
+ torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
+ target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
+ kp_src_3d,
+ kp_src_3d_weight,
+ sigma=1)
+ target_s = torch.tensor(target_s).float()[None]
+ target_weight_s = torch.ones_like(
+ torch.tensor(target_weight_s).float()[None])
+
+ data = {
+ 'img_s': [support_img],
+ 'img_q': q_img,
+ 'target_s': [target_s],
+ 'target_weight_s': [target_weight_s],
+ 'target_q': None,
+ 'target_weight_q': None,
+ 'return_loss': False,
+ 'img_metas': [{'sample_skeleton': [skeleton],
+ 'query_skeleton': skeleton,
+ 'sample_joints_3d': [kp_src_3d],
+ 'query_joints_3d': kp_src_3d,
+ 'sample_center': [kp_src_tensor.mean(dim=0)],
+ 'query_center': kp_src_tensor.mean(dim=0),
+ 'sample_scale': [
+ kp_src_tensor.max(dim=0)[0] -
+ kp_src_tensor.min(dim=0)[0]],
+ 'query_scale': kp_src_tensor.max(dim=0)[0] -
+ kp_src_tensor.min(dim=0)[0],
+ 'sample_rotation': [0],
+ 'query_rotation': 0,
+ 'sample_bbox_score': [1],
+ 'query_bbox_score': 1,
+ 'query_image_file': '',
+ 'sample_image_file': [''],
+ }]
+ }
+ # Load model
+ model = build_posenet(cfg.model)
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**data)
+ # visualize results
+ vis_s_weight = target_weight_s[0]
+ vis_q_weight = target_weight_s[0]
+ vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
+ vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0)
+ support_kp = kp_src_3d
+ out = plot_results(vis_s_image,
+ vis_q_image,
+ support_kp,
+ vis_s_weight,
+ None,
+ vis_q_weight,
+ skeleton,
+ None,
+ torch.tensor(outputs['points']).squeeze(0),
+ )
+ return out
+
+
+with gr.Blocks() as demo:
+ gr.Markdown('''
+ # Pose Anything Demo
+ We present a novel approach to category agnostic pose estimation that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities.
+ ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything)
+ 
+ ## Instructions
+ 1. Upload an image of the object you want to pose on the **left** image.
+ 2. Click on the **left** image to mark keypoints.
+ 3. Click on the keypoints on the **right** image to mark limbs.
+ 4. Upload an image of the object you want to pose to the query image (**bottom**).
+ 5. Click **Evaluate** to pose the query image.
+ ''')
+ with gr.Row():
+ support_img = gr.Image(label="Support Image",
+ type="pil",
+ info='Click to mark keypoints').style(
+ height=256, width=256)
+ posed_support = gr.Image(label="Posed Support Image",
+ type="pil",
+ interactive=False).style(height=256, width=256)
+ with gr.Row():
+ query_img = gr.Image(label="Query Image",
+ type="pil").style(height=256, width=256)
+ with gr.Row():
+ eval_btn = gr.Button(value="Evaluate")
+ with gr.Row():
+ output_img = gr.Plot(label="Output Image", height=256, width=256)
+
+
+ def get_select_coords(kp_support,
+ limb_support,
+ evt: gr.SelectData,
+ r=0.015):
+ pixels_in_queue = set()
+ pixels_in_queue.add((evt.index[1], evt.index[0]))
+ while len(pixels_in_queue) > 0:
+ pixel = pixels_in_queue.pop()
+ if pixel[0] is not None and pixel[
+ 1] is not None and pixel not in kp_src:
+ kp_src.append(pixel)
+ else:
+ print("Invalid pixel")
+ if limb_support is None:
+ canvas_limb = kp_support
+ else:
+ canvas_limb = limb_support
+ canvas_kp = kp_support
+ w, h = canvas_kp.size
+ draw_pose = ImageDraw.Draw(canvas_kp)
+ draw_limb = ImageDraw.Draw(canvas_limb)
+ r = int(r * w)
+ leftUpPoint = (pixel[1] - r, pixel[0] - r)
+ rightDownPoint = (pixel[1] + r, pixel[0] + r)
+ twoPointList = [leftUpPoint, rightDownPoint]
+ draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
+ draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
+
+ return canvas_kp, canvas_limb
+
+
+ def get_limbs(kp_support,
+ evt: gr.SelectData,
+ r=0.02, width=0.02):
+ global count, color_idx, prev_pt, skeleton, prev_pt_idx, prev_clicked
+ curr_pixel = (evt.index[1], evt.index[0])
+ pixels_in_queue = set()
+ pixels_in_queue.add((evt.index[1], evt.index[0]))
+ canvas_kp = kp_support
+ w, h = canvas_kp.size
+ r = int(r * w)
+ width = int(width * w)
+ while (len(pixels_in_queue) > 0 and
+ curr_pixel != prev_clicked and
+ len(kp_src) > 0):
+ pixel = pixels_in_queue.pop()
+ prev_clicked = pixel
+ closest_point = min(kp_src,
+ key=lambda p: (p[0] - pixel[0]) ** 2 +
+ (p[1] - pixel[1]) ** 2)
+ closest_point_index = kp_src.index(closest_point)
+ draw_limb = ImageDraw.Draw(canvas_kp)
+ if color_idx < len(COLORS):
+ c = COLORS[color_idx]
+ else:
+ c = random.choices(range(256), k=3)
+ leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
+ rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
+ twoPointList = [leftUpPoint, rightDownPoint]
+ draw_limb.ellipse(twoPointList, fill=tuple(c))
+ if count == 0:
+ prev_pt = closest_point[1], closest_point[0]
+ prev_pt_idx = closest_point_index
+ count = count + 1
+ else:
+ if prev_pt_idx != closest_point_index:
+ # Create Line and add Limb
+ draw_limb.line([prev_pt, (closest_point[1], closest_point[0])],
+ fill=tuple(c),
+ width=width)
+ skeleton.append((prev_pt_idx, closest_point_index))
+ color_idx = color_idx + 1
+ else:
+ draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
+ count = 0
+ return canvas_kp
+
+
+ def set_query(support_img):
+ global original_support_image
+ skeleton.clear()
+ kp_src.clear()
+ original_support_image = np.array(support_img)[:, :, ::-1].copy()
+ support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
+ return support_img, support_img
+
+
+ support_img.select(get_select_coords,
+ [support_img, posed_support],
+ [support_img, posed_support],
+ )
+ support_img.upload(set_query,
+ inputs=support_img,
+ outputs=[support_img,posed_support])
+ posed_support.select(get_limbs,
+ posed_support,
+ posed_support)
+ eval_btn.click(fn=process,
+ inputs=[query_img],
+ outputs=output_img)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Pose Anything Demo')
+ parser.add_argument('--checkpoint',
+ help='checkpoint path',
+ default='https://huggingface.co/orhir/PoseAnything/blob/main/1shot-swin_graph_split1.pth')
+ args = parser.parse_args()
+ checkpoint_path = args.checkpoint
+ demo.launch()
diff --git a/configs/1shot-swin/base_split1_config.py b/configs/1shot-swin/base_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8860b3410374a46a43a89d63398a2841fefdf073
--- /dev/null
+++ b/configs/1shot-swin/base_split1_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/base_split2_config.py b/configs/1shot-swin/base_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c23ce58f0b1dfa4470f36d5cbf92e9e5b5e4061
--- /dev/null
+++ b/configs/1shot-swin/base_split2_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/base_split3_config.py b/configs/1shot-swin/base_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..906b444a4240e748d13061508131bb2b59a3db35
--- /dev/null
+++ b/configs/1shot-swin/base_split3_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/base_split4_config.py b/configs/1shot-swin/base_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..449a1a87720523d18235c0d5bac6a60dce4e2ec8
--- /dev/null
+++ b/configs/1shot-swin/base_split4_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/base_split5_config.py b/configs/1shot-swin/base_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..09347fa0ce0dd4ae94f968ab4a2a66dc091754ff
--- /dev/null
+++ b/configs/1shot-swin/base_split5_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/graph_split1_config.py b/configs/1shot-swin/graph_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc75313f9493abbdabd5b1af8e6ffbf6cee00ba
--- /dev/null
+++ b/configs/1shot-swin/graph_split1_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/graph_split2_config.py b/configs/1shot-swin/graph_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8083cac1ed6b8e13c6a2f9e8c16df165f5aea197
--- /dev/null
+++ b/configs/1shot-swin/graph_split2_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/graph_split3_config.py b/configs/1shot-swin/graph_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..57cc8b6704e635308de8c3cbbbb446080c526cdf
--- /dev/null
+++ b/configs/1shot-swin/graph_split3_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/graph_split4_config.py b/configs/1shot-swin/graph_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a3c0039ec6df175100b66ee82b43c342eb73fc8
--- /dev/null
+++ b/configs/1shot-swin/graph_split4_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shot-swin/graph_split5_config.py b/configs/1shot-swin/graph_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a115d8385470862b8ce465db64a8c860761417fe
--- /dev/null
+++ b/configs/1shot-swin/graph_split5_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/base_split1_config.py b/configs/1shots/base_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9342eb1bbee19b0dee29b6c1f1ac751c0c36ef0a
--- /dev/null
+++ b/configs/1shots/base_split1_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/base_split2_config.py b/configs/1shots/base_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e80cf01cbca1767b4feebcea4b4ff35ca180ae4f
--- /dev/null
+++ b/configs/1shots/base_split2_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/base_split3_config.py b/configs/1shots/base_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..770e861e54dc49bbd539297d15c44fe6aba487d6
--- /dev/null
+++ b/configs/1shots/base_split3_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/base_split4_config.py b/configs/1shots/base_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5fe0672d598b133e0f0527f2b8b00670ded109a
--- /dev/null
+++ b/configs/1shots/base_split4_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/base_split5_config.py b/configs/1shots/base_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d8871cdd132b3ae8c41091e756a41d31356b89
--- /dev/null
+++ b/configs/1shots/base_split5_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/graph_split1_config.py b/configs/1shots/graph_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf21ccf58c08d61dddb46f9fb91bc34ee78871c
--- /dev/null
+++ b/configs/1shots/graph_split1_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/graph_split2_config.py b/configs/1shots/graph_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b7f6d35c538ad916cd431df9a81ce2ab6aed364
--- /dev/null
+++ b/configs/1shots/graph_split2_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/graph_split3_config.py b/configs/1shots/graph_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..828343883ae89403453109331674c57a1a316922
--- /dev/null
+++ b/configs/1shots/graph_split3_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/graph_split4_config.py b/configs/1shots/graph_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dda7d8d744f08ef233b62b182b92cc6407420311
--- /dev/null
+++ b/configs/1shots/graph_split4_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/1shots/graph_split5_config.py b/configs/1shots/graph_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde7be1a300b8223fc078dae35fd197d18a5f2d
--- /dev/null
+++ b/configs/1shots/graph_split5_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/base_split1_config.py b/configs/5shot-swin/base_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..61437a3e34ce64a467ae6d942b6ffbb2222b3524
--- /dev/null
+++ b/configs/5shot-swin/base_split1_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/base_split2_config.py b/configs/5shot-swin/base_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f71c2b80d54ba94a1848e6dd0e12964aceee6d
--- /dev/null
+++ b/configs/5shot-swin/base_split2_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/base_split3_config.py b/configs/5shot-swin/base_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..739deccb183b8b0b5ea9238133200950efd47e7b
--- /dev/null
+++ b/configs/5shot-swin/base_split3_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/base_split4_config.py b/configs/5shot-swin/base_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..152e75fd729f7fcb737812ea12bf3ebc1bdba96b
--- /dev/null
+++ b/configs/5shot-swin/base_split4_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/base_split5_config.py b/configs/5shot-swin/base_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a9a8a4aa3284a6d8d6e550fd3fefe23ee538277
--- /dev/null
+++ b/configs/5shot-swin/base_split5_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/graph_split1_config.py b/configs/5shot-swin/graph_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..49af4b51dfa069cb86b6780f27187b2baf0c4c55
--- /dev/null
+++ b/configs/5shot-swin/graph_split1_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/graph_split2_config.py b/configs/5shot-swin/graph_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9004a1bdbb321392dde71fb75ade25d78f3db656
--- /dev/null
+++ b/configs/5shot-swin/graph_split2_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/graph_split3_config.py b/configs/5shot-swin/graph_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfeefa53e98045a2ef41ae81c72ef20a98bf4f8e
--- /dev/null
+++ b/configs/5shot-swin/graph_split3_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/graph_split4_config.py b/configs/5shot-swin/graph_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbd51d5d5d0b3643147c823f498225db42fe5e4d
--- /dev/null
+++ b/configs/5shot-swin/graph_split4_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shot-swin/graph_split5_config.py b/configs/5shot-swin/graph_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c95a5040c661fb267e571b2af0f3a7c92d6aabc8
--- /dev/null
+++ b/configs/5shot-swin/graph_split5_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_base_22k_500k.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/base_split1_config.py b/configs/5shots/base_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..09813293fc94bbe1adc89fa85b40e00e26becc02
--- /dev/null
+++ b/configs/5shots/base_split1_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/base_split2_config.py b/configs/5shots/base_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a726688e3d3804b661a70d4fa525bc390a6e3386
--- /dev/null
+++ b/configs/5shots/base_split2_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/base_split3_config.py b/configs/5shots/base_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8825d2444856ba2819a656b218f21edd62ec002a
--- /dev/null
+++ b/configs/5shots/base_split3_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/base_split4_config.py b/configs/5shots/base_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d61a56b3958153255d46298202a3b5270866ec9
--- /dev/null
+++ b/configs/5shots/base_split4_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/base_split5_config.py b/configs/5shots/base_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f35a5f03cec585e7992e589109fd969a84c7bb5f
--- /dev/null
+++ b/configs/5shots/base_split5_config.py
@@ -0,0 +1,190 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/graph_split1_config.py b/configs/5shots/graph_split1_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c62b520877f9230f85dc680c9e9a849c7cea15a
--- /dev/null
+++ b/configs/5shots/graph_split1_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/graph_split2_config.py b/configs/5shots/graph_split2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d4ebf7942a10fd9018911f8554895feacf9202c
--- /dev/null
+++ b/configs/5shots/graph_split2_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split2_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/graph_split3_config.py b/configs/5shots/graph_split3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..06a5afd6b1702ad26e050e4268baca585f64d72f
--- /dev/null
+++ b/configs/5shots/graph_split3_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split3_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/graph_split4_config.py b/configs/5shots/graph_split4_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c0b83dd5407c4beac9bc66d0dbf1db9aa8b735e
--- /dev/null
+++ b/configs/5shots/graph_split4_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split4_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/5shots/graph_split5_config.py b/configs/5shots/graph_split5_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f4ef2e52ba7ae3d7780c79bd83b131f7c1a9f4
--- /dev/null
+++ b/configs/5shots/graph_split5_config.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='pretrained/swinv2_tiny_patch4_window16_256.pth',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=16,
+ drop_path_rate=0.2,
+ img_size=256,
+ upsample="bilinear"
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=768,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=768,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split5_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=5,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/demo.py b/configs/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7d26e61d791010a5ecf26bfa44327d3e74b3d23
--- /dev/null
+++ b/configs/demo.py
@@ -0,0 +1,194 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='TransformerPoseTwoStage',
+ pretrained='swinv2_large',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=16,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.2,
+ img_size=256,
+ ),
+ keypoint_head=dict(
+ type='TwoStageHead',
+ in_channels=1536,
+ transformer=dict(
+ type='TwoStageSupportRefineTransformer',
+ d_model=384,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ dim_feedforward=1536,
+ dropout=0.1,
+ similarity_proj_dim=384,
+ dynamic_proj_dim=192,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+ support_pos_embed=False,
+ heatmap_loss_weight=2.0,
+ skeleton_loss_weight=0.02,
+ num_samples=0,
+ support_embedding_type="fixed",
+ num_support=100,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=192, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[256, 256],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_all.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/configs/demo_b.py b/configs/demo_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..a90293cfec7a91719691275b2fbd8a9a48a3ec7c
--- /dev/null
+++ b/configs/demo_b.py
@@ -0,0 +1,191 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=20)
+evaluation = dict(
+ interval=25,
+ metric=['PCK', 'NME', 'AUC', 'EPE'],
+ key_indicator='PCK',
+ gpu_collect=True,
+ res_folder='')
+optimizer = dict(
+ type='Adam',
+ lr=1e-5,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[160, 180])
+total_epochs = 200
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=1,
+ dataset_joints=1,
+ dataset_channel=[
+ [
+ 0,
+ ],
+ ],
+ inference_channel=[
+ 0,
+ ],
+ max_kpt_num=100)
+
+# model settings
+model = dict(
+ type='PoseAnythingModel',
+ pretrained='swinv2_base',
+ encoder_config=dict(
+ type='SwinTransformerV2',
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=14,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ drop_path_rate=0.1,
+ img_size=224,
+ ),
+ keypoint_head=dict(
+ type='PoseHead',
+ in_channels=1024,
+ transformer=dict(
+ type='EncoderDecoder',
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder='pre',
+ dim_feedforward=1024,
+ dropout=0.1,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=True,
+
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True)),
+ # training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=False,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[224, 224],
+ heatmap_size=[64, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=15,
+ scale_factor=0.15),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs', 'category_id', 'skeleton',
+ ]),
+]
+
+valid_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffineFewShot'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTargetFewShot', sigma=1),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs', 'category_id',
+ 'skeleton',
+ ]),
+]
+
+test_pipeline = valid_pipeline
+
+data_root = 'data/mp100'
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_train.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TransformerPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_val.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=100,
+ pipeline=valid_pipeline),
+ test=dict(
+ type='TestPoseDataset',
+ ann_file=f'{data_root}/annotations/mp100_split1_test.json',
+ img_prefix=f'{data_root}/images/',
+ # img_prefix=f'{data_root}',
+ data_cfg=data_cfg,
+ valid_class_ids=None,
+ max_kpt_num=channel_cfg['max_kpt_num'],
+ num_shots=1,
+ num_queries=15,
+ num_episodes=200,
+ pck_threshold_list=[0.05, 0.10, 0.15, 0.2, 0.25],
+ pipeline=test_pipeline),
+)
+vis_backends = [
+ dict(type='LocalVisBackend'),
+ dict(type='TensorboardVisBackend'),
+]
+visualizer = dict(
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+
+shuffle_cfg = dict(interval=1)
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..82d270fd7a86731e2d697721c72e1dfe0d66df86
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,289 @@
+import argparse
+import copy
+import os
+import pickle
+import random
+import cv2
+import numpy as np
+import torch
+from mmcv import Config, DictAction
+from mmcv.cnn import fuse_conv_bn
+from mmcv.runner import load_checkpoint
+from mmpose.core import wrap_fp16_model
+from mmpose.models import build_posenet
+from torchvision import transforms
+from models import *
+import torchvision.transforms.functional as F
+
+from tools.visualization import plot_results
+
+COLORS = [
+ [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
+ [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
+ [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
+ [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]]
+
+class Resize_Pad:
+ def __init__(self, w=256, h=256):
+ self.w = w
+ self.h = h
+
+ def __call__(self, image):
+ _, w_1, h_1 = image.shape
+ ratio_1 = w_1 / h_1
+ # check if the original and final aspect ratios are the same within a margin
+ if round(ratio_1, 2) != 1:
+ # padding to preserve aspect ratio
+ if ratio_1 > 1: # Make the image higher
+ hp = int(w_1 - h_1)
+ hp = hp // 2
+ image = F.pad(image, (hp, 0, hp, 0), 0, "constant")
+ return F.resize(image, [self.h, self.w])
+ else:
+ wp = int(h_1 - w_1)
+ wp = wp // 2
+ image = F.pad(image, (0, wp, 0, wp), 0, "constant")
+ return F.resize(image, [self.h, self.w])
+ else:
+ return F.resize(image, [self.h, self.w])
+
+
+def transform_keypoints_to_pad_and_resize(keypoints, image_size):
+ trans_keypoints = keypoints.clone()
+ h, w = image_size[:2]
+ ratio_1 = w / h
+ if ratio_1 > 1:
+ # width is bigger than height - pad height
+ hp = int(w - h)
+ hp = hp // 2
+ trans_keypoints[:, 1] = keypoints[:, 1] + hp
+ trans_keypoints *= (256. / w)
+ else:
+ # height is bigger than width - pad width
+ wp = int(image_size[1] - image_size[0])
+ wp = wp // 2
+ trans_keypoints[:, 0] = keypoints[:, 0] + wp
+ trans_keypoints *= (256. / h)
+ return trans_keypoints
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Pose Anything Demo')
+ parser.add_argument('--support', help='Image file')
+ parser.add_argument('--query', help='Image file')
+ parser.add_argument('--config', default=None, help='test config file path')
+ parser.add_argument('--checkpoint', default=None, help='checkpoint file')
+ parser.add_argument('--outdir', default='output', help='checkpoint file')
+
+ parser.add_argument(
+ '--fuse-conv-bn',
+ action='store_true',
+ help='Whether to fuse conv and bn, this will slightly increase'
+ 'the inference speed')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ default={},
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. For example, '
+ "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+ args = parser.parse_args()
+ return args
+
+
+def merge_configs(cfg1, cfg2):
+ # Merge cfg2 into cfg1
+ # Overwrite cfg1 if repeated, ignore if value is None.
+ cfg1 = {} if cfg1 is None else cfg1.copy()
+ cfg2 = {} if cfg2 is None else cfg2
+ for k, v in cfg2.items():
+ if v:
+ cfg1[k] = v
+ return cfg1
+
+
+def main():
+ random.seed(0)
+ np.random.seed(0)
+ torch.manual_seed(0)
+
+ args = parse_args()
+ cfg = Config.fromfile(args.config)
+
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ cfg.data.test.test_mode = True
+
+ os.makedirs(args.outdir, exist_ok=True)
+
+ # Load data
+ support_img = cv2.imread(args.support)
+ query_img = cv2.imread(args.query)
+ if support_img is None or query_img is None:
+ raise ValueError('Fail to read images')
+
+ preprocess = transforms.Compose([
+ transforms.ToTensor(),
+ Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)])
+
+ # frame = copy.deepcopy(support_img)
+ padded_support_img = preprocess(support_img).cpu().numpy().transpose(1, 2, 0) * 255
+ frame = copy.deepcopy(padded_support_img.astype(np.uint8).copy())
+ kp_src = []
+ skeleton = []
+ count = 0
+ prev_pt = None
+ prev_pt_idx = None
+ color_idx = 0
+
+ def selectKP(event, x, y, flags, param):
+ nonlocal kp_src, frame
+ # if we are in points selection mode, the mouse was clicked,
+ # list of points with the (x, y) location of the click
+ # and draw the circle
+
+ if event == cv2.EVENT_LBUTTONDOWN:
+ kp_src.append((x, y))
+ cv2.circle(frame, (x, y), 2, (0, 0, 255), 1)
+ cv2.imshow("Source", frame)
+
+ if event == cv2.EVENT_RBUTTONDOWN:
+ kp_src = []
+ frame = copy.deepcopy(support_img)
+ cv2.imshow("Source", frame)
+
+ def draw_line(event, x, y, flags, param):
+ nonlocal skeleton, kp_src, frame, count, prev_pt, prev_pt_idx, marked_frame, color_idx
+ if event == cv2.EVENT_LBUTTONDOWN:
+ closest_point = min(kp_src, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2)
+ closest_point_index = kp_src.index(closest_point)
+ if color_idx < len(COLORS):
+ c = COLORS[color_idx]
+ else:
+ c = random.choices(range(256), k=3)
+ color = color_idx
+ cv2.circle(frame, closest_point, 2, c, 1)
+ if count == 0:
+ prev_pt = closest_point
+ prev_pt_idx = closest_point_index
+ count = count + 1
+ cv2.imshow("Source", frame)
+ else:
+ cv2.line(frame, prev_pt, closest_point, c, 2)
+ cv2.imshow("Source", frame)
+ count = 0
+ skeleton.append((prev_pt_idx, closest_point_index))
+ color_idx = color_idx + 1
+ elif event == cv2.EVENT_RBUTTONDOWN:
+ frame = copy.deepcopy(marked_frame)
+ cv2.imshow("Source", frame)
+ count = 0
+ color_idx = 0
+ skeleton = []
+ prev_pt = None
+
+ cv2.namedWindow("Source", cv2.WINDOW_NORMAL)
+ cv2.resizeWindow('Source', 800, 600)
+ cv2.setMouseCallback("Source", selectKP)
+ cv2.imshow("Source", frame)
+
+ # keep looping until points have been selected
+ print('Press any key when finished marking the points!! ')
+ while True:
+ if cv2.waitKey(1) > 0:
+ break
+
+ marked_frame = copy.deepcopy(frame)
+ cv2.setMouseCallback("Source", draw_line)
+ print('Press any key when finished creating skeleton!!')
+ while True:
+ if cv2.waitKey(1) > 0:
+ break
+
+ cv2.destroyAllWindows()
+ kp_src = torch.tensor(kp_src).float()
+ preprocess = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)])
+
+ if len(skeleton) == 0:
+ skeleton = [(0, 0)]
+
+ support_img = preprocess(support_img).flip(0)[None]
+ query_img = preprocess(query_img).flip(0)[None]
+ # Create heatmap from keypoints
+ genHeatMap = TopDownGenerateTargetFewShot()
+ data_cfg = cfg.data_cfg
+ data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size])
+ data_cfg['joint_weights'] = None
+ data_cfg['use_different_joint_weights'] = False
+ kp_src_3d = torch.concatenate((kp_src, torch.zeros(kp_src.shape[0], 1)), dim=-1)
+ kp_src_3d_weight = torch.concatenate((torch.ones_like(kp_src), torch.zeros(kp_src.shape[0], 1)), dim=-1)
+ target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1)
+ target_s = torch.tensor(target_s).float()[None]
+ target_weight_s = torch.tensor(target_weight_s).float()[None]
+
+ data = {
+ 'img_s': [support_img],
+ 'img_q': query_img,
+ 'target_s': [target_s],
+ 'target_weight_s': [target_weight_s],
+ 'target_q': None,
+ 'target_weight_q': None,
+ 'return_loss': False,
+ 'img_metas': [{'sample_skeleton': [skeleton],
+ 'query_skeleton': skeleton,
+ 'sample_joints_3d': [kp_src_3d],
+ 'query_joints_3d': kp_src_3d,
+ 'sample_center': [kp_src.mean(dim=0)],
+ 'query_center': kp_src.mean(dim=0),
+ 'sample_scale': [kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0]],
+ 'query_scale': kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0],
+ 'sample_rotation': [0],
+ 'query_rotation': 0,
+ 'sample_bbox_score': [1],
+ 'query_bbox_score': 1,
+ 'query_image_file': '',
+ 'sample_image_file': [''],
+ }]
+ }
+
+ # Load model
+ model = build_posenet(cfg.model)
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ load_checkpoint(model, args.checkpoint, map_location='cpu')
+ if args.fuse_conv_bn:
+ model = fuse_conv_bn(model)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**data)
+
+ # visualize results
+ vis_s_weight = target_weight_s[0]
+ vis_q_weight = target_weight_s[0]
+ vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
+ vis_q_image = query_img[0].detach().cpu().numpy().transpose(1, 2, 0)
+ support_kp = kp_src_3d
+
+ plot_results(vis_s_image,
+ vis_q_image,
+ support_kp,
+ vis_s_weight,
+ None,
+ vis_q_weight,
+ skeleton,
+ None,
+ torch.tensor(outputs['points']).squeeze(0),
+ out_dir=args.outdir)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..1241d78a7b16acb6330508151c931472cd79f871
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,50 @@
+ARG PYTORCH="2.0.1"
+ARG CUDA="11.7"
+ARG CUDNN="8"
+
+FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
+
+ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
+ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
+ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
+ENV TZ=Asia/Kolkata DEBIAN_FRONTEND=noninteractive
+# To fix GPG key error when running apt-get update
+RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
+RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
+
+RUN apt-get update && apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx\
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install xtcocotools
+RUN pip install cython
+RUN pip install xtcocotools
+# Install MMEngine and MMCV
+RUN pip install openmim
+RUN mim install mmengine
+RUN mim install "mmpose==0.28.1"
+RUN mim install "mmcv-full==1.5.3"
+RUN pip install -U torchmetrics timm
+RUN pip install numpy scipy --upgrade
+RUN pip install future tensorboard
+
+WORKDIR PoseAnything
+
+COPY models PoseAnything/models
+COPY configs PoseAnything/configs
+COPY pretrained PoseAnything/pretrained
+COPY requirements.txt PoseAnything/
+COPY tools PoseAnything/tools
+COPY setup.cfg PoseAnything/
+COPY setup.py PoseAnything/
+COPY test.py PoseAnything/
+COPY train.py PoseAnything/
+COPY README.md PoseAnything/
+
+RUN mkdir -p PoseAnything/data/mp100
+WORKDIR PoseAnything
+
+# Install MMPose
+RUN conda clean --all
+ENV FORCE_CUDA="1"
+RUN python setup.py develop
\ No newline at end of file
diff --git a/gradio_teaser.png b/gradio_teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..7deea0ceb765af763283898f9ff6ffe856b72c9d
Binary files /dev/null and b/gradio_teaser.png differ
diff --git a/models/VERSION b/models/VERSION
new file mode 100644
index 0000000000000000000000000000000000000000..0ea3a944b399d25f7e1b8fe684d754eb8da9fe7f
--- /dev/null
+++ b/models/VERSION
@@ -0,0 +1 @@
+0.2.0
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12119c91b4a54a136b3a13a5a695bfa90d27ea8
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,3 @@
+from .core import * # noqa
+from .datasets import * # noqa
+from .models import * # noqa
diff --git a/models/apis/__init__.py b/models/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..500c844f99bf7725c185e94c289ccf5613d09da5
--- /dev/null
+++ b/models/apis/__init__.py
@@ -0,0 +1,5 @@
+from .train import train_model
+
+__all__ = [
+ 'train_model'
+]
diff --git a/models/apis/train.py b/models/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f428ad292a7947d31625290779c6ca444ecda4
--- /dev/null
+++ b/models/apis/train.py
@@ -0,0 +1,126 @@
+import os
+
+import torch
+from models.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
+ build_optimizer)
+from mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook
+from mmpose.datasets import build_dataloader
+from mmpose.utils import get_root_logger
+
+
+def train_model(model,
+ dataset,
+ val_dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ """Train model entry function.
+
+ Args:
+ model (nn.Module): The model to be trained.
+ dataset (Dataset): Train dataset.
+ cfg (dict): The config dict for training.
+ distributed (bool): Whether to use distributed training.
+ Default: False.
+ validate (bool): Whether to do evaluation. Default: False.
+ timestamp (str | None): Local time for runner. Default: None.
+ meta (dict | None): Meta dict to record some important information.
+ Default: None
+ """
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ dataloader_setting = dict(
+ samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
+ workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
+ # cfg.gpus will be ignored if distributed
+ num_gpus=len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ pin_memory=False,
+ )
+ dataloader_setting = dict(dataloader_setting,
+ **cfg.data.get('train_dataloader', {}))
+
+ data_loaders = [
+ build_dataloader(ds, **dataloader_setting) for ds in dataset
+ ]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters',
+ False) # NOTE: True has been modified to False for faster training.
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+ runner = EpochBasedRunner(
+ model,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta)
+ # an ugly workaround to make .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # fp16 setting
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
+ elif distributed and 'type' not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+ if distributed:
+ runner.register_hook(DistSamplerSeedHook())
+
+ shuffle_cfg = cfg.get('shuffle_cfg', None)
+ if shuffle_cfg is not None:
+ for data_loader in data_loaders:
+ runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg))
+
+ # register eval hooks
+ if validate:
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder'])
+ dataloader_setting = dict(
+ # samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
+ # cfg.gpus will be ignored if distributed
+ num_gpus=len(cfg.gpu_ids),
+ dist=distributed,
+ shuffle=False,
+ pin_memory=False,
+ )
+ dataloader_setting = dict(dataloader_setting,
+ **cfg.data.get('val_dataloader', {}))
+ val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
diff --git a/models/core/__init__.py b/models/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/core/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/core/custom_hooks/shuffle_hooks.py b/models/core/custom_hooks/shuffle_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..43c3c28995939a34e510005a68a97ae2e6b5f907
--- /dev/null
+++ b/models/core/custom_hooks/shuffle_hooks.py
@@ -0,0 +1,29 @@
+from mmcv.runner import Hook
+from mmpose.utils import get_root_logger
+from torch.utils.data import DataLoader
+
+
+class ShufflePairedSamplesHook(Hook):
+ """Non-Distributed ShufflePairedSamples.
+ After each training epoch, run FewShotKeypointDataset.random_paired_samples()
+ """
+
+ def __init__(self,
+ dataloader,
+ interval=1):
+ if not isinstance(dataloader, DataLoader):
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
+ f'but got {type(dataloader)}')
+
+ self.dataloader = dataloader
+ self.interval = interval
+ self.logger = get_root_logger()
+
+ def after_train_epoch(self, runner):
+ """Called after every training epoch to evaluate the results."""
+ if not self.every_n_epochs(runner, self.interval):
+ return
+ # self.logger.info("Run random_paired_samples()")
+ # self.logger.info(f"Before: {self.dataloader.dataset.paired_samples[0]}")
+ self.dataloader.dataset.random_paired_samples()
+ # self.logger.info(f"After: {self.dataloader.dataset.paired_samples[0]}")
diff --git a/models/datasets/__init__.py b/models/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..25529624cf32c145ca0bf686af2899bb386d5d28
--- /dev/null
+++ b/models/datasets/__init__.py
@@ -0,0 +1,3 @@
+from .builder import * # noqa
+from .datasets import * # noqa
+from .pipelines import * # noqa
diff --git a/models/datasets/builder.py b/models/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f75d7db623cdd0498a13d823ffb563fe6d2a4de
--- /dev/null
+++ b/models/datasets/builder.py
@@ -0,0 +1,54 @@
+from mmcv.utils import build_from_cfg
+from mmpose.datasets.builder import DATASETS
+from mmpose.datasets.dataset_wrappers import RepeatDataset
+from torch.utils.data.dataset import ConcatDataset
+
+
+def _concat_cfg(cfg):
+ replace = ['ann_file', 'img_prefix']
+ channels = ['num_joints', 'dataset_channel']
+ concat_cfg = []
+ for i in range(len(cfg['type'])):
+ cfg_tmp = cfg.deepcopy()
+ cfg_tmp['type'] = cfg['type'][i]
+ for item in replace:
+ assert item in cfg_tmp
+ assert len(cfg['type']) == len(cfg[item]), (cfg[item])
+ cfg_tmp[item] = cfg[item][i]
+ for item in channels:
+ assert item in cfg_tmp['data_cfg']
+ assert len(cfg['type']) == len(cfg['data_cfg'][item])
+ cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]
+ concat_cfg.append(cfg_tmp)
+ return concat_cfg
+
+
+def _check_vaild(cfg):
+ replace = ['num_joints', 'dataset_channel']
+ if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):
+ for item in replace:
+ cfg['data_cfg'][item] = cfg['data_cfg'][item][0]
+ return cfg
+
+
+def build_dataset(cfg, default_args=None):
+ """Build a dataset from config dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ default_args (dict, optional): Default initialization arguments.
+ Default: None.
+
+ Returns:
+ Dataset: The constructed dataset.
+ """
+ if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset
+ dataset = ConcatDataset(
+ [build_dataset(c, default_args) for c in _concat_cfg(cfg)])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ else:
+ cfg = _check_vaild(cfg)
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+ return dataset
diff --git a/models/datasets/datasets/__init__.py b/models/datasets/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca2431da38469ffe6d6e066db3539172555ad992
--- /dev/null
+++ b/models/datasets/datasets/__init__.py
@@ -0,0 +1,5 @@
+from .mp100 import (FewShotKeypointDataset, FewShotBaseDataset,
+ TransformerBaseDataset, TransformerPoseDataset)
+
+__all__ = ['FewShotBaseDataset', 'FewShotKeypointDataset',
+ 'TransformerBaseDataset', 'TransformerPoseDataset']
diff --git a/models/datasets/datasets/mp100/__init__.py b/models/datasets/datasets/mp100/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25e8a582047541857f7fe8f2ec3f182322b60f3
--- /dev/null
+++ b/models/datasets/datasets/mp100/__init__.py
@@ -0,0 +1,12 @@
+from .fewshot_base_dataset import FewShotBaseDataset
+from .fewshot_dataset import FewShotKeypointDataset
+from .test_base_dataset import TestBaseDataset
+from .test_dataset import TestPoseDataset
+from .transformer_base_dataset import TransformerBaseDataset
+from .transformer_dataset import TransformerPoseDataset
+
+__all__ = [
+ 'FewShotKeypointDataset', 'FewShotBaseDataset',
+ 'TransformerPoseDataset', 'TransformerBaseDataset',
+ 'TestBaseDataset', 'TestPoseDataset'
+]
diff --git a/models/datasets/datasets/mp100/fewshot_base_dataset.py b/models/datasets/datasets/mp100/fewshot_base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d7b9a75102ddc78fe6916b99db72cc9cdaaf1c5
--- /dev/null
+++ b/models/datasets/datasets/mp100/fewshot_base_dataset.py
@@ -0,0 +1,224 @@
+import copy
+from abc import ABCMeta, abstractmethod
+
+import json_tricks as json
+import numpy as np
+from mmcv.parallel import DataContainer as DC
+from mmpose.core.evaluation.top_down_eval import (keypoint_pck_accuracy)
+from mmpose.datasets import DATASETS
+from mmpose.datasets.pipelines import Compose
+from torch.utils.data import Dataset
+
+
+@DATASETS.register_module()
+class FewShotBaseDataset(Dataset, metaclass=ABCMeta):
+
+ def __init__(self,
+ ann_file,
+ img_prefix,
+ data_cfg,
+ pipeline,
+ test_mode=False):
+ self.image_info = {}
+ self.ann_info = {}
+
+ self.annotations_path = ann_file
+ if not img_prefix.endswith('/'):
+ img_prefix = img_prefix + '/'
+ self.img_prefix = img_prefix
+ self.pipeline = pipeline
+ self.test_mode = test_mode
+
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
+ self.ann_info['num_joints'] = data_cfg['num_joints']
+
+ self.ann_info['flip_pairs'] = None
+
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
+
+ self.db = []
+ self.num_shots = 1
+ self.paired_samples = []
+ self.pipeline = Compose(self.pipeline)
+
+ @abstractmethod
+ def _get_db(self):
+ """Load dataset."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _select_kpt(self, obj, kpt_id):
+ """Select kpt."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
+ """Evaluate keypoint results."""
+ raise NotImplementedError
+
+ @staticmethod
+ def _write_keypoint_results(keypoints, res_file):
+ """Write results into a json file."""
+
+ with open(res_file, 'w') as f:
+ json.dump(keypoints, f, sort_keys=True, indent=4)
+
+ def _report_metric(self,
+ res_file,
+ metrics,
+ pck_thr=0.2,
+ pckh_thr=0.7,
+ auc_nor=30):
+ """Keypoint evaluation.
+
+ Args:
+ res_file (str): Json file stored prediction results.
+ metrics (str | list[str]): Metric to be performed.
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
+ pck_thr (float): PCK threshold, default as 0.2.
+ pckh_thr (float): PCKh threshold, default as 0.7.
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
+
+ Returns:
+ List: Evaluation results for evaluation metric.
+ """
+ info_str = []
+
+ with open(res_file, 'r') as fin:
+ preds = json.load(fin)
+ assert len(preds) == len(self.paired_samples)
+
+ outputs = []
+ gts = []
+ masks = []
+ threshold_bbox = []
+ threshold_head_box = []
+
+ for pred, pair in zip(preds, self.paired_samples):
+ item = self.db[pair[-1]]
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
+ gts.append(np.array(item['joints_3d'])[:, :-1])
+
+ mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
+ mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
+ for id_s in pair[:-1]:
+ mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
+ masks.append(np.bitwise_and(mask_query, mask_sample))
+
+ if 'PCK' in metrics:
+ bbox = np.array(item['bbox'])
+ bbox_thr = np.max(bbox[2:])
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
+ if 'PCKh' in metrics:
+ head_box_thr = item['head_size']
+ threshold_head_box.append(
+ np.array([head_box_thr, head_box_thr]))
+
+ if 'PCK' in metrics:
+ pck_avg = []
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
+ _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt, 0),
+ np.expand_dims(mask, 0), pck_thr, np.expand_dims(thr_bbox, 0))
+ pck_avg.append(pck)
+ info_str.append(('PCK', np.mean(pck_avg)))
+
+ return info_str
+
+ def _merge_obj(self, Xs_list, Xq, idx):
+ """ merge Xs_list and Xq.
+
+ :param Xs_list: N-shot samples X
+ :param Xq: query X
+ :param idx: id of paired_samples
+ :return: Xall
+ """
+ Xall = dict()
+ Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
+ Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
+ Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
+ xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
+
+ Xall['img_q'] = Xq['img']
+ Xall['target_q'] = Xq['target']
+ Xall['target_weight_q'] = Xq['target_weight']
+ xq_img_metas = Xq['img_metas'].data
+
+ img_metas = dict()
+ for key in xq_img_metas.keys():
+ img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
+ img_metas['query_' + key] = xq_img_metas[key]
+ img_metas['bbox_id'] = idx
+
+ Xall['img_metas'] = DC(img_metas, cpu_only=True)
+
+ return Xall
+
+ def __len__(self):
+ """Get the size of the dataset."""
+ return len(self.paired_samples)
+
+ def __getitem__(self, idx):
+ """Get the sample given index."""
+
+ pair_ids = self.paired_samples[idx]
+ assert len(pair_ids) == self.num_shots + 1
+ sample_id_list = pair_ids[:self.num_shots]
+ query_id = pair_ids[-1]
+
+ sample_obj_list = []
+ for sample_id in sample_id_list:
+ sample_obj = copy.deepcopy(self.db[sample_id])
+ sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
+ sample_obj_list.append(sample_obj)
+
+ query_obj = copy.deepcopy(self.db[query_id])
+ query_obj['ann_info'] = copy.deepcopy(self.ann_info)
+
+ if not self.test_mode:
+ # randomly select "one" keypoint
+ sample_valid = (sample_obj_list[0]['joints_3d_visible'][:, 0] > 0)
+ for sample_obj in sample_obj_list:
+ sample_valid = sample_valid & (sample_obj['joints_3d_visible'][:, 0] > 0)
+ query_valid = (query_obj['joints_3d_visible'][:, 0] > 0)
+
+ valid_s = np.where(sample_valid)[0]
+ valid_q = np.where(query_valid)[0]
+ valid_sq = np.where(sample_valid & query_valid)[0]
+ if len(valid_sq) > 0:
+ kpt_id = np.random.choice(valid_sq)
+ elif len(valid_s) > 0:
+ kpt_id = np.random.choice(valid_s)
+ elif len(valid_q) > 0:
+ kpt_id = np.random.choice(valid_q)
+ else:
+ kpt_id = np.random.choice(np.array(range(len(query_valid))))
+
+ for i in range(self.num_shots):
+ sample_obj_list[i] = self._select_kpt(sample_obj_list[i], kpt_id)
+ query_obj = self._select_kpt(query_obj, kpt_id)
+
+ # when test, all keypoints will be preserved.
+
+ Xs_list = []
+ for sample_obj in sample_obj_list:
+ Xs = self.pipeline(sample_obj)
+ Xs_list.append(Xs)
+ Xq = self.pipeline(query_obj)
+
+ Xall = self._merge_obj(Xs_list, Xq, idx)
+ Xall['skeleton'] = self.db[query_id]['skeleton']
+
+ return Xall
+
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
+ """sort kpts and remove the repeated ones."""
+ kpts = sorted(kpts, key=lambda x: x[key])
+ num = len(kpts)
+ for i in range(num - 1, 0, -1):
+ if kpts[i][key] == kpts[i - 1][key]:
+ del kpts[i]
+
+ return kpts
diff --git a/models/datasets/datasets/mp100/fewshot_dataset.py b/models/datasets/datasets/mp100/fewshot_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..40ee4cad3c66c9549a7b2196f615a7eb4bef883a
--- /dev/null
+++ b/models/datasets/datasets/mp100/fewshot_dataset.py
@@ -0,0 +1,314 @@
+import os
+import random
+from collections import OrderedDict
+
+import numpy as np
+from mmpose.datasets import DATASETS
+from xtcocotools.coco import COCO
+
+from .fewshot_base_dataset import FewShotBaseDataset
+
+
+@DATASETS.register_module()
+class FewShotKeypointDataset(FewShotBaseDataset):
+
+ def __init__(self,
+ ann_file,
+ img_prefix,
+ data_cfg,
+ pipeline,
+ valid_class_ids,
+ num_shots=1,
+ num_queries=100,
+ num_episodes=1,
+ test_mode=False):
+ super().__init__(
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)
+
+ self.ann_info['flip_pairs'] = []
+
+ self.ann_info['upper_body_ids'] = []
+ self.ann_info['lower_body_ids'] = []
+
+ self.ann_info['use_different_joint_weights'] = False
+ self.ann_info['joint_weights'] = np.array([1., ],
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
+
+ self.coco = COCO(ann_file)
+
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
+ self.img_ids = self.coco.getImgIds()
+ self.classes = [
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
+ ]
+
+ self.num_classes = len(self.classes)
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
+
+ if valid_class_ids is not None:
+ self.valid_class_ids = valid_class_ids
+ else:
+ self.valid_class_ids = self.coco.getCatIds()
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
+
+ self.cats = self.coco.cats
+
+ # Also update self.cat2obj
+ self.db = self._get_db()
+
+ self.num_shots = num_shots
+
+ if not test_mode:
+ # Update every training epoch
+ self.random_paired_samples()
+ else:
+ self.num_queries = num_queries
+ self.num_episodes = num_episodes
+ self.make_paired_samples()
+
+ def random_paired_samples(self):
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
+
+ # balance the dataset
+ max_num_data = max(num_datas)
+
+ all_samples = []
+ for cls in self.valid_class_ids:
+ for i in range(max_num_data):
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
+ all_samples.append(shot)
+
+ self.paired_samples = np.array(all_samples)
+ np.random.shuffle(self.paired_samples)
+
+ def make_paired_samples(self):
+ random.seed(1)
+ np.random.seed(0)
+
+ all_samples = []
+ for cls in self.valid_class_ids:
+ for _ in range(self.num_episodes):
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
+ sample_ids = shots[:self.num_shots]
+ query_ids = shots[self.num_shots:]
+ for query_id in query_ids:
+ all_samples.append(sample_ids + [query_id])
+
+ self.paired_samples = np.array(all_samples)
+
+ def _select_kpt(self, obj, kpt_id):
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id + 1]
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id + 1]
+ obj['kpt_id'] = kpt_id
+
+ return obj
+
+ @staticmethod
+ def _get_mapping_id_name(imgs):
+ """
+ Args:
+ imgs (dict): dict of image info.
+
+ Returns:
+ tuple: Image name & id mapping dicts.
+
+ - id2name (dict): Mapping image id to name.
+ - name2id (dict): Mapping image name to id.
+ """
+ id2name = {}
+ name2id = {}
+ for image_id, image in imgs.items():
+ file_name = image['file_name']
+ id2name[image_id] = file_name
+ name2id[file_name] = image_id
+
+ return id2name, name2id
+
+ def _get_db(self):
+ """Ground truth bbox and keypoints."""
+ self.obj_id = 0
+
+ self.cat2obj = {}
+ for i in self.coco.getCatIds():
+ self.cat2obj.update({i: []})
+
+ gt_db = []
+ for img_id in self.img_ids:
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
+ return gt_db
+
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
+ """load annotation from COCOAPI.
+
+ Note:
+ bbox:[x1, y1, w, h]
+ Args:
+ img_id: coco image id
+ Returns:
+ dict: db entry
+ """
+ img_ann = self.coco.loadImgs(img_id)[0]
+ width = img_ann['width']
+ height = img_ann['height']
+
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
+ objs = self.coco.loadAnns(ann_ids)
+
+ # sanitize bboxes
+ valid_objs = []
+ for obj in objs:
+ if 'bbox' not in obj:
+ continue
+ x, y, w, h = obj['bbox']
+ x1 = max(0, x)
+ y1 = max(0, y)
+ x2 = min(width - 1, x1 + max(0, w - 1))
+ y2 = min(height - 1, y1 + max(0, h - 1))
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
+ valid_objs.append(obj)
+ objs = valid_objs
+
+ bbox_id = 0
+ rec = []
+ for obj in objs:
+ if 'keypoints' not in obj:
+ continue
+ if max(obj['keypoints']) == 0:
+ continue
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
+ continue
+
+ category_id = obj['category_id']
+ # the number of keypoint for this specific category
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
+
+ joints_3d = np.zeros((cat_kpt_num, 3), dtype=np.float32)
+ joints_3d_visible = np.zeros((cat_kpt_num, 3), dtype=np.float32)
+
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
+ joints_3d[:, :2] = keypoints[:, :2]
+ joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
+
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
+
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
+
+ self.cat2obj[category_id].append(self.obj_id)
+
+ rec.append({
+ 'image_file': image_file,
+ 'center': center,
+ 'scale': scale,
+ 'rotation': 0,
+ 'bbox': obj['clean_bbox'][:4],
+ 'bbox_score': 1,
+ 'joints_3d': joints_3d,
+ 'joints_3d_visible': joints_3d_visible,
+ 'category_id': category_id,
+ 'cat_kpt_num': cat_kpt_num,
+ 'bbox_id': self.obj_id,
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
+ })
+ bbox_id = bbox_id + 1
+ self.obj_id += 1
+
+ return rec
+
+ def _xywh2cs(self, x, y, w, h):
+ """This encodes bbox(x,y,w,w) into (center, scale)
+
+ Args:
+ x, y, w, h
+
+ Returns:
+ tuple: A tuple containing center and scale.
+
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
+ """
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
+ #
+ # if (not self.test_mode) and np.random.rand() < 0.3:
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
+
+ if w > aspect_ratio * h:
+ h = w * 1.0 / aspect_ratio
+ elif w < aspect_ratio * h:
+ w = h * aspect_ratio
+
+ # pixel std is 200.0
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
+ # padding to include proper amount of context
+ scale = scale * 1.25
+
+ return center, scale
+
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
+ """Evaluate interhand2d keypoint results. The pose prediction results
+ will be saved in `${res_folder}/result_keypoints.json`.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ outputs (list(preds, boxes, image_path, output_heatmap))
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
+ coordinates, score is the third dimension of the array.
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
+ , scale[1],area, score]
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
+ 'j', 'p', 'g']
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
+
+ res_folder (str): Path of directory to save the results.
+ metric (str | list[str]): Metric to be performed.
+ Options: 'PCK', 'AUC', 'EPE'.
+
+ Returns:
+ dict: Evaluation results for evaluation metric.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['PCK', 'AUC', 'EPE']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
+
+ kpts = []
+ for output in outputs:
+ preds = output['preds']
+ boxes = output['boxes']
+ image_paths = output['image_paths']
+ bbox_ids = output['bbox_ids']
+
+ batch_size = len(image_paths)
+ for i in range(batch_size):
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
+
+ kpts.append({
+ 'keypoints': preds[i].tolist(),
+ 'center': boxes[i][0:2].tolist(),
+ 'scale': boxes[i][2:4].tolist(),
+ 'area': float(boxes[i][4]),
+ 'score': float(boxes[i][5]),
+ 'image_id': image_id,
+ 'bbox_id': bbox_ids[i]
+ })
+ kpts = self._sort_and_unique_bboxes(kpts)
+
+ self._write_keypoint_results(kpts, res_file)
+ info_str = self._report_metric(res_file, metrics)
+ name_value = OrderedDict(info_str)
+
+ return name_value
diff --git a/models/datasets/datasets/mp100/test_base_dataset.py b/models/datasets/datasets/mp100/test_base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a163ffee588867d20b926a0f4320faf462f635d8
--- /dev/null
+++ b/models/datasets/datasets/mp100/test_base_dataset.py
@@ -0,0 +1,230 @@
+import copy
+from abc import ABCMeta, abstractmethod
+
+import json_tricks as json
+import numpy as np
+from mmcv.parallel import DataContainer as DC
+from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
+ keypoint_pck_accuracy)
+from mmpose.datasets import DATASETS
+from mmpose.datasets.pipelines import Compose
+from torch.utils.data import Dataset
+
+
+@DATASETS.register_module()
+class TestBaseDataset(Dataset, metaclass=ABCMeta):
+
+ def __init__(self,
+ ann_file,
+ img_prefix,
+ data_cfg,
+ pipeline,
+ test_mode=True,
+ PCK_threshold_list=[0.05, 0.1, 0.15, 0.2, 0.25]):
+ self.image_info = {}
+ self.ann_info = {}
+
+ self.annotations_path = ann_file
+ if not img_prefix.endswith('/'):
+ img_prefix = img_prefix + '/'
+ self.img_prefix = img_prefix
+ self.pipeline = pipeline
+ self.test_mode = test_mode
+ self.PCK_threshold_list = PCK_threshold_list
+
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
+ self.ann_info['num_joints'] = data_cfg['num_joints']
+
+ self.ann_info['flip_pairs'] = None
+
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
+
+ self.db = []
+ self.num_shots = 1
+ self.paired_samples = []
+ self.pipeline = Compose(self.pipeline)
+
+ @abstractmethod
+ def _get_db(self):
+ """Load dataset."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _select_kpt(self, obj, kpt_id):
+ """Select kpt."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
+ """Evaluate keypoint results."""
+ raise NotImplementedError
+
+ @staticmethod
+ def _write_keypoint_results(keypoints, res_file):
+ """Write results into a json file."""
+
+ with open(res_file, 'w') as f:
+ json.dump(keypoints, f, sort_keys=True, indent=4)
+
+ def _report_metric(self,
+ res_file,
+ metrics):
+ """Keypoint evaluation.
+
+ Args:
+ res_file (str): Json file stored prediction results.
+ metrics (str | list[str]): Metric to be performed.
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
+ pck_thr (float): PCK threshold, default as 0.2.
+ pckh_thr (float): PCKh threshold, default as 0.7.
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
+
+ Returns:
+ List: Evaluation results for evaluation metric.
+ """
+ info_str = []
+
+ with open(res_file, 'r') as fin:
+ preds = json.load(fin)
+ assert len(preds) == len(self.paired_samples)
+
+ outputs = []
+ gts = []
+ masks = []
+ threshold_bbox = []
+ threshold_head_box = []
+
+ for pred, pair in zip(preds, self.paired_samples):
+ item = self.db[pair[-1]]
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
+ gts.append(np.array(item['joints_3d'])[:, :-1])
+
+ mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
+ mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
+ for id_s in pair[:-1]:
+ mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
+ masks.append(np.bitwise_and(mask_query, mask_sample))
+
+ if 'PCK' in metrics or 'NME' in metrics or 'AUC' in metrics:
+ bbox = np.array(item['bbox'])
+ bbox_thr = np.max(bbox[2:])
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
+ if 'PCKh' in metrics:
+ head_box_thr = item['head_size']
+ threshold_head_box.append(
+ np.array([head_box_thr, head_box_thr]))
+
+ if 'PCK' in metrics:
+ pck_results = dict()
+ for pck_thr in self.PCK_threshold_list:
+ pck_results[pck_thr] = []
+
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
+ for pck_thr in self.PCK_threshold_list:
+ _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt, 0),
+ np.expand_dims(mask, 0), pck_thr, np.expand_dims(thr_bbox, 0))
+ pck_results[pck_thr].append(pck)
+
+ mPCK = 0
+ for pck_thr in self.PCK_threshold_list:
+ info_str.append(['PCK@' + str(pck_thr), np.mean(pck_results[pck_thr])])
+ mPCK += np.mean(pck_results[pck_thr])
+ info_str.append(['mPCK', mPCK / len(self.PCK_threshold_list)])
+
+ if 'NME' in metrics:
+ nme_results = []
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
+ nme = keypoint_nme(np.expand_dims(output, 0), np.expand_dims(gt, 0), np.expand_dims(mask, 0),
+ np.expand_dims(thr_bbox, 0))
+ nme_results.append(nme)
+ info_str.append(['NME', np.mean(nme_results)])
+
+ if 'AUC' in metrics:
+ auc_results = []
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
+ auc = keypoint_auc(np.expand_dims(output, 0), np.expand_dims(gt, 0), np.expand_dims(mask, 0),
+ thr_bbox[0])
+ auc_results.append(auc)
+ info_str.append(['AUC', np.mean(auc_results)])
+
+ if 'EPE' in metrics:
+ epe_results = []
+ for (output, gt, mask) in zip(outputs, gts, masks):
+ epe = keypoint_epe(np.expand_dims(output, 0), np.expand_dims(gt, 0), np.expand_dims(mask, 0))
+ epe_results.append(epe)
+ info_str.append(['EPE', np.mean(epe_results)])
+ return info_str
+
+ def _merge_obj(self, Xs_list, Xq, idx):
+ """ merge Xs_list and Xq.
+
+ :param Xs_list: N-shot samples X
+ :param Xq: query X
+ :param idx: id of paired_samples
+ :return: Xall
+ """
+ Xall = dict()
+ Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
+ Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
+ Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
+ xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
+
+ Xall['img_q'] = Xq['img']
+ Xall['target_q'] = Xq['target']
+ Xall['target_weight_q'] = Xq['target_weight']
+ xq_img_metas = Xq['img_metas'].data
+
+ img_metas = dict()
+ for key in xq_img_metas.keys():
+ img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
+ img_metas['query_' + key] = xq_img_metas[key]
+ img_metas['bbox_id'] = idx
+
+ Xall['img_metas'] = DC(img_metas, cpu_only=True)
+
+ return Xall
+
+ def __len__(self):
+ """Get the size of the dataset."""
+ return len(self.paired_samples)
+
+ def __getitem__(self, idx):
+ """Get the sample given index."""
+
+ pair_ids = self.paired_samples[idx] # [supported id * shots, query id]
+ assert len(pair_ids) == self.num_shots + 1
+ sample_id_list = pair_ids[:self.num_shots]
+ query_id = pair_ids[-1]
+
+ sample_obj_list = []
+ for sample_id in sample_id_list:
+ sample_obj = copy.deepcopy(self.db[sample_id])
+ sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
+ sample_obj_list.append(sample_obj)
+
+ query_obj = copy.deepcopy(self.db[query_id])
+ query_obj['ann_info'] = copy.deepcopy(self.ann_info)
+
+ Xs_list = []
+ for sample_obj in sample_obj_list:
+ Xs = self.pipeline(sample_obj) # dict with ['img', 'target', 'target_weight', 'img_metas'],
+ Xs_list.append(Xs) # Xs['target'] is of shape [100, map_h, map_w]
+ Xq = self.pipeline(query_obj)
+
+ Xall = self._merge_obj(Xs_list, Xq, idx)
+ Xall['skeleton'] = self.db[query_id]['skeleton']
+
+ return Xall
+
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
+ """sort kpts and remove the repeated ones."""
+ kpts = sorted(kpts, key=lambda x: x[key])
+ num = len(kpts)
+ for i in range(num - 1, 0, -1):
+ if kpts[i][key] == kpts[i - 1][key]:
+ del kpts[i]
+
+ return kpts
diff --git a/models/datasets/datasets/mp100/test_dataset.py b/models/datasets/datasets/mp100/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb217abf92a5c8fb5380361eeeaef5a9aa04acc
--- /dev/null
+++ b/models/datasets/datasets/mp100/test_dataset.py
@@ -0,0 +1,321 @@
+import os
+import random
+from collections import OrderedDict
+
+import numpy as np
+from mmpose.datasets import DATASETS
+from xtcocotools.coco import COCO
+
+from .test_base_dataset import TestBaseDataset
+
+
+@DATASETS.register_module()
+class TestPoseDataset(TestBaseDataset):
+
+ def __init__(self,
+ ann_file,
+ img_prefix,
+ data_cfg,
+ pipeline,
+ valid_class_ids,
+ max_kpt_num=None,
+ num_shots=1,
+ num_queries=100,
+ num_episodes=1,
+ pck_threshold_list=[0.05, 0.1, 0.15, 0.20, 0.25],
+ test_mode=True):
+ super().__init__(
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode, PCK_threshold_list=pck_threshold_list)
+
+ self.ann_info['flip_pairs'] = []
+
+ self.ann_info['upper_body_ids'] = []
+ self.ann_info['lower_body_ids'] = []
+
+ self.ann_info['use_different_joint_weights'] = False
+ self.ann_info['joint_weights'] = np.array([1., ],
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
+
+ self.coco = COCO(ann_file)
+
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
+ self.img_ids = self.coco.getImgIds()
+ self.classes = [
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
+ ]
+
+ self.num_classes = len(self.classes)
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
+
+ if valid_class_ids is not None: # None by default
+ self.valid_class_ids = valid_class_ids
+ else:
+ self.valid_class_ids = self.coco.getCatIds()
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
+
+ self.cats = self.coco.cats
+ self.max_kpt_num = max_kpt_num
+
+ # Also update self.cat2obj
+ self.db = self._get_db()
+
+ self.num_shots = num_shots
+
+ if not test_mode:
+ # Update every training epoch
+ self.random_paired_samples()
+ else:
+ self.num_queries = num_queries
+ self.num_episodes = num_episodes
+ self.make_paired_samples()
+
+ def random_paired_samples(self):
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
+
+ # balance the dataset
+ max_num_data = max(num_datas)
+
+ all_samples = []
+ for cls in self.valid_class_ids:
+ for i in range(max_num_data):
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
+ all_samples.append(shot)
+
+ self.paired_samples = np.array(all_samples)
+ np.random.shuffle(self.paired_samples)
+
+ def make_paired_samples(self):
+ random.seed(1)
+ np.random.seed(0)
+
+ all_samples = []
+ for cls in self.valid_class_ids:
+ for _ in range(self.num_episodes):
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
+ sample_ids = shots[:self.num_shots]
+ query_ids = shots[self.num_shots:]
+ for query_id in query_ids:
+ all_samples.append(sample_ids + [query_id])
+
+ self.paired_samples = np.array(all_samples)
+
+ def _select_kpt(self, obj, kpt_id):
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id + 1]
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id + 1]
+ obj['kpt_id'] = kpt_id
+
+ return obj
+
+ @staticmethod
+ def _get_mapping_id_name(imgs):
+ """
+ Args:
+ imgs (dict): dict of image info.
+
+ Returns:
+ tuple: Image name & id mapping dicts.
+
+ - id2name (dict): Mapping image id to name.
+ - name2id (dict): Mapping image name to id.
+ """
+ id2name = {}
+ name2id = {}
+ for image_id, image in imgs.items():
+ file_name = image['file_name']
+ id2name[image_id] = file_name
+ name2id[file_name] = image_id
+
+ return id2name, name2id
+
+ def _get_db(self):
+ """Ground truth bbox and keypoints."""
+ self.obj_id = 0
+
+ self.cat2obj = {}
+ for i in self.coco.getCatIds():
+ self.cat2obj.update({i: []})
+
+ gt_db = []
+ for img_id in self.img_ids:
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
+ return gt_db
+
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
+ """load annotation from COCOAPI.
+
+ Note:
+ bbox:[x1, y1, w, h]
+ Args:
+ img_id: coco image id
+ Returns:
+ dict: db entry
+ """
+ img_ann = self.coco.loadImgs(img_id)[0]
+ width = img_ann['width']
+ height = img_ann['height']
+
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
+ objs = self.coco.loadAnns(ann_ids)
+
+ # sanitize bboxes
+ valid_objs = []
+ for obj in objs:
+ if 'bbox' not in obj:
+ continue
+ x, y, w, h = obj['bbox']
+ x1 = max(0, x)
+ y1 = max(0, y)
+ x2 = min(width - 1, x1 + max(0, w - 1))
+ y2 = min(height - 1, y1 + max(0, h - 1))
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
+ valid_objs.append(obj)
+ objs = valid_objs
+
+ bbox_id = 0
+ rec = []
+ for obj in objs:
+ if 'keypoints' not in obj:
+ continue
+ if max(obj['keypoints']) == 0:
+ continue
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
+ continue
+
+ category_id = obj['category_id']
+ # the number of keypoint for this specific category
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
+ if self.max_kpt_num is None:
+ kpt_num = cat_kpt_num
+ else:
+ kpt_num = self.max_kpt_num
+
+ joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
+ joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
+
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
+ joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
+ joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
+
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
+
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
+
+ self.cat2obj[category_id].append(self.obj_id)
+
+ rec.append({
+ 'image_file': image_file,
+ 'center': center,
+ 'scale': scale,
+ 'rotation': 0,
+ 'bbox': obj['clean_bbox'][:4],
+ 'bbox_score': 1,
+ 'joints_3d': joints_3d,
+ 'joints_3d_visible': joints_3d_visible,
+ 'category_id': category_id,
+ 'cat_kpt_num': cat_kpt_num,
+ 'bbox_id': self.obj_id,
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
+ })
+ bbox_id = bbox_id + 1
+ self.obj_id += 1
+
+ return rec
+
+ def _xywh2cs(self, x, y, w, h):
+ """This encodes bbox(x,y,w,w) into (center, scale)
+
+ Args:
+ x, y, w, h
+
+ Returns:
+ tuple: A tuple containing center and scale.
+
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
+ """
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
+ #
+ # if (not self.test_mode) and np.random.rand() < 0.3:
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
+
+ if w > aspect_ratio * h:
+ h = w * 1.0 / aspect_ratio
+ elif w < aspect_ratio * h:
+ w = h * aspect_ratio
+
+ # pixel std is 200.0
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
+ # padding to include proper amount of context
+ scale = scale * 1.25
+
+ return center, scale
+
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
+ """Evaluate interhand2d keypoint results. The pose prediction results
+ will be saved in `${res_folder}/result_keypoints.json`.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ outputs (list(preds, boxes, image_path, output_heatmap))
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
+ coordinates, score is the third dimension of the array.
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
+ , scale[1],area, score]
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
+ 'j', 'p', 'g']
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
+
+ res_folder (str): Path of directory to save the results.
+ metric (str | list[str]): Metric to be performed.
+ Options: 'PCK', 'AUC', 'EPE'.
+
+ Returns:
+ dict: Evaluation results for evaluation metric.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
+
+ kpts = []
+ for output in outputs:
+ preds = output['preds']
+ boxes = output['boxes']
+ image_paths = output['image_paths']
+ bbox_ids = output['bbox_ids']
+
+ batch_size = len(image_paths)
+ for i in range(batch_size):
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
+
+ kpts.append({
+ 'keypoints': preds[i].tolist(),
+ 'center': boxes[i][0:2].tolist(),
+ 'scale': boxes[i][2:4].tolist(),
+ 'area': float(boxes[i][4]),
+ 'score': float(boxes[i][5]),
+ 'image_id': image_id,
+ 'bbox_id': bbox_ids[i]
+ })
+ kpts = self._sort_and_unique_bboxes(kpts)
+
+ self._write_keypoint_results(kpts, res_file)
+ info_str = self._report_metric(res_file, metrics)
+ name_value = OrderedDict(info_str)
+
+ return name_value
diff --git a/models/datasets/datasets/mp100/transformer_base_dataset.py b/models/datasets/datasets/mp100/transformer_base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4faa5bf2fb366b58b25be95af8f8943d5da357f
--- /dev/null
+++ b/models/datasets/datasets/mp100/transformer_base_dataset.py
@@ -0,0 +1,198 @@
+import copy
+from abc import ABCMeta, abstractmethod
+
+import json_tricks as json
+import numpy as np
+from mmcv.parallel import DataContainer as DC
+from mmpose.core.evaluation.top_down_eval import (keypoint_pck_accuracy)
+from mmpose.datasets import DATASETS
+from mmpose.datasets.pipelines import Compose
+from torch.utils.data import Dataset
+
+
+@DATASETS.register_module()
+class TransformerBaseDataset(Dataset, metaclass=ABCMeta):
+
+ def __init__(self,
+ ann_file,
+ img_prefix,
+ data_cfg,
+ pipeline,
+ test_mode=False):
+ self.image_info = {}
+ self.ann_info = {}
+
+ self.annotations_path = ann_file
+ if not img_prefix.endswith('/'):
+ img_prefix = img_prefix + '/'
+ self.img_prefix = img_prefix
+ self.pipeline = pipeline
+ self.test_mode = test_mode
+
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
+ self.ann_info['num_joints'] = data_cfg['num_joints']
+
+ self.ann_info['flip_pairs'] = None
+
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
+
+ self.db = []
+ self.num_shots = 1
+ self.paired_samples = []
+ self.pipeline = Compose(self.pipeline)
+
+ @abstractmethod
+ def _get_db(self):
+ """Load dataset."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _select_kpt(self, obj, kpt_id):
+ """Select kpt."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
+ """Evaluate keypoint results."""
+ raise NotImplementedError
+
+ @staticmethod
+ def _write_keypoint_results(keypoints, res_file):
+ """Write results into a json file."""
+
+ with open(res_file, 'w') as f:
+ json.dump(keypoints, f, sort_keys=True, indent=4)
+
+ def _report_metric(self,
+ res_file,
+ metrics,
+ pck_thr=0.2,
+ pckh_thr=0.7,
+ auc_nor=30):
+ """Keypoint evaluation.
+
+ Args:
+ res_file (str): Json file stored prediction results.
+ metrics (str | list[str]): Metric to be performed.
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
+ pck_thr (float): PCK threshold, default as 0.2.
+ pckh_thr (float): PCKh threshold, default as 0.7.
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
+
+ Returns:
+ List: Evaluation results for evaluation metric.
+ """
+ info_str = []
+
+ with open(res_file, 'r') as fin:
+ preds = json.load(fin)
+ assert len(preds) == len(self.paired_samples)
+
+ outputs = []
+ gts = []
+ masks = []
+ threshold_bbox = []
+ threshold_head_box = []
+
+ for pred, pair in zip(preds, self.paired_samples):
+ item = self.db[pair[-1]]
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
+ gts.append(np.array(item['joints_3d'])[:, :-1])
+
+ mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
+ mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
+ for id_s in pair[:-1]:
+ mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
+ masks.append(np.bitwise_and(mask_query, mask_sample))
+
+ if 'PCK' in metrics:
+ bbox = np.array(item['bbox'])
+ bbox_thr = np.max(bbox[2:])
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
+ if 'PCKh' in metrics:
+ head_box_thr = item['head_size']
+ threshold_head_box.append(
+ np.array([head_box_thr, head_box_thr]))
+
+ if 'PCK' in metrics:
+ pck_avg = []
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
+ _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt, 0),
+ np.expand_dims(mask, 0), pck_thr, np.expand_dims(thr_bbox, 0))
+ pck_avg.append(pck)
+ info_str.append(('PCK', np.mean(pck_avg)))
+
+ return info_str
+
+ def _merge_obj(self, Xs_list, Xq, idx):
+ """ merge Xs_list and Xq.
+
+ :param Xs_list: N-shot samples X
+ :param Xq: query X
+ :param idx: id of paired_samples
+ :return: Xall
+ """
+ Xall = dict()
+ Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
+ Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
+ Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
+ xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
+
+ Xall['img_q'] = Xq['img']
+ Xall['target_q'] = Xq['target']
+ Xall['target_weight_q'] = Xq['target_weight']
+ xq_img_metas = Xq['img_metas'].data
+
+ img_metas = dict()
+ for key in xq_img_metas.keys():
+ img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
+ img_metas['query_' + key] = xq_img_metas[key]
+ img_metas['bbox_id'] = idx
+
+ Xall['img_metas'] = DC(img_metas, cpu_only=True)
+
+ return Xall
+
+ def __len__(self):
+ """Get the size of the dataset."""
+ return len(self.paired_samples)
+
+ def __getitem__(self, idx):
+ """Get the sample given index."""
+
+ pair_ids = self.paired_samples[idx] # [supported id * shots, query id]
+ assert len(pair_ids) == self.num_shots + 1
+ sample_id_list = pair_ids[:self.num_shots]
+ query_id = pair_ids[-1]
+
+ sample_obj_list = []
+ for sample_id in sample_id_list:
+ sample_obj = copy.deepcopy(self.db[sample_id])
+ sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
+ sample_obj_list.append(sample_obj)
+
+ query_obj = copy.deepcopy(self.db[query_id])
+ query_obj['ann_info'] = copy.deepcopy(self.ann_info)
+
+ Xs_list = []
+ for sample_obj in sample_obj_list:
+ Xs = self.pipeline(sample_obj) # dict with ['img', 'target', 'target_weight', 'img_metas'],
+ Xs_list.append(Xs) # Xs['target'] is of shape [100, map_h, map_w]
+ Xq = self.pipeline(query_obj)
+
+ Xall = self._merge_obj(Xs_list, Xq, idx)
+ Xall['skeleton'] = self.db[query_id]['skeleton']
+ return Xall
+
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
+ """sort kpts and remove the repeated ones."""
+ kpts = sorted(kpts, key=lambda x: x[key])
+ num = len(kpts)
+ for i in range(num - 1, 0, -1):
+ if kpts[i][key] == kpts[i - 1][key]:
+ del kpts[i]
+
+ return kpts
diff --git a/models/datasets/datasets/mp100/transformer_dataset.py b/models/datasets/datasets/mp100/transformer_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..73a7a7a702ceeaa14a9073f1d6d8d70fc1e4b449
--- /dev/null
+++ b/models/datasets/datasets/mp100/transformer_dataset.py
@@ -0,0 +1,321 @@
+import os
+import random
+from collections import OrderedDict
+
+import numpy as np
+from mmpose.datasets import DATASETS
+from xtcocotools.coco import COCO
+
+from .transformer_base_dataset import TransformerBaseDataset
+
+
+@DATASETS.register_module()
+class TransformerPoseDataset(TransformerBaseDataset):
+
+ def __init__(self,
+ ann_file,
+ img_prefix,
+ data_cfg,
+ pipeline,
+ valid_class_ids,
+ max_kpt_num=None,
+ num_shots=1,
+ num_queries=100,
+ num_episodes=1,
+ test_mode=False):
+ super().__init__(
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)
+
+ self.ann_info['flip_pairs'] = []
+
+ self.ann_info['upper_body_ids'] = []
+ self.ann_info['lower_body_ids'] = []
+
+ self.ann_info['use_different_joint_weights'] = False
+ self.ann_info['joint_weights'] = np.array([1., ],
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
+
+ self.coco = COCO(ann_file)
+
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
+ self.img_ids = self.coco.getImgIds()
+ self.classes = [
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
+ ]
+
+ self.num_classes = len(self.classes)
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
+
+ if valid_class_ids is not None: # None by default
+ self.valid_class_ids = valid_class_ids
+ else:
+ self.valid_class_ids = self.coco.getCatIds()
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
+
+ self.cats = self.coco.cats
+ self.max_kpt_num = max_kpt_num
+
+ # Also update self.cat2obj
+ self.db = self._get_db()
+
+ self.num_shots = num_shots
+
+ if not test_mode:
+ # Update every training epoch
+ self.random_paired_samples()
+ else:
+ self.num_queries = num_queries
+ self.num_episodes = num_episodes
+ self.make_paired_samples()
+
+ def random_paired_samples(self):
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
+
+ # balance the dataset
+ max_num_data = max(num_datas)
+
+ all_samples = []
+ for cls in self.valid_class_ids:
+ for i in range(max_num_data):
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
+ all_samples.append(shot)
+
+ self.paired_samples = np.array(all_samples)
+ np.random.shuffle(self.paired_samples)
+
+ def make_paired_samples(self):
+ random.seed(1)
+ np.random.seed(0)
+
+ all_samples = []
+ for cls in self.valid_class_ids:
+ for _ in range(self.num_episodes):
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
+ sample_ids = shots[:self.num_shots]
+ query_ids = shots[self.num_shots:]
+ for query_id in query_ids:
+ all_samples.append(sample_ids + [query_id])
+
+ self.paired_samples = np.array(all_samples)
+
+ def _select_kpt(self, obj, kpt_id):
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id + 1]
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id + 1]
+ obj['kpt_id'] = kpt_id
+
+ return obj
+
+ @staticmethod
+ def _get_mapping_id_name(imgs):
+ """
+ Args:
+ imgs (dict): dict of image info.
+
+ Returns:
+ tuple: Image name & id mapping dicts.
+
+ - id2name (dict): Mapping image id to name.
+ - name2id (dict): Mapping image name to id.
+ """
+ id2name = {}
+ name2id = {}
+ for image_id, image in imgs.items():
+ file_name = image['file_name']
+ id2name[image_id] = file_name
+ name2id[file_name] = image_id
+
+ return id2name, name2id
+
+ def _get_db(self):
+ """Ground truth bbox and keypoints."""
+ self.obj_id = 0
+
+ self.cat2obj = {}
+ for i in self.coco.getCatIds():
+ self.cat2obj.update({i: []})
+
+ gt_db = []
+ for img_id in self.img_ids:
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
+
+ return gt_db
+
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
+ """load annotation from COCOAPI.
+
+ Note:
+ bbox:[x1, y1, w, h]
+ Args:
+ img_id: coco image id
+ Returns:
+ dict: db entry
+ """
+ img_ann = self.coco.loadImgs(img_id)[0]
+ width = img_ann['width']
+ height = img_ann['height']
+
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
+ objs = self.coco.loadAnns(ann_ids)
+
+ # sanitize bboxes
+ valid_objs = []
+ for obj in objs:
+ if 'bbox' not in obj:
+ continue
+ x, y, w, h = obj['bbox']
+ x1 = max(0, x)
+ y1 = max(0, y)
+ x2 = min(width - 1, x1 + max(0, w - 1))
+ y2 = min(height - 1, y1 + max(0, h - 1))
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
+ valid_objs.append(obj)
+ objs = valid_objs
+
+ bbox_id = 0
+ rec = []
+ for obj in objs:
+ if 'keypoints' not in obj:
+ continue
+ if max(obj['keypoints']) == 0:
+ continue
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
+ continue
+
+ category_id = obj['category_id']
+ # the number of keypoint for this specific category
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
+ if self.max_kpt_num is None:
+ kpt_num = cat_kpt_num
+ else:
+ kpt_num = self.max_kpt_num
+
+ joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
+ joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
+
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
+ joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
+ joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
+
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
+
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
+ if os.path.exists(image_file):
+ self.cat2obj[category_id].append(self.obj_id)
+
+ rec.append({
+ 'image_file': image_file,
+ 'center': center,
+ 'scale': scale,
+ 'rotation': 0,
+ 'bbox': obj['clean_bbox'][:4],
+ 'bbox_score': 1,
+ 'joints_3d': joints_3d,
+ 'joints_3d_visible': joints_3d_visible,
+ 'category_id': category_id,
+ 'cat_kpt_num': cat_kpt_num,
+ 'bbox_id': self.obj_id,
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
+ })
+ bbox_id = bbox_id + 1
+ self.obj_id += 1
+
+ return rec
+
+ def _xywh2cs(self, x, y, w, h):
+ """This encodes bbox(x,y,w,w) into (center, scale)
+
+ Args:
+ x, y, w, h
+
+ Returns:
+ tuple: A tuple containing center and scale.
+
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
+ """
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
+ #
+ # if (not self.test_mode) and np.random.rand() < 0.3:
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
+
+ if w > aspect_ratio * h:
+ h = w * 1.0 / aspect_ratio
+ elif w < aspect_ratio * h:
+ w = h * aspect_ratio
+
+ # pixel std is 200.0
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
+ # padding to include proper amount of context
+ scale = scale * 1.25
+
+ return center, scale
+
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
+ """Evaluate interhand2d keypoint results. The pose prediction results
+ will be saved in `${res_folder}/result_keypoints.json`.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ outputs (list(preds, boxes, image_path, output_heatmap))
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
+ coordinates, score is the third dimension of the array.
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
+ , scale[1],area, score]
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
+ 'j', 'p', 'g']
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
+
+ res_folder (str): Path of directory to save the results.
+ metric (str | list[str]): Metric to be performed.
+ Options: 'PCK', 'AUC', 'EPE'.
+
+ Returns:
+ dict: Evaluation results for evaluation metric.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
+
+ kpts = []
+ for output in outputs:
+ preds = output['preds']
+ boxes = output['boxes']
+ image_paths = output['image_paths']
+ bbox_ids = output['bbox_ids']
+
+ batch_size = len(image_paths)
+ for i in range(batch_size):
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
+
+ kpts.append({
+ 'keypoints': preds[i].tolist(),
+ 'center': boxes[i][0:2].tolist(),
+ 'scale': boxes[i][2:4].tolist(),
+ 'area': float(boxes[i][4]),
+ 'score': float(boxes[i][5]),
+ 'image_id': image_id,
+ 'bbox_id': bbox_ids[i]
+ })
+ kpts = self._sort_and_unique_bboxes(kpts)
+
+ self._write_keypoint_results(kpts, res_file)
+ info_str = self._report_metric(res_file, metrics)
+ name_value = OrderedDict(info_str)
+
+ return name_value
diff --git a/models/datasets/pipelines/__init__.py b/models/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9d29b3ab6f7e2ab2bf7baf40df0d7861c64820a
--- /dev/null
+++ b/models/datasets/pipelines/__init__.py
@@ -0,0 +1,6 @@
+from .top_down_transform import (TopDownAffineFewShot,
+ TopDownGenerateTargetFewShot)
+
+__all__ = [
+ 'TopDownGenerateTargetFewShot', 'TopDownAffineFewShot'
+]
diff --git a/models/datasets/pipelines/post_transforms.py b/models/datasets/pipelines/post_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1025daf5cc87ca3d9a3a204a8df05ca8af725fd
--- /dev/null
+++ b/models/datasets/pipelines/post_transforms.py
@@ -0,0 +1,121 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
+# Original licence: Copyright (c) Microsoft, under the MIT License.
+# ------------------------------------------------------------------------------
+
+import cv2
+import numpy as np
+
+
+def get_affine_transform(center,
+ scale,
+ rot,
+ output_size,
+ shift=(0., 0.),
+ inv=False):
+ """Get the affine transform matrix, given the center/scale/rot/output_size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ]): Size of the destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: The transform matrix.
+ """
+ assert len(center) == 2
+ assert len(scale) == 2
+ assert len(output_size) == 2
+ assert len(shift) == 2
+
+ # pixel_std is 200.
+ scale_tmp = scale * 200.0
+
+ shift = np.array(shift)
+ src_w = scale_tmp[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = rotate_point([0., src_w * -0.5], rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift
+ src[1, :] = center + src_dir + scale_tmp * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def affine_transform(pt, trans_mat):
+ """Apply an affine transformation to the points.
+
+ Args:
+ pt (np.ndarray): a 2 dimensional point to be transformed
+ trans_mat (np.ndarray): 2x3 matrix of an affine transform
+
+ Returns:
+ np.ndarray: Transformed points.
+ """
+ assert len(pt) == 2
+ new_pt = np.array(trans_mat) @ np.array([pt[0], pt[1], 1.])
+
+ return new_pt
+
+
+def _get_3rd_point(a, b):
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): point(x,y)
+ b (np.ndarray): point(x,y)
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ assert len(a) == 2
+ assert len(b) == 2
+ direction = a - b
+ third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
+
+ return third_pt
+
+
+def rotate_point(pt, angle_rad):
+ """Rotate a point by an angle.
+
+ Args:
+ pt (list[float]): 2 dimensional point to be rotated
+ angle_rad (float): rotation angle by radian
+
+ Returns:
+ list[float]: Rotated point.
+ """
+ assert len(pt) == 2
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ new_x = pt[0] * cs - pt[1] * sn
+ new_y = pt[0] * sn + pt[1] * cs
+ rotated_pt = [new_x, new_y]
+
+ return rotated_pt
diff --git a/models/datasets/pipelines/top_down_transform.py b/models/datasets/pipelines/top_down_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac3ed1defef3d56f36b428f82afc4ebbb96a67a
--- /dev/null
+++ b/models/datasets/pipelines/top_down_transform.py
@@ -0,0 +1,378 @@
+import cv2
+import numpy as np
+from mmpose.core.post_processing import (get_warp_matrix,
+ warp_affine_joints)
+from mmpose.datasets.builder import PIPELINES
+
+from .post_transforms import (affine_transform,
+ get_affine_transform)
+
+
+@PIPELINES.register_module()
+class TopDownAffineFewShot:
+ """Affine transform the image to make input.
+
+ Required keys:'img', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
+ 'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
+ 'joints_3d_visible'.
+
+ Args:
+ use_udp (bool): To use unbiased data processing.
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
+ """
+
+ def __init__(self, use_udp=False):
+ self.use_udp = use_udp
+
+ def __call__(self, results):
+ image_size = results['ann_info']['image_size']
+
+ img = results['img']
+ joints_3d = results['joints_3d']
+ joints_3d_visible = results['joints_3d_visible']
+ c = results['center']
+ s = results['scale']
+ r = results['rotation']
+
+ if self.use_udp:
+ trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
+ img = cv2.warpAffine(
+ img,
+ trans, (int(image_size[0]), int(image_size[1])),
+ flags=cv2.INTER_LINEAR)
+ joints_3d[:, 0:2] = \
+ warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
+ else:
+ trans = get_affine_transform(c, s, r, image_size)
+ img = cv2.warpAffine(
+ img,
+ trans, (int(image_size[0]), int(image_size[1])),
+ flags=cv2.INTER_LINEAR)
+ for i in range(len(joints_3d)):
+ if joints_3d_visible[i, 0] > 0.0:
+ joints_3d[i,
+ 0:2] = affine_transform(joints_3d[i, 0:2], trans)
+
+ results['img'] = img
+ results['joints_3d'] = joints_3d
+ results['joints_3d_visible'] = joints_3d_visible
+
+ return results
+
+
+@PIPELINES.register_module()
+class TopDownGenerateTargetFewShot:
+ """Generate the target heatmap.
+
+ Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'.
+ Modified keys: 'target', and 'target_weight'.
+
+ Args:
+ sigma: Sigma of heatmap gaussian for 'MSRA' approach.
+ kernel: Kernel of heatmap gaussian for 'Megvii' approach.
+ encoding (str): Approach to generate target heatmaps.
+ Currently supported approaches: 'MSRA', 'Megvii', 'UDP'.
+ Default:'MSRA'
+
+ unbiased_encoding (bool): Option to use unbiased
+ encoding methods.
+ Paper ref: Zhang et al. Distribution-Aware Coordinate
+ Representation for Human Pose Estimation (CVPR 2020).
+ keypoint_pose_distance: Keypoint pose distance for UDP.
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
+ target_type (str): supported targets: 'GaussianHeatMap',
+ 'CombinedTarget'. Default:'GaussianHeatMap'
+ CombinedTarget: The combination of classification target
+ (response map) and regression target (offset map).
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
+ """
+
+ def __init__(self,
+ sigma=2,
+ kernel=(11, 11),
+ valid_radius_factor=0.0546875,
+ target_type='GaussianHeatMap',
+ encoding='MSRA',
+ unbiased_encoding=False):
+ self.sigma = sigma
+ self.unbiased_encoding = unbiased_encoding
+ self.kernel = kernel
+ self.valid_radius_factor = valid_radius_factor
+ self.target_type = target_type
+ self.encoding = encoding
+
+ def _msra_generate_target(self, cfg, joints_3d, joints_3d_visible, sigma):
+ """Generate the target heatmap via "MSRA" approach.
+
+ Args:
+ cfg (dict): data config
+ joints_3d: np.ndarray ([num_joints, 3])
+ joints_3d_visible: np.ndarray ([num_joints, 3])
+ sigma: Sigma of heatmap gaussian
+ Returns:
+ tuple: A tuple containing targets.
+
+ - target: Target heatmaps.
+ - target_weight: (1: visible, 0: invisible)
+ """
+ num_joints = len(joints_3d)
+ image_size = cfg['image_size']
+ W, H = cfg['heatmap_size']
+ joint_weights = cfg['joint_weights']
+ use_different_joint_weights = cfg['use_different_joint_weights']
+ assert not use_different_joint_weights
+
+ target_weight = np.zeros((num_joints, 1), dtype=np.float32)
+ target = np.zeros((num_joints, H, W), dtype=np.float32)
+
+ # 3-sigma rule
+ tmp_size = sigma * 3
+
+ if self.unbiased_encoding:
+ for joint_id in range(num_joints):
+ target_weight[joint_id] = joints_3d_visible[joint_id, 0]
+
+ feat_stride = image_size / [W, H]
+ mu_x = joints_3d[joint_id][0] / feat_stride[0]
+ mu_y = joints_3d[joint_id][1] / feat_stride[1]
+ # Check that any part of the gaussian is in-bounds
+ ul = [mu_x - tmp_size, mu_y - tmp_size]
+ br = [mu_x + tmp_size + 1, mu_y + tmp_size + 1]
+ if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
+ target_weight[joint_id] = 0
+
+ if target_weight[joint_id] == 0:
+ continue
+
+ x = np.arange(0, W, 1, np.float32)
+ y = np.arange(0, H, 1, np.float32)
+ y = y[:, None]
+
+ if target_weight[joint_id] > 0.5:
+ target[joint_id] = np.exp(-((x - mu_x) ** 2 +
+ (y - mu_y) ** 2) /
+ (2 * sigma ** 2))
+ else:
+ for joint_id in range(num_joints):
+ target_weight[joint_id] = joints_3d_visible[joint_id, 0]
+
+ feat_stride = image_size / [W, H]
+ mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
+ mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
+ # Check that any part of the gaussian is in-bounds
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
+ if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
+ target_weight[joint_id] = 0
+
+ if target_weight[joint_id] > 0.5:
+ size = 2 * tmp_size + 1
+ x = np.arange(0, size, 1, np.float32)
+ y = x[:, None]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized,
+ # we want the center value to equal 1
+ g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], W) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], H) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], W)
+ img_y = max(0, ul[1]), min(br[1], H)
+
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+
+ if use_different_joint_weights:
+ target_weight = np.multiply(target_weight, joint_weights)
+
+ return target, target_weight
+
+ def _udp_generate_target(self, cfg, joints_3d, joints_3d_visible, factor,
+ target_type):
+ """Generate the target heatmap via 'UDP' approach. Paper ref: Huang et
+ al. The Devil is in the Details: Delving into Unbiased Data Processing
+ for Human Pose Estimation (CVPR 2020).
+
+ Note:
+ num keypoints: K
+ heatmap height: H
+ heatmap width: W
+ num target channels: C
+ C = K if target_type=='GaussianHeatMap'
+ C = 3*K if target_type=='CombinedTarget'
+
+ Args:
+ cfg (dict): data config
+ joints_3d (np.ndarray[K, 3]): Annotated keypoints.
+ joints_3d_visible (np.ndarray[K, 3]): Visibility of keypoints.
+ factor (float): kernel factor for GaussianHeatMap target or
+ valid radius factor for CombinedTarget.
+ target_type (str): 'GaussianHeatMap' or 'CombinedTarget'.
+ GaussianHeatMap: Heatmap target with gaussian distribution.
+ CombinedTarget: The combination of classification target
+ (response map) and regression target (offset map).
+
+ Returns:
+ tuple: A tuple containing targets.
+
+ - target (np.ndarray[C, H, W]): Target heatmaps.
+ - target_weight (np.ndarray[K, 1]): (1: visible, 0: invisible)
+ """
+ num_joints = len(joints_3d)
+ image_size = cfg['image_size']
+ heatmap_size = cfg['heatmap_size']
+ joint_weights = cfg['joint_weights']
+ use_different_joint_weights = cfg['use_different_joint_weights']
+ assert not use_different_joint_weights
+
+ target_weight = np.ones((num_joints, 1), dtype=np.float32)
+ target_weight[:, 0] = joints_3d_visible[:, 0]
+
+ assert target_type in ['GaussianHeatMap', 'CombinedTarget']
+
+ if target_type == 'GaussianHeatMap':
+ target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
+ dtype=np.float32)
+
+ tmp_size = factor * 3
+
+ # prepare for gaussian
+ size = 2 * tmp_size + 1
+ x = np.arange(0, size, 1, np.float32)
+ y = x[:, None]
+
+ for joint_id in range(num_joints):
+ feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
+ mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
+ mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
+ # Check that any part of the gaussian is in-bounds
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
+ if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
+ or br[0] < 0 or br[1] < 0:
+ # If not, just return the image as is
+ target_weight[joint_id] = 0
+ continue
+
+ # # Generate gaussian
+ mu_x_ac = joints_3d[joint_id][0] / feat_stride[0]
+ mu_y_ac = joints_3d[joint_id][1] / feat_stride[1]
+ x0 = y0 = size // 2
+ x0 += mu_x_ac - mu_x
+ y0 += mu_y_ac - mu_y
+ g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * factor ** 2))
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
+ img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
+
+ v = target_weight[joint_id]
+ if v > 0.5:
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+ elif target_type == 'CombinedTarget':
+ target = np.zeros(
+ (num_joints, 3, heatmap_size[1] * heatmap_size[0]),
+ dtype=np.float32)
+ feat_width = heatmap_size[0]
+ feat_height = heatmap_size[1]
+ feat_x_int = np.arange(0, feat_width)
+ feat_y_int = np.arange(0, feat_height)
+ feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)
+ feat_x_int = feat_x_int.flatten()
+ feat_y_int = feat_y_int.flatten()
+ # Calculate the radius of the positive area in classification
+ # heatmap.
+ valid_radius = factor * heatmap_size[1]
+ feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
+ for joint_id in range(num_joints):
+ mu_x = joints_3d[joint_id][0] / feat_stride[0]
+ mu_y = joints_3d[joint_id][1] / feat_stride[1]
+ x_offset = (mu_x - feat_x_int) / valid_radius
+ y_offset = (mu_y - feat_y_int) / valid_radius
+ dis = x_offset ** 2 + y_offset ** 2
+ keep_pos = np.where(dis <= 1)[0]
+ v = target_weight[joint_id]
+ if v > 0.5:
+ target[joint_id, 0, keep_pos] = 1
+ target[joint_id, 1, keep_pos] = x_offset[keep_pos]
+ target[joint_id, 2, keep_pos] = y_offset[keep_pos]
+ target = target.reshape(num_joints * 3, heatmap_size[1],
+ heatmap_size[0])
+
+ if use_different_joint_weights:
+ target_weight = np.multiply(target_weight, joint_weights)
+
+ return target, target_weight
+
+ def __call__(self, results):
+ """Generate the target heatmap."""
+ joints_3d = results['joints_3d']
+ joints_3d_visible = results['joints_3d_visible']
+
+ assert self.encoding in ['MSRA', 'UDP']
+
+ if self.encoding == 'MSRA':
+ if isinstance(self.sigma, list):
+ num_sigmas = len(self.sigma)
+ cfg = results['ann_info']
+ num_joints = len(joints_3d)
+ heatmap_size = cfg['heatmap_size']
+
+ target = np.empty(
+ (0, num_joints, heatmap_size[1], heatmap_size[0]),
+ dtype=np.float32)
+ target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
+ for i in range(num_sigmas):
+ target_i, target_weight_i = self._msra_generate_target(
+ cfg, joints_3d, joints_3d_visible, self.sigma[i])
+ target = np.concatenate([target, target_i[None]], axis=0)
+ target_weight = np.concatenate(
+ [target_weight, target_weight_i[None]], axis=0)
+ else:
+ target, target_weight = self._msra_generate_target(
+ results['ann_info'], joints_3d, joints_3d_visible,
+ self.sigma)
+ elif self.encoding == 'UDP':
+ if self.target_type == 'CombinedTarget':
+ factors = self.valid_radius_factor
+ channel_factor = 3
+ elif self.target_type == 'GaussianHeatMap':
+ factors = self.sigma
+ channel_factor = 1
+ if isinstance(factors, list):
+ num_factors = len(factors)
+ cfg = results['ann_info']
+ num_joints = len(joints_3d)
+ W, H = cfg['heatmap_size']
+
+ target = np.empty((0, channel_factor * num_joints, H, W),
+ dtype=np.float32)
+ target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
+ for i in range(num_factors):
+ target_i, target_weight_i = self._udp_generate_target(
+ cfg, joints_3d, joints_3d_visible, factors[i],
+ self.target_type)
+ target = np.concatenate([target, target_i[None]], axis=0)
+ target_weight = np.concatenate(
+ [target_weight, target_weight_i[None]], axis=0)
+ else:
+ target, target_weight = self._udp_generate_target(
+ results['ann_info'], joints_3d, joints_3d_visible, factors,
+ self.target_type)
+ else:
+ raise ValueError(
+ f'Encoding approach {self.encoding} is not supported!')
+
+ results['target'] = target
+ results['target_weight'] = target_weight
+
+ return results
diff --git a/models/models/__init__.py b/models/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d72100303b1ec7ddb0ca901e154e0d53d6abb3a8
--- /dev/null
+++ b/models/models/__init__.py
@@ -0,0 +1,3 @@
+from .backbones import * # noqa
+from .detectors import * # noqa
+from .keypoint_heads import * # noqa
diff --git a/models/models/backbones/__init__.py b/models/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff5119ffe5d6bfc66c9fcfa8128b521a1a3563d
--- /dev/null
+++ b/models/models/backbones/__init__.py
@@ -0,0 +1 @@
+from .swin_transformer_v2 import SwinTransformerV2
diff --git a/models/models/backbones/simmim.py b/models/models/backbones/simmim.py
new file mode 100644
index 0000000000000000000000000000000000000000..c591856212a84a7b237ecc44f5067b4fbb0bbb26
--- /dev/null
+++ b/models/models/backbones/simmim.py
@@ -0,0 +1,210 @@
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Zhenda Xie
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_
+
+from .swin_transformer import SwinTransformer
+from .swin_transformer_v2 import SwinTransformerV2
+
+
+def norm_targets(targets, patch_size):
+ assert patch_size % 2 == 1
+
+ targets_ = targets
+ targets_count = torch.ones_like(targets)
+
+ targets_square = targets ** 2.
+
+ targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2,
+ count_include_pad=False)
+ targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2,
+ count_include_pad=False)
+ targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2,
+ count_include_pad=True) * (patch_size ** 2)
+
+ targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))
+ targets_var = torch.clamp(targets_var, min=0.)
+
+ targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5
+
+ return targets_
+
+
+class SwinTransformerForSimMIM(SwinTransformer):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ assert self.num_classes == 0
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ trunc_normal_(self.mask_token, mean=0., std=.02)
+
+ def forward(self, x, mask):
+ x = self.patch_embed(x)
+
+ assert mask is not None
+ B, L, _ = x.shape
+
+ mask_tokens = self.mask_token.expand(B, L, -1)
+ w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
+ x = x * (1. - w) + mask_tokens * w
+
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x)
+ x = self.norm(x)
+
+ x = x.transpose(1, 2)
+ B, C, L = x.shape
+ H = W = int(L ** 0.5)
+ x = x.reshape(B, C, H, W)
+ return x
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return super().no_weight_decay() | {'mask_token'}
+
+
+class SwinTransformerV2ForSimMIM(SwinTransformerV2):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ assert self.num_classes == 0
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ trunc_normal_(self.mask_token, mean=0., std=.02)
+
+ def forward(self, x, mask):
+ x = self.patch_embed(x)
+
+ assert mask is not None
+ B, L, _ = x.shape
+
+ mask_tokens = self.mask_token.expand(B, L, -1)
+ w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
+ x = x * (1. - w) + mask_tokens * w
+
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x)
+ x = self.norm(x)
+
+ x = x.transpose(1, 2)
+ B, C, L = x.shape
+ H = W = int(L ** 0.5)
+ x = x.reshape(B, C, H, W)
+ return x
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return super().no_weight_decay() | {'mask_token'}
+
+
+class SimMIM(nn.Module):
+ def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
+ super().__init__()
+ self.config = config
+ self.encoder = encoder
+ self.encoder_stride = encoder_stride
+
+ self.decoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.encoder.num_features,
+ out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
+ nn.PixelShuffle(self.encoder_stride),
+ )
+
+ self.in_chans = in_chans
+ self.patch_size = patch_size
+
+ def forward(self, x, mask):
+ z = self.encoder(x, mask)
+ x_rec = self.decoder(z)
+
+ mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(
+ 1).contiguous()
+
+ # norm target as prompted
+ if self.config.NORM_TARGET.ENABLE:
+ x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE)
+
+ loss_recon = F.l1_loss(x, x_rec, reduction='none')
+ loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
+ return loss
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ if hasattr(self.encoder, 'no_weight_decay'):
+ return {'encoder.' + i for i in self.encoder.no_weight_decay()}
+ return {}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ if hasattr(self.encoder, 'no_weight_decay_keywords'):
+ return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
+ return {}
+
+
+def build_simmim(config):
+ model_type = config.MODEL.TYPE
+ if model_type == 'swin':
+ encoder = SwinTransformerForSimMIM(
+ img_size=config.DATA.IMG_SIZE,
+ patch_size=config.MODEL.SWIN.PATCH_SIZE,
+ in_chans=config.MODEL.SWIN.IN_CHANS,
+ num_classes=0,
+ embed_dim=config.MODEL.SWIN.EMBED_DIM,
+ depths=config.MODEL.SWIN.DEPTHS,
+ num_heads=config.MODEL.SWIN.NUM_HEADS,
+ window_size=config.MODEL.SWIN.WINDOW_SIZE,
+ mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
+ qkv_bias=config.MODEL.SWIN.QKV_BIAS,
+ qk_scale=config.MODEL.SWIN.QK_SCALE,
+ drop_rate=config.MODEL.DROP_RATE,
+ drop_path_rate=config.MODEL.DROP_PATH_RATE,
+ ape=config.MODEL.SWIN.APE,
+ patch_norm=config.MODEL.SWIN.PATCH_NORM,
+ use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+ encoder_stride = 32
+ in_chans = config.MODEL.SWIN.IN_CHANS
+ patch_size = config.MODEL.SWIN.PATCH_SIZE
+ elif model_type == 'swinv2':
+ encoder = SwinTransformerV2ForSimMIM(
+ img_size=config.DATA.IMG_SIZE,
+ patch_size=config.MODEL.SWINV2.PATCH_SIZE,
+ in_chans=config.MODEL.SWINV2.IN_CHANS,
+ num_classes=0,
+ embed_dim=config.MODEL.SWINV2.EMBED_DIM,
+ depths=config.MODEL.SWINV2.DEPTHS,
+ num_heads=config.MODEL.SWINV2.NUM_HEADS,
+ window_size=config.MODEL.SWINV2.WINDOW_SIZE,
+ mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
+ qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
+ drop_rate=config.MODEL.DROP_RATE,
+ drop_path_rate=config.MODEL.DROP_PATH_RATE,
+ ape=config.MODEL.SWINV2.APE,
+ patch_norm=config.MODEL.SWINV2.PATCH_NORM,
+ use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+ encoder_stride = 32
+ in_chans = config.MODEL.SWINV2.IN_CHANS
+ patch_size = config.MODEL.SWINV2.PATCH_SIZE
+ else:
+ raise NotImplementedError(f"Unknown pre-train model: {model_type}")
+
+ model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans,
+ patch_size=patch_size)
+
+ return model
diff --git a/models/models/backbones/swin_mlp.py b/models/models/backbones/swin_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..115c43cd1f2d8788c4dcc2f1b83a091035e640c3
--- /dev/null
+++ b/models/models/backbones/swin_mlp.py
@@ -0,0 +1,468 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class SwinMLPBlock(nn.Module):
+ r""" Swin MLP Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.padding = [self.window_size - self.shift_size, self.shift_size,
+ self.window_size - self.shift_size, self.shift_size] # P_l,P_r,P_t,P_b
+
+ self.norm1 = norm_layer(dim)
+ # use group convolution to implement multi-head MLP
+ self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,
+ self.num_heads * self.window_size ** 2,
+ kernel_size=1,
+ groups=self.num_heads)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # shift
+ if self.shift_size > 0:
+ P_l, P_r, P_t, P_b = self.padding
+ shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0)
+ else:
+ shifted_x = x
+ _, _H, _W, _ = shifted_x.shape
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # Window/Shifted-Window Spatial MLP
+ x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads)
+ x_windows_heads = x_windows_heads.transpose(1, 2) # nW*B, nH, window_size*window_size, C//nH
+ x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,
+ C // self.num_heads)
+ spatial_mlp_windows = self.spatial_mlp(x_windows_heads) # nW*B, nH*window_size*window_size, C//nH
+ spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,
+ C // self.num_heads).transpose(1, 2)
+ spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)
+
+ # merge windows
+ spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W) # B H' W' C
+
+ # reverse shift
+ if self.shift_size > 0:
+ P_l, P_r, P_t, P_b = self.padding
+ x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+
+ # Window/Shifted-Window Spatial MLP
+ if self.shift_size > 0:
+ nW = (H / self.window_size + 1) * (W / self.window_size + 1)
+ else:
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin MLP layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., drop=0., drop_path=0.,
+ norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinMLPBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class SwinMLP(nn.Module):
+ r""" Swin MLP
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin MLP layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ drop_rate (float): Dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+ window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ drop=drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Linear, nn.Conv1d)):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x)
+
+ x = self.norm(x) # B L C
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
+ x = torch.flatten(x, 1)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
diff --git a/models/models/backbones/swin_transformer.py b/models/models/backbones/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dde06bc5b4c4e56f5c09adf84679b01e5eeaf488
--- /dev/null
+++ b/models/models/backbones/swin_transformer.py
@@ -0,0 +1,614 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+try:
+ import os, sys
+
+ kernel_path = os.path.abspath(os.path.join('..'))
+ sys.path.append(kernel_path)
+ from kernels.window_process.window_process import WindowProcess, WindowProcessReverse
+
+except:
+ WindowProcess = None
+ WindowProcessReverse = None
+ print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ fused_window_process=False):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+ self.fused_window_process = fused_window_process
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ if not self.fused_window_process:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ else:
+ x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
+ else:
+ shifted_x = x
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ if not self.fused_window_process:
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
+ else:
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(x)
+
+ # FFN
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ fused_window_process=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ fused_window_process=fused_window_process)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class SwinTransformer(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, fused_window_process=False, **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ fused_window_process=fused_window_process)
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x)
+
+ x = self.norm(x) # B L C
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
+ x = torch.flatten(x, 1)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
diff --git a/models/models/backbones/swin_transformer_moe.py b/models/models/backbones/swin_transformer_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b07d21740f08e8e2bb9ace6cd4fed6ada4ee042
--- /dev/null
+++ b/models/models/backbones/swin_transformer_moe.py
@@ -0,0 +1,824 @@
+# --------------------------------------------------------
+# Swin Transformer MoE
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+try:
+ from tutel import moe as tutel_moe
+except:
+ tutel_moe = None
+ print("Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.")
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
+ mlp_fc2_bias=True):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class MoEMlp(nn.Module):
+ def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25,
+ cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True,
+ gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02,
+ mlp_fc2_bias=True):
+ super().__init__()
+
+ self.in_features = in_features
+ self.hidden_features = hidden_features
+ self.num_local_experts = num_local_experts
+ self.top_value = top_value
+ self.capacity_factor = capacity_factor
+ self.cosine_router = cosine_router
+ self.normalize_gate = normalize_gate
+ self.use_bpr = use_bpr
+ self.init_std = init_std
+ self.mlp_fc2_bias = mlp_fc2_bias
+
+ self.dist_rank = dist.get_rank()
+
+ self._dropout = nn.Dropout(p=moe_drop)
+
+ _gate_type = {'type': 'cosine_top' if cosine_router else 'top',
+ 'k': top_value, 'capacity_factor': capacity_factor,
+ 'gate_noise': gate_noise, 'fp32_gate': True}
+ if cosine_router:
+ _gate_type['proj_dim'] = cosine_router_dim
+ _gate_type['init_t'] = cosine_router_init_t
+ self._moe_layer = tutel_moe.moe_layer(
+ gate_type=_gate_type,
+ model_dim=in_features,
+ experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features,
+ 'activation_fn': lambda x: self._dropout(F.gelu(x))},
+ scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True),
+ seeds=(1, self.dist_rank + 1, self.dist_rank + 1),
+ batch_prioritized_routing=use_bpr,
+ normalize_gate=normalize_gate,
+ is_gshard_loss=is_gshard_loss,
+
+ )
+ if not self.mlp_fc2_bias:
+ self._moe_layer.experts.batched_fc2_bias.requires_grad = False
+
+ def forward(self, x):
+ x = self._moe_layer(x)
+ return x, x.l_aux
+
+ def extra_repr(self) -> str:
+ return f'[Statistics-{self.dist_rank}] param count for MoE, ' \
+ f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \
+ f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \
+ f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}'
+
+ def _init_weights(self):
+ if hasattr(self._moe_layer, "experts"):
+ trunc_normal_(self._moe_layer.experts.batched_fc1_w, std=self.init_std)
+ trunc_normal_(self._moe_layer.experts.batched_fc2_w, std=self.init_std)
+ nn.init.constant_(self._moe_layer.experts.batched_fc1_bias, 0)
+ nn.init.constant_(self._moe_layer.experts.batched_fc2_bias, 0)
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True
+ init_std: Initialization std. Default: 0.02
+ pretrained_window_size (int): Window size in pre-training.
+ is_moe (bool): If True, this block is a MoE block.
+ num_local_experts (int): number of local experts in each device (GPU). Default: 1
+ top_value (int): the value of k in top-k gating. Default: 1
+ capacity_factor (float): the capacity factor in MoE. Default: 1.25
+ cosine_router (bool): Whether to use cosine router. Default: False
+ normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False
+ use_bpr (bool): Whether to use batch-prioritized-routing. Default: True
+ is_gshard_loss (bool): If True, use Gshard balance loss.
+ If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False
+ gate_noise (float): the noise ratio in top-k gating. Default: 1.0
+ cosine_router_dim (int): Projection dimension in cosine router.
+ cosine_router_init_t (float): Initialization temperature in cosine router.
+ moe_drop (float): Dropout rate in MoE. Default: 0.0
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0,
+ is_moe=False, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,
+ normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,
+ cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.is_moe = is_moe
+ self.capacity_factor = capacity_factor
+ self.top_value = top_value
+
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if self.is_moe:
+ self.mlp = MoEMlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ num_local_experts=num_local_experts,
+ top_value=top_value,
+ capacity_factor=capacity_factor,
+ cosine_router=cosine_router,
+ normalize_gate=normalize_gate,
+ use_bpr=use_bpr,
+ is_gshard_loss=is_gshard_loss,
+ gate_noise=gate_noise,
+ cosine_router_dim=cosine_router_dim,
+ cosine_router_init_t=cosine_router_init_t,
+ moe_drop=moe_drop,
+ mlp_fc2_bias=mlp_fc2_bias,
+ init_std=init_std)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ mlp_fc2_bias=mlp_fc2_bias)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(x)
+
+ # FFN
+ shortcut = x
+ x = self.norm2(x)
+ if self.is_moe:
+ x, l_aux = self.mlp(x)
+ x = shortcut + self.drop_path(x)
+ return x, l_aux
+ else:
+ x = shortcut + self.drop_path(self.mlp(x))
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ if self.is_moe:
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * self.capacity_factor * self.top_value
+ else:
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True
+ init_std: Initialization std. Default: 0.02
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ moe_blocks (tuple(int)): The index of each MoE block.
+ num_local_experts (int): number of local experts in each device (GPU). Default: 1
+ top_value (int): the value of k in top-k gating. Default: 1
+ capacity_factor (float): the capacity factor in MoE. Default: 1.25
+ cosine_router (bool): Whether to use cosine router Default: False
+ normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False
+ use_bpr (bool): Whether to use batch-prioritized-routing. Default: True
+ is_gshard_loss (bool): If True, use Gshard balance loss.
+ If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False
+ gate_noise (float): the noise ratio in top-k gating. Default: 1.0
+ cosine_router_dim (int): Projection dimension in cosine router.
+ cosine_router_init_t (float): Initialization temperature in cosine router.
+ moe_drop (float): Dropout rate in MoE. Default: 0.0
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
+ mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0,
+ moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,
+ normalize_gate=False, use_bpr=True, is_gshard_loss=True,
+ cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ mlp_fc2_bias=mlp_fc2_bias,
+ init_std=init_std,
+ pretrained_window_size=pretrained_window_size,
+
+ is_moe=True if i in moe_block else False,
+ num_local_experts=num_local_experts,
+ top_value=top_value,
+ capacity_factor=capacity_factor,
+ cosine_router=cosine_router,
+ normalize_gate=normalize_gate,
+ use_bpr=use_bpr,
+ is_gshard_loss=is_gshard_loss,
+ gate_noise=gate_noise,
+ cosine_router_dim=cosine_router_dim,
+ cosine_router_init_t=cosine_router_init_t,
+ moe_drop=moe_drop)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ l_aux = 0.0
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ out = checkpoint.checkpoint(blk, x)
+ else:
+ out = blk(x)
+ if isinstance(out, tuple):
+ x = out[0]
+ cur_l_aux = out[1]
+ l_aux = cur_l_aux + l_aux
+ else:
+ x = out
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x, l_aux
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class SwinTransformerMoE(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True
+ init_std: Initialization std. Default: 0.02
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
+ moe_blocks (tuple(tuple(int))): The index of each MoE block in each layer.
+ num_local_experts (int): number of local experts in each device (GPU). Default: 1
+ top_value (int): the value of k in top-k gating. Default: 1
+ capacity_factor (float): the capacity factor in MoE. Default: 1.25
+ cosine_router (bool): Whether to use cosine router Default: False
+ normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False
+ use_bpr (bool): Whether to use batch-prioritized-routing. Default: True
+ is_gshard_loss (bool): If True, use Gshard balance loss.
+ If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False
+ gate_noise (float): the noise ratio in top-k gating. Default: 1.0
+ cosine_router_dim (int): Projection dimension in cosine router.
+ cosine_router_init_t (float): Initialization temperature in cosine router.
+ moe_drop (float): Dropout rate in MoE. Default: 0.0
+ aux_loss_weight (float): auxiliary loss weight. Default: 0.1
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],
+ moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25,
+ cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,
+ cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs):
+ super().__init__()
+ self._ddp_params_and_buffers_to_ignore = list()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+ self.init_std = init_std
+ self.aux_loss_weight = aux_loss_weight
+ self.num_local_experts = num_local_experts
+ self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \
+ else dist.get_world_size() // (-num_local_experts)
+ self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts)
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=self.init_std)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ mlp_fc2_bias=mlp_fc2_bias,
+ init_std=init_std,
+ use_checkpoint=use_checkpoint,
+ pretrained_window_size=pretrained_window_sizes[i_layer],
+
+ moe_block=moe_blocks[i_layer],
+ num_local_experts=num_local_experts,
+ top_value=top_value,
+ capacity_factor=capacity_factor,
+ cosine_router=cosine_router,
+ normalize_gate=normalize_gate,
+ use_bpr=use_bpr,
+ is_gshard_loss=is_gshard_loss,
+ gate_noise=gate_noise,
+ cosine_router_dim=cosine_router_dim,
+ cosine_router_init_t=cosine_router_init_t,
+ moe_drop=moe_drop)
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=self.init_std)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, MoEMlp):
+ m._init_weights()
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {"cpb_mlp", 'relative_position_bias_table', 'fc1_bias', 'fc2_bias',
+ 'temperature', 'cosine_projector', 'sim_matrix'}
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ l_aux = 0.0
+ for layer in self.layers:
+ x, cur_l_aux = layer(x)
+ l_aux = cur_l_aux + l_aux
+
+ x = self.norm(x) # B L C
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
+ x = torch.flatten(x, 1)
+ return x, l_aux
+
+ def forward(self, x):
+ x, l_aux = self.forward_features(x)
+ x = self.head(x)
+ return x, l_aux * self.aux_loss_weight
+
+ def add_param_to_skip_allreduce(self, param_name):
+ self._ddp_params_and_buffers_to_ignore.append(param_name)
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
diff --git a/models/models/backbones/swin_transformer_v2.py b/models/models/backbones/swin_transformer_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7af47b2bb3f01f5e71912c42157dac4b7e162feb
--- /dev/null
+++ b/models/models/backbones/swin_transformer_v2.py
@@ -0,0 +1,680 @@
+# --------------------------------------------------------
+# Swin Transformer V2
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from mmpose.models.builder import BACKBONES
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=x.device))).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+@BACKBONES.register_module()
+class SwinTransformerV2(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],
+ multi_scale=False, upsample='deconv', **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ pretrained_window_size=pretrained_window_sizes[i_layer])
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.multi_scale = multi_scale
+ if self.multi_scale:
+ self.scales = [1, 2, 4, 4]
+ self.upsample = nn.ModuleList()
+ features = [int(embed_dim * 2 ** i) for i in range(1, self.num_layers)] + [self.num_features]
+ self.multi_scale_fuse = nn.Conv2d(sum(features), self.num_features, 1)
+ for i in range(self.num_layers):
+ self.upsample.append(nn.Upsample(scale_factor=self.scales[i]))
+ else:
+ if upsample == 'deconv':
+ self.upsample = nn.ConvTranspose2d(self.num_features, self.num_features, 2, stride=2)
+ elif upsample == 'new_deconv':
+ self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ nn.Conv2d(self.num_features, self.num_features, 3, stride=1, padding=1),
+ nn.BatchNorm2d(self.num_features),
+ nn.ReLU(inplace=True)
+ )
+ elif upsample == 'new_deconv2':
+ self.upsample = nn.Sequential(nn.Upsample(scale_factor=2),
+ nn.Conv2d(self.num_features, self.num_features, 3, stride=1, padding=1),
+ nn.BatchNorm2d(self.num_features),
+ nn.ReLU(inplace=True)
+ )
+ elif upsample == 'bilinear':
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+ else:
+ self.upsample = nn.Identity()
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+ for bly in self.layers:
+ bly._init_respostnorm()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ B, C, H, W = x.shape
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ if self.multi_scale:
+ # x_2d = x.view(B, H // 4, W // 4, -1).permute(0, 3, 1, 2) # B C H W
+ # features = [self.upsample[0](x_2d)]
+ features = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ x_2d = x.view(B, H // (8 * self.scales[i]), W // (8 * self.scales[i]), -1).permute(0, 3, 1,
+ 2) # B C H W
+ features.append(self.upsample[i](x_2d))
+ x = torch.cat(features, dim=1)
+ x = self.multi_scale_fuse(x)
+ x = x.view(B, self.num_features, -1).permute(0, 2, 1)
+ x = self.norm(x) # B L C
+ x = x.view(B, H // 8, W // 8, self.num_features).permute(0, 3, 1, 2) # B C H W
+
+ else:
+ for layer in self.layers:
+ x = layer(x)
+ x = self.norm(x) # B L C
+ x = x.view(B, H // 32, W // 32, self.num_features).permute(0, 3, 1, 2) # B C H W
+ x = self.upsample(x)
+
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
diff --git a/models/models/backbones/swin_utils.py b/models/models/backbones/swin_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3453f850256cfac99b83f9abfef57b85b47c2ce8
--- /dev/null
+++ b/models/models/backbones/swin_utils.py
@@ -0,0 +1,116 @@
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# Modified by Zhenda Xie
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+from scipy import interpolate
+
+
+def load_pretrained(config, model, logger):
+ checkpoint = torch.load(config, map_location='cpu')
+ checkpoint_model = checkpoint['model']
+
+ if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):
+ checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if
+ k.startswith('encoder.')}
+ print('Detect pre-trained model, remove [encoder.] prefix.')
+ else:
+ print('Detect non-pre-trained model, pass without doing anything.')
+
+ checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)
+ msg = model.load_state_dict(checkpoint_model, strict=False)
+ print(msg)
+
+ del checkpoint
+ torch.cuda.empty_cache()
+
+
+def remap_pretrained_keys_swin(model, checkpoint_model, logger):
+ state_dict = model.state_dict()
+
+ # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
+ all_keys = list(checkpoint_model.keys())
+ for key in all_keys:
+ if "relative_position_bias_table" in key:
+ relative_position_bias_table_pretrained = checkpoint_model[key]
+ relative_position_bias_table_current = state_dict[key]
+ L1, nH1 = relative_position_bias_table_pretrained.size()
+ L2, nH2 = relative_position_bias_table_current.size()
+ if nH1 != nH2:
+ print(f"Error in loading {key}, passing......")
+ else:
+ if L1 != L2:
+ print(f"{key}: Interpolate relative_position_bias_table using geo.")
+ src_size = int(L1 ** 0.5)
+ dst_size = int(L2 ** 0.5)
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r ** n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.090307:
+ # q = 1.090307
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+
+ print("Original positions = %s" % str(x))
+ print("Target positions = %s" % str(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(nH1):
+ z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy()
+ f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
+ all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(
+ relative_position_bias_table_pretrained.device))
+
+ new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+ checkpoint_model[key] = new_rel_pos_bias
+
+ # delete relative_position_index since we always re-init it
+ relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k]
+ for k in relative_position_index_keys:
+ del checkpoint_model[k]
+
+ # delete relative_coords_table since we always re-init it
+ relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k]
+ for k in relative_coords_table_keys:
+ del checkpoint_model[k]
+
+ # re-map keys due to name change
+ rpe_mlp_keys = [k for k in checkpoint_model.keys() if "rpe_mlp" in k]
+ for k in rpe_mlp_keys:
+ checkpoint_model[k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint_model.pop(k)
+
+ # delete attn_mask since we always re-init it
+ attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
+ for k in attn_mask_keys:
+ del checkpoint_model[k]
+
+ return checkpoint_model
diff --git a/models/models/detectors/__init__.py b/models/models/detectors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..038450a700175662ed8fc1e92c67b4cdcae54cd1
--- /dev/null
+++ b/models/models/detectors/__init__.py
@@ -0,0 +1,3 @@
+from .pam import PoseAnythingModel
+
+__all__ = ['PoseAnythingModel']
diff --git a/models/models/detectors/pam.py b/models/models/detectors/pam.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc49b339e792c59b4bd43b04e5cf646685d959ce
--- /dev/null
+++ b/models/models/detectors/pam.py
@@ -0,0 +1,381 @@
+import numpy as np
+import torch
+from mmpose.models import builder
+from mmpose.models.builder import POSENETS
+from mmpose.models.detectors.base import BasePose
+
+from models.models.backbones.swin_utils import load_pretrained
+
+
+@POSENETS.register_module()
+class PoseAnythingModel(BasePose):
+ """Few-shot keypoint detectors.
+ Args:
+ keypoint_head (dict): Keypoint head to process feature.
+ encoder_config (dict): Config for encoder. Default: None.
+ pretrained (str): Path to the pretrained models.
+ train_cfg (dict): Config for training. Default: None.
+ test_cfg (dict): Config for testing. Default: None.
+ """
+
+ def __init__(self,
+ keypoint_head,
+ encoder_config,
+ pretrained=False,
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__()
+ self.backbone, self.backbone_type = self.init_backbone(pretrained, encoder_config)
+ self.keypoint_head = builder.build_head(keypoint_head)
+ self.keypoint_head.init_weights()
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.target_type = test_cfg.get('target_type',
+ 'GaussianHeatMap') # GaussianHeatMap
+
+ def init_backbone(self, pretrained, encoder_config):
+ if 'swin' in pretrained:
+ encoder_sample = builder.build_backbone(encoder_config)
+ if '.pth' in pretrained:
+ load_pretrained(pretrained, encoder_sample, logger=None)
+ backbone = 'swin'
+ elif 'dino' in pretrained:
+ if 'dinov2' in pretrained:
+ repo = 'facebookresearch/dinov2'
+ backbone = 'dinov2'
+ else:
+ repo = 'facebookresearch/dino:main'
+ backbone = 'dino'
+ encoder_sample = torch.hub.load(repo, pretrained)
+ elif 'resnet' in pretrained:
+ pretrained = 'torchvision://resnet50'
+ encoder_config = dict(type='ResNet', depth=50, out_indices=(3,))
+ encoder_sample = builder.build_backbone(encoder_config)
+ encoder_sample.init_weights(pretrained)
+ backbone = 'resnet50'
+ else:
+ raise NotImplementedError(f'backbone {pretrained} not supported')
+ return encoder_sample, backbone
+
+ @property
+ def with_keypoint(self):
+ """Check if has keypoint_head."""
+ return hasattr(self, 'keypoint_head')
+
+ def init_weights(self, pretrained=None):
+ """Weight initialization for model."""
+ self.backbone.init_weights(pretrained)
+ self.encoder_query.init_weights(pretrained)
+ self.keypoint_head.init_weights()
+
+ def forward(self,
+ img_s,
+ img_q,
+ target_s=None,
+ target_weight_s=None,
+ target_q=None,
+ target_weight_q=None,
+ img_metas=None,
+ return_loss=True,
+ **kwargs):
+ """Defines the computation performed at every call."""
+
+ if return_loss:
+ return self.forward_train(img_s, target_s, target_weight_s, img_q,
+ target_q, target_weight_q, img_metas,
+ **kwargs)
+ else:
+ return self.forward_test(img_s, target_s, target_weight_s, img_q,
+ target_q, target_weight_q, img_metas,
+ **kwargs)
+
+ def forward_dummy(self, img_s, target_s, target_weight_s, img_q, target_q,
+ target_weight_q, img_metas, **kwargs):
+ return self.predict(
+ img_s, target_s, target_weight_s, img_q, img_metas)
+
+ def forward_train(self,
+ img_s,
+ target_s,
+ target_weight_s,
+ img_q,
+ target_q,
+ target_weight_q,
+ img_metas,
+ **kwargs):
+
+ """Defines the computation performed at every call when training."""
+ bs, _, h, w = img_q.shape
+
+ output, initial_proposals, similarity_map, mask_s = self.predict(
+ img_s, target_s, target_weight_s, img_q, img_metas)
+
+ # parse the img meta to get the target keypoints
+ target_keypoints = self.parse_keypoints_from_img_meta(img_metas, output.device, keyword='query')
+ target_sizes = torch.tensor([img_q.shape[-2], img_q.shape[-1]]).unsqueeze(0).repeat(img_q.shape[0], 1, 1)
+
+ # if return loss
+ losses = dict()
+ if self.with_keypoint:
+ keypoint_losses = self.keypoint_head.get_loss(
+ output, initial_proposals, similarity_map, target_keypoints,
+ target_q, target_weight_q * mask_s, target_sizes)
+ losses.update(keypoint_losses)
+ keypoint_accuracy = self.keypoint_head.get_accuracy(output[-1],
+ target_keypoints,
+ target_weight_q * mask_s,
+ target_sizes,
+ height=h)
+ losses.update(keypoint_accuracy)
+
+ return losses
+
+ def forward_test(self,
+ img_s,
+ target_s,
+ target_weight_s,
+ img_q,
+ target_q,
+ target_weight_q,
+ img_metas=None,
+ **kwargs):
+
+ """Defines the computation performed at every call when testing."""
+ batch_size, _, img_height, img_width = img_q.shape
+
+ output, initial_proposals, similarity_map, _ = self.predict(img_s, target_s, target_weight_s, img_q, img_metas)
+ predicted_pose = output[-1].detach().cpu().numpy() # [bs, num_query, 2]
+
+ result = {}
+ if self.with_keypoint:
+ keypoint_result = self.keypoint_head.decode(img_metas, predicted_pose, img_size=[img_width, img_height])
+ result.update(keypoint_result)
+
+ result.update({
+ "points":
+ torch.cat((initial_proposals, output.squeeze(1))).cpu().numpy()
+ })
+ result.update({"sample_image_file": img_metas[0]['sample_image_file']})
+
+ return result
+
+ def predict(self,
+ img_s,
+ target_s,
+ target_weight_s,
+ img_q,
+ img_metas=None):
+
+ batch_size, _, img_height, img_width = img_q.shape
+ assert [i['sample_skeleton'][0] != i['query_skeleton'] for i in img_metas]
+ skeleton = [i['sample_skeleton'][0] for i in img_metas]
+
+ feature_q, feature_s = self.extract_features(img_s, img_q)
+
+ mask_s = target_weight_s[0]
+ for target_weight in target_weight_s:
+ mask_s = mask_s * target_weight
+
+ output, initial_proposals, similarity_map = self.keypoint_head(feature_q, feature_s, target_s, mask_s, skeleton)
+
+ return output, initial_proposals, similarity_map, mask_s
+
+ def extract_features(self, img_s, img_q):
+ if self.backbone_type == 'swin':
+ feature_q = self.backbone.forward_features(img_q) # [bs, C, h, w]
+ feature_s = [self.backbone.forward_features(img) for img in img_s]
+ elif self.backbone_type == 'dino':
+ batch_size, _, img_height, img_width = img_q.shape
+ feature_q = self.backbone.get_intermediate_layers(img_q, n=1)[0][:, 1:] \
+ .reshape(batch_size, img_height // 8, img_width // 8, -1).permute(0, 3, 1, 2) # [bs, 3, h, w]
+ feature_s = [self.backbone.get_intermediate_layers(img, n=1)[0][:, 1:].
+ reshape(batch_size, img_height // 8, img_width // 8, -1).permute(0, 3, 1, 2) for img in img_s]
+ elif self.backbone_type == 'dinov2':
+ batch_size, _, img_height, img_width = img_q.shape
+ feature_q = self.backbone.get_intermediate_layers(img_q, n=1, reshape=True)[0] # [bs, c, h, w]
+ feature_s = [self.backbone.get_intermediate_layers(img, n=1, reshape=True)[0] for img in img_s]
+ else:
+ feature_s = [self.backbone(img) for img in img_s]
+ feature_q = self.encoder_query(img_q)
+
+ return feature_q, feature_s
+
+ def parse_keypoints_from_img_meta(self, img_meta, device, keyword='query'):
+ """Parse keypoints from the img_meta.
+
+ Args:
+ img_meta (dict): Image meta info.
+ device (torch.device): Device of the output keypoints.
+ keyword (str): 'query' or 'sample'. Default: 'query'.
+
+ Returns:
+ Tensor: Keypoints coordinates of query images.
+ """
+
+ if keyword == 'query':
+ query_kpt = torch.stack([
+ torch.tensor(info[f'{keyword}_joints_3d']).to(device)
+ for info in img_meta
+ ], dim=0)[:, :, :2] # [bs, num_query, 2]
+ else:
+ query_kpt = []
+ for info in img_meta:
+ if isinstance(info[f'{keyword}_joints_3d'][0], torch.Tensor):
+ samples = torch.stack(info[f'{keyword}_joints_3d'])
+ else:
+ samples = np.array(info[f'{keyword}_joints_3d'])
+ query_kpt.append(torch.tensor(samples).to(device)[:, :, :2])
+ query_kpt = torch.stack(query_kpt, dim=0) # [bs, , num_samples, num_query, 2]
+ return query_kpt
+
+
+ # UNMODIFIED
+ def show_result(self,
+ img,
+ result,
+ skeleton=None,
+ kpt_score_thr=0.3,
+ bbox_color='green',
+ pose_kpt_color=None,
+ pose_limb_color=None,
+ radius=4,
+ text_color=(255, 0, 0),
+ thickness=1,
+ font_scale=0.5,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (list[dict]): The results to draw over `img`
+ (bbox_result, pose_result).
+ kpt_score_thr (float, optional): Minimum score of keypoints
+ to be shown. Default: 0.3.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
+ If None, do not draw keypoints.
+ pose_limb_color (np.array[Mx3]): Color of M limbs.
+ If None, do not draw limbs.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ Tensor: Visualized img, only if not `show` or `out_file`.
+ """
+
+ img = mmcv.imread(img)
+ img = img.copy()
+ img_h, img_w, _ = img.shape
+
+ bbox_result = []
+ pose_result = []
+ for res in result:
+ bbox_result.append(res['bbox'])
+ pose_result.append(res['keypoints'])
+
+ if len(bbox_result) > 0:
+ bboxes = np.vstack(bbox_result)
+ # draw bounding boxes
+ mmcv.imshow_bboxes(
+ img,
+ bboxes,
+ colors=bbox_color,
+ top_k=-1,
+ thickness=thickness,
+ show=False,
+ win_name=win_name,
+ wait_time=wait_time,
+ out_file=None)
+
+ for person_id, kpts in enumerate(pose_result):
+ # draw each point on image
+ if pose_kpt_color is not None:
+ assert len(pose_kpt_color) == len(kpts), (
+ len(pose_kpt_color), len(kpts))
+ for kid, kpt in enumerate(kpts):
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(
+ kpt[1]), kpt[2]
+ if kpt_score > kpt_score_thr:
+ img_copy = img.copy()
+ r, g, b = pose_kpt_color[kid]
+ cv2.circle(img_copy, (int(x_coord), int(y_coord)),
+ radius, (int(r), int(g), int(b)), -1)
+ transparency = max(0, min(1, kpt_score))
+ cv2.addWeighted(
+ img_copy,
+ transparency,
+ img,
+ 1 - transparency,
+ 0,
+ dst=img)
+
+ # draw limbs
+ if skeleton is not None and pose_limb_color is not None:
+ assert len(pose_limb_color) == len(skeleton)
+ for sk_id, sk in enumerate(skeleton):
+ pos1 = (int(kpts[sk[0] - 1, 0]), int(kpts[sk[0] - 1,
+ 1]))
+ pos2 = (int(kpts[sk[1] - 1, 0]), int(kpts[sk[1] - 1,
+ 1]))
+ if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0
+ and pos1[1] < img_h and pos2[0] > 0
+ and pos2[0] < img_w and pos2[1] > 0
+ and pos2[1] < img_h
+ and kpts[sk[0] - 1, 2] > kpt_score_thr
+ and kpts[sk[1] - 1, 2] > kpt_score_thr):
+ img_copy = img.copy()
+ X = (pos1[0], pos2[0])
+ Y = (pos1[1], pos2[1])
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
+ angle = math.degrees(
+ math.atan2(Y[0] - Y[1], X[0] - X[1]))
+ stickwidth = 2
+ polygon = cv2.ellipse2Poly(
+ (int(mX), int(mY)),
+ (int(length / 2), int(stickwidth)), int(angle),
+ 0, 360, 1)
+
+ r, g, b = pose_limb_color[sk_id]
+ cv2.fillConvexPoly(img_copy, polygon,
+ (int(r), int(g), int(b)))
+ transparency = max(
+ 0,
+ min(
+ 1, 0.5 *
+ (kpts[sk[0] - 1, 2] + kpts[sk[1] - 1, 2])))
+ cv2.addWeighted(
+ img_copy,
+ transparency,
+ img,
+ 1 - transparency,
+ 0,
+ dst=img)
+
+ show, wait_time = 1, 1
+ if show:
+ height, width = img.shape[:2]
+ max_ = max(height, width)
+
+ factor = min(1, 800 / max_)
+ enlarge = cv2.resize(
+ img, (0, 0),
+ fx=factor,
+ fy=factor,
+ interpolation=cv2.INTER_CUBIC)
+ imshow(enlarge, win_name, wait_time)
+
+ if out_file is not None:
+ imwrite(img, out_file)
+
+ return img
\ No newline at end of file
diff --git a/models/models/keypoint_heads/__init__.py b/models/models/keypoint_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b42289e92a1e7a99103a2e532a41e107b68589d3
--- /dev/null
+++ b/models/models/keypoint_heads/__init__.py
@@ -0,0 +1,3 @@
+from .head import PoseHead
+
+__all__ = ['PoseHead']
diff --git a/models/models/keypoint_heads/head.py b/models/models/keypoint_heads/head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1d0bc69cebe8aeaa1434fa5478c5fe8cb1f711
--- /dev/null
+++ b/models/models/keypoint_heads/head.py
@@ -0,0 +1,368 @@
+from copy import deepcopy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (Conv2d, Linear, xavier_init)
+from mmcv.cnn.bricks.transformer import build_positional_encoding
+from mmpose.core.evaluation import keypoint_pck_accuracy
+from mmpose.core.post_processing import transform_preds
+from mmpose.models import HEADS
+from mmpose.models.utils.ops import resize
+
+from models.models.utils import build_transformer
+
+
+def inverse_sigmoid(x, eps=1e-3):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+class TokenDecodeMLP(nn.Module):
+ '''
+ The MLP used to predict coordinates from the support keypoints tokens.
+ '''
+
+ def __init__(self,
+ in_channels,
+ hidden_channels,
+ out_channels=2,
+ num_layers=3):
+ super(TokenDecodeMLP, self).__init__()
+ layers = []
+ for i in range(num_layers):
+ if i == 0:
+ layers.append(nn.Linear(in_channels, hidden_channels))
+ layers.append(nn.GELU())
+ else:
+ layers.append(nn.Linear(hidden_channels, hidden_channels))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_channels, out_channels))
+ self.mlp = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.mlp(x)
+
+
+@HEADS.register_module()
+class PoseHead(nn.Module):
+ '''
+ In two stage regression A3, the proposal generator are moved into transformer.
+ All valid proposals will be added with an positional embedding to better regress the location
+ '''
+
+ def __init__(self,
+ in_channels,
+ transformer=None,
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ encoder_positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=512,
+ normalize=True),
+ share_kpt_branch=False,
+ num_decoder_layer=3,
+ with_heatmap_loss=False,
+ with_bb_loss=False,
+ bb_temperature=0.2,
+ heatmap_loss_weight=2.0,
+ support_order_dropout=-1,
+ extra=None,
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.positional_encoding = build_positional_encoding(positional_encoding)
+ self.encoder_positional_encoding = build_positional_encoding(encoder_positional_encoding)
+ self.transformer = build_transformer(transformer)
+ self.embed_dims = self.transformer.d_model
+ self.with_heatmap_loss = with_heatmap_loss
+ self.with_bb_loss = with_bb_loss
+ self.bb_temperature = bb_temperature
+ self.heatmap_loss_weight = heatmap_loss_weight
+ self.support_order_dropout = support_order_dropout
+
+ assert 'num_feats' in positional_encoding
+ num_feats = positional_encoding['num_feats']
+ assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
+ f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
+ f' and {num_feats}.'
+ if extra is not None and not isinstance(extra, dict):
+ raise TypeError('extra should be dict or None.')
+ """Initialize layers of the transformer head."""
+ self.input_proj = Conv2d(self.in_channels, self.embed_dims, kernel_size=1)
+ self.query_proj = Linear(self.in_channels, self.embed_dims)
+ # Instantiate the proposal generator and subsequent keypoint branch.
+ kpt_branch = TokenDecodeMLP(
+ in_channels=self.embed_dims, hidden_channels=self.embed_dims)
+ if share_kpt_branch:
+ self.kpt_branch = nn.ModuleList(
+ [kpt_branch for i in range(num_decoder_layer)])
+ else:
+ self.kpt_branch = nn.ModuleList(
+ [deepcopy(kpt_branch) for i in range(num_decoder_layer)])
+
+ self.train_cfg = {} if train_cfg is None else train_cfg
+ self.test_cfg = {} if test_cfg is None else test_cfg
+ self.target_type = self.test_cfg.get('target_type', 'GaussianHeatMap')
+
+ def init_weights(self):
+ for m in self.modules():
+ if hasattr(m, 'weight') and m.weight.dim() > 1:
+ xavier_init(m, distribution='uniform')
+ """Initialize weights of the transformer head."""
+ # The initialization for transformer is important
+ self.transformer.init_weights()
+ # initialization for input_proj & prediction head
+ for mlp in self.kpt_branch:
+ nn.init.constant_(mlp.mlp[-1].weight.data, 0)
+ nn.init.constant_(mlp.mlp[-1].bias.data, 0)
+ nn.init.xavier_uniform_(self.input_proj.weight, gain=1)
+ nn.init.constant_(self.input_proj.bias, 0)
+
+ nn.init.xavier_uniform_(self.query_proj.weight, gain=1)
+ nn.init.constant_(self.query_proj.bias, 0)
+
+ def forward(self, x, feature_s, target_s, mask_s, skeleton):
+ """"Forward function for a single feature level.
+
+ Args:
+ x (Tensor): Input feature from backbone's single stage, shape
+ [bs, c, h, w].
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head,
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h).
+ Shape [nb_dec, bs, num_query, 4].
+ """
+ # construct binary masks which used for the transformer.
+ # NOTE following the official DETR repo, non-zero values representing
+ # ignored positions, while zero values means valid positions.
+
+ # process query image feature
+ x = self.input_proj(x)
+ bs, dim, h, w = x.shape
+
+ # Disable the support keypoint positional embedding
+ support_order_embedding = x.new_zeros((bs, self.embed_dims, 1, target_s[0].shape[1])).to(torch.bool)
+
+ # Feature map pos embedding
+ masks = x.new_zeros((x.shape[0], x.shape[2], x.shape[3])).to(torch.bool)
+ pos_embed = self.positional_encoding(masks)
+
+ # process keypoint token feature
+ query_embed_list = []
+ for i, (feature, target) in enumerate(zip(feature_s, target_s)):
+ # resize the support feature back to the heatmap sizes.
+ resized_feature = resize(
+ input=feature,
+ size=target.shape[-2:],
+ mode='bilinear',
+ align_corners=False)
+ target = target / (target.sum(dim=-1).sum(dim=-1)[:, :, None, None] + 1e-8)
+ support_keypoints = target.flatten(2) @ resized_feature.flatten(2).permute(0, 2, 1)
+ query_embed_list.append(support_keypoints)
+
+ support_keypoints = torch.mean(torch.stack(query_embed_list, dim=0), 0)
+ support_keypoints = support_keypoints * mask_s
+ support_keypoints = self.query_proj(support_keypoints)
+ masks_query = (~mask_s.to(torch.bool)).squeeze(-1) # True indicating this query matched no actual joints.
+
+ # outs_dec: [nb_dec, bs, num_query, c]
+ # memory: [bs, c, h, w]
+ # x = Query image feature, support_keypoints = Support keypoint feature
+ outs_dec, initial_proposals, out_points, similarity_map = self.transformer(x,
+ masks,
+ support_keypoints,
+ pos_embed,
+ support_order_embedding,
+ masks_query,
+ self.positional_encoding,
+ self.kpt_branch,
+ skeleton)
+
+ output_kpts = []
+ for idx in range(outs_dec.shape[0]):
+ layer_delta_unsig = self.kpt_branch[idx](outs_dec[idx])
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(
+ out_points[idx])
+ output_kpts.append(layer_outputs_unsig.sigmoid())
+
+ return torch.stack(output_kpts, dim=0), initial_proposals, similarity_map
+
+ def get_loss(self, output, initial_proposals, similarity_map, target, target_heatmap, target_weight, target_sizes):
+ # Calculate top-down keypoint loss.
+ losses = dict()
+ # denormalize the predicted coordinates.
+ num_dec_layer, bs, nq = output.shape[:3]
+ target_sizes = target_sizes.to(output.device) # [bs, 1, 2]
+ target = target / target_sizes
+ target = target[None, :, :, :].repeat(num_dec_layer, 1, 1, 1)
+
+ # set the weight for unset query point to be zero
+ normalizer = target_weight.squeeze(dim=-1).sum(dim=-1) # [bs, ]
+ normalizer[normalizer == 0] = 1
+
+ # compute the heatmap loss
+ if self.with_heatmap_loss:
+ losses['heatmap_loss'] = self.heatmap_loss(
+ similarity_map, target_heatmap, target_weight,
+ normalizer) * self.heatmap_loss_weight
+
+ # compute l1 loss for inital_proposals
+ proposal_l1_loss = F.l1_loss(
+ initial_proposals, target[0], reduction="none")
+ proposal_l1_loss = proposal_l1_loss.sum(
+ dim=-1, keepdim=False) * target_weight.squeeze(dim=-1)
+ proposal_l1_loss = proposal_l1_loss.sum(
+ dim=-1, keepdim=False) / normalizer # [bs, ]
+ losses['proposal_loss'] = proposal_l1_loss.sum() / bs
+
+ # compute l1 loss for each layer
+ for idx in range(num_dec_layer):
+ layer_output, layer_target = output[idx], target[idx]
+ l1_loss = F.l1_loss(
+ layer_output, layer_target, reduction="none") # [bs, query, 2]
+ l1_loss = l1_loss.sum(
+ dim=-1, keepdim=False) * target_weight.squeeze(
+ dim=-1) # [bs, query]
+ # normalize the loss for each sample with the number of visible joints
+ l1_loss = l1_loss.sum(dim=-1, keepdim=False) / normalizer # [bs, ]
+ losses['l1_loss' + '_layer' + str(idx)] = l1_loss.sum() / bs
+
+ return losses
+
+ def get_max_coords(self, heatmap, heatmap_size=64):
+ B, C, H, W = heatmap.shape
+ heatmap = heatmap.view(B, C, -1)
+ max_cor = heatmap.argmax(dim=2)
+ row, col = torch.floor(max_cor / heatmap_size), max_cor % heatmap_size
+ support_joints = torch.cat((row.unsqueeze(-1), col.unsqueeze(-1)), dim=-1)
+ return support_joints
+
+ def heatmap_loss(self, similarity_map, target_heatmap, target_weight,
+ normalizer):
+ # similarity_map: [bs, num_query, h, w]
+ # target_heatmap: [bs, num_query, sh, sw]
+ # target_weight: [bs, num_query, 1]
+
+ # preprocess the similarity_map
+ h, w = similarity_map.shape[-2:]
+ # similarity_map = torch.clamp(similarity_map, 0.0, None)
+ similarity_map = similarity_map.sigmoid()
+
+ target_heatmap = F.interpolate(
+ target_heatmap, size=(h, w), mode='bilinear')
+ target_heatmap = (target_heatmap /
+ (target_heatmap.max(dim=-1)[0].max(dim=-1)[0] + 1e-10)[:, :, None,
+ None]) # make sure sum of each query is 1
+
+ l2_loss = F.mse_loss(
+ similarity_map, target_heatmap, reduction="none") # bs, nq, h, w
+ l2_loss = l2_loss * target_weight[:, :, :, None] # bs, nq, h, w
+ l2_loss = l2_loss.flatten(2, 3).sum(-1) / (h * w) # bs, nq
+ l2_loss = l2_loss.sum(-1) / normalizer # bs,
+
+ return l2_loss.mean()
+
+ def get_accuracy(self, output, target, target_weight, target_sizes, height=256):
+ """Calculate accuracy for top-down keypoint loss.
+
+ Args:
+ output (torch.Tensor[NxKx2]): estimated keypoints in ABSOLUTE coordinates.
+ target (torch.Tensor[NxKx2]): gt keypoints in ABSOLUTE coordinates.
+ target_weight (torch.Tensor[NxKx1]): Weights across different joint types.
+ target_sizes (torch.Tensor[Nx2): shapes of the image.
+ """
+ # NOTE: In POMNet, PCK is estimated on 1/8 resolution, which is slightly different here.
+
+ accuracy = dict()
+ output = output * float(height)
+ output, target, target_weight, target_sizes = (
+ output.detach().cpu().numpy(), target.detach().cpu().numpy(),
+ target_weight.squeeze(-1).long().detach().cpu().numpy(),
+ target_sizes.squeeze(1).detach().cpu().numpy())
+
+ _, avg_acc, _ = keypoint_pck_accuracy(
+ output,
+ target,
+ target_weight.astype(np.bool8),
+ thr=0.2,
+ normalize=target_sizes)
+ accuracy['acc_pose'] = float(avg_acc)
+
+ return accuracy
+
+ def decode(self, img_metas, output, img_size, **kwargs):
+ """Decode the predicted keypoints from prediction.
+
+ Args:
+ img_metas (list(dict)): Information about data augmentation
+ By default this includes:
+ - "image_file: path to the image file
+ - "center": center of the bbox
+ - "scale": scale of the bbox
+ - "rotation": rotation of the bbox
+ - "bbox_score": score of bbox
+ output (np.ndarray[N, K, H, W]): model predicted heatmaps.
+ """
+ batch_size = len(img_metas)
+ W, H = img_size
+ output = output * np.array([W, H])[None, None, :] # [bs, query, 2], coordinates with recovered shapes.
+
+ if 'bbox_id' or 'query_bbox_id' in img_metas[0]:
+ bbox_ids = []
+ else:
+ bbox_ids = None
+
+ c = np.zeros((batch_size, 2), dtype=np.float32)
+ s = np.zeros((batch_size, 2), dtype=np.float32)
+ image_paths = []
+ score = np.ones(batch_size)
+ for i in range(batch_size):
+ c[i, :] = img_metas[i]['query_center']
+ s[i, :] = img_metas[i]['query_scale']
+ image_paths.append(img_metas[i]['query_image_file'])
+
+ if 'query_bbox_score' in img_metas[i]:
+ score[i] = np.array(
+ img_metas[i]['query_bbox_score']).reshape(-1)
+ if 'bbox_id' in img_metas[i]:
+ bbox_ids.append(img_metas[i]['bbox_id'])
+ elif 'query_bbox_id' in img_metas[i]:
+ bbox_ids.append(img_metas[i]['query_bbox_id'])
+
+ preds = np.zeros(output.shape)
+ for idx in range(output.shape[0]):
+ preds[i] = transform_preds(
+ output[i],
+ c[i],
+ s[i], [W, H],
+ use_udp=self.test_cfg.get('use_udp', False))
+
+ all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
+ all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
+ all_preds[:, :, 0:2] = preds[:, :, 0:2]
+ all_preds[:, :, 2:3] = 1.0
+ all_boxes[:, 0:2] = c[:, 0:2]
+ all_boxes[:, 2:4] = s[:, 0:2]
+ all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
+ all_boxes[:, 5] = score
+
+ result = {}
+
+ result['preds'] = all_preds
+ result['boxes'] = all_boxes
+ result['image_paths'] = image_paths
+ result['bbox_ids'] = bbox_ids
+
+ return result
diff --git a/models/models/utils/__init__.py b/models/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..01719d3c9f951e9cfa78fe653d511da964b02235
--- /dev/null
+++ b/models/models/utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import build_linear_layer, build_transformer, build_backbone
+from .encoder_decoder import EncoderDecoder
+from .positional_encoding import (LearnedPositionalEncoding,
+ SinePositionalEncoding)
+from .transformer import (DetrTransformerDecoderLayer, DetrTransformerDecoder,
+ DetrTransformerEncoder, DynamicConv)
+
+__all__ = [
+ 'build_transformer', 'build_backbone', 'build_linear_layer', 'DetrTransformerDecoderLayer',
+ 'DetrTransformerDecoder', 'DetrTransformerEncoder',
+ 'LearnedPositionalEncoding', 'SinePositionalEncoding',
+ 'EncoderDecoder',
+]
diff --git a/models/models/utils/builder.py b/models/models/utils/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..111c33d6831c0a6e453a15637a48a44e4ff27714
--- /dev/null
+++ b/models/models/utils/builder.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.utils import Registry, build_from_cfg
+
+TRANSFORMER = Registry('Transformer')
+BACKBONES = Registry('BACKBONES')
+LINEAR_LAYERS = Registry('linear layers')
+
+
+def build_backbone(cfg, default_args=None):
+ """Build backbone."""
+ return build_from_cfg(cfg, BACKBONES, default_args)
+
+
+def build_transformer(cfg, default_args=None):
+ """Builder for Transformer."""
+ return build_from_cfg(cfg, TRANSFORMER, default_args)
+
+
+LINEAR_LAYERS.register_module('Linear', module=nn.Linear)
+
+
+def build_linear_layer(cfg, *args, **kwargs):
+ """Build linear layer.
+ Args:
+ cfg (None or dict): The linear layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an linear layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding linear layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding linear layer.
+ Returns:
+ nn.Module: Created linear layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Linear')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in LINEAR_LAYERS:
+ raise KeyError(f'Unrecognized linear type {layer_type}')
+ else:
+ linear_layer = LINEAR_LAYERS.get(layer_type)
+
+ layer = linear_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/models/models/utils/encoder_decoder.py b/models/models/utils/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e544329a2e6ec9324f72ec2f3c8fbb1535821c6
--- /dev/null
+++ b/models/models/utils/encoder_decoder.py
@@ -0,0 +1,574 @@
+import copy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+from torch import Tensor
+
+from models.models.utils.builder import TRANSFORMER
+
+
+def inverse_sigmoid(x, eps=1e-3):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.gelu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class ProposalGenerator(nn.Module):
+
+ def __init__(self, hidden_dim, proj_dim, dynamic_proj_dim):
+ super().__init__()
+ self.support_proj = nn.Linear(hidden_dim, proj_dim)
+ self.query_proj = nn.Linear(hidden_dim, proj_dim)
+ self.dynamic_proj = nn.Sequential(
+ nn.Linear(hidden_dim, dynamic_proj_dim),
+ nn.ReLU(),
+ nn.Linear(dynamic_proj_dim, hidden_dim))
+ self.dynamic_act = nn.Tanh()
+
+ def forward(self, query_feat, support_feat, spatial_shape):
+ """
+ Args:
+ support_feat: [query, bs, c]
+ query_feat: [hw, bs, c]
+ spatial_shape: h, w
+ """
+ device = query_feat.device
+ _, bs, c = query_feat.shape
+ h, w = spatial_shape
+ side_normalizer = torch.tensor([w, h]).to(query_feat.device)[None, None,
+ :] # [bs, query, 2], Normalize the coord to [0,1]
+
+ query_feat = query_feat.transpose(0, 1)
+ support_feat = support_feat.transpose(0, 1)
+ nq = support_feat.shape[1]
+
+ fs_proj = self.support_proj(support_feat) # [bs, query, c]
+ fq_proj = self.query_proj(query_feat) # [bs, hw, c]
+ pattern_attention = self.dynamic_act(self.dynamic_proj(fs_proj)) # [bs, query, c]
+
+ fs_feat = (pattern_attention + 1) * fs_proj # [bs, query, c]
+ similarity = torch.bmm(fq_proj, fs_feat.transpose(1, 2)) # [bs, hw, query]
+ similarity = similarity.transpose(1, 2).reshape(bs, nq, h, w)
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=device), # (h, w)
+ torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=device))
+
+ # compute softmax and sum up
+ coord_grid = torch.stack([grid_x, grid_y],
+ dim=0).unsqueeze(0).unsqueeze(0).repeat(bs, nq, 1, 1, 1) # [bs, query, 2, h, w]
+ coord_grid = coord_grid.permute(0, 1, 3, 4, 2) # [bs, query, h, w, 2]
+ similarity_softmax = similarity.flatten(2, 3).softmax(dim=-1) # [bs, query, hw]
+ similarity_coord_grid = similarity_softmax[:, :, :, None] * coord_grid.flatten(2, 3)
+ proposal_for_loss = similarity_coord_grid.sum(dim=2, keepdim=False) # [bs, query, 2]
+ proposal_for_loss = proposal_for_loss / side_normalizer
+
+ max_pos = torch.argmax(similarity.reshape(bs, nq, -1), dim=-1, keepdim=True) # (bs, nq, 1)
+ max_mask = F.one_hot(max_pos, num_classes=w * h) # (bs, nq, 1, w*h)
+ max_mask = max_mask.reshape(bs, nq, w, h).type(torch.float) # (bs, nq, w, h)
+ local_max_mask = F.max_pool2d(
+ input=max_mask, kernel_size=3, stride=1,
+ padding=1).reshape(bs, nq, w * h, 1) # (bs, nq, w*h, 1)
+ '''
+ proposal = (similarity_coord_grid * local_max_mask).sum(
+ dim=2, keepdim=False) / torch.count_nonzero(
+ local_max_mask, dim=2)
+ '''
+ # first, extract the local probability map with the mask
+ local_similarity_softmax = similarity_softmax[:, :, :, None] * local_max_mask # (bs, nq, w*h, 1)
+
+ # then, re-normalize the local probability map
+ local_similarity_softmax = local_similarity_softmax / (
+ local_similarity_softmax.sum(dim=-2, keepdim=True) + 1e-10
+ ) # [bs, nq, w*h, 1]
+
+ # point-wise mulplication of local probability map and coord grid
+ proposals = local_similarity_softmax * coord_grid.flatten(2, 3) # [bs, nq, w*h, 2]
+
+ # sum the mulplication to obtain the final coord proposals
+ proposals = proposals.sum(dim=2) / side_normalizer # [bs, nq, 2]
+
+ return proposal_for_loss, similarity, proposals
+
+
+@TRANSFORMER.register_module()
+class EncoderDecoder(nn.Module):
+
+ def __init__(self,
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=3,
+ num_decoder_layers=3,
+ graph_decoder=None,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ similarity_proj_dim=256,
+ dynamic_proj_dim=128,
+ return_intermediate_dec=True,
+ look_twice=False,
+ detach_support_feat=False):
+ super().__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
+ activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = GraphTransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
+ activation, normalize_before, graph_decoder)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = GraphTransformerDecoder(d_model, decoder_layer, num_decoder_layers, decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ look_twice=look_twice, detach_support_feat=detach_support_feat)
+
+ self.proposal_generator = ProposalGenerator(
+ hidden_dim=d_model,
+ proj_dim=similarity_proj_dim,
+ dynamic_proj_dim=dynamic_proj_dim)
+
+ def init_weights(self):
+ # follow the official DETR to init parameters
+ for m in self.modules():
+ if hasattr(m, 'weight') and m.weight.dim() > 1:
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, src, mask, support_embed, pos_embed, support_order_embed,
+ query_padding_mask, position_embedding, kpt_branch, skeleton, return_attn_map=False):
+
+ bs, c, h, w = src.shape
+
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ support_order_embed = support_order_embed.flatten(2).permute(2, 0, 1)
+ pos_embed = torch.cat((pos_embed, support_order_embed))
+ query_embed = support_embed.transpose(0, 1)
+ mask = mask.flatten(1)
+
+ query_embed, refined_support_embed = self.encoder(
+ src,
+ query_embed,
+ src_key_padding_mask=mask,
+ query_key_padding_mask=query_padding_mask,
+ pos=pos_embed)
+
+ # Generate initial proposals and corresponding positional embedding.
+ initial_proposals_for_loss, similarity_map, initial_proposals = self.proposal_generator(
+ query_embed, refined_support_embed, spatial_shape=[h, w])
+ initial_position_embedding = position_embedding.forward_coordinates(initial_proposals)
+
+ outs_dec, out_points, attn_maps = self.decoder(
+ refined_support_embed,
+ query_embed,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=initial_position_embedding,
+ tgt_key_padding_mask=query_padding_mask,
+ position_embedding=position_embedding,
+ initial_proposals=initial_proposals,
+ kpt_branch=kpt_branch,
+ skeleton=skeleton,
+ return_attn_map=return_attn_map)
+
+ return outs_dec.transpose(1, 2), initial_proposals_for_loss, out_points, similarity_map
+
+
+class GraphTransformerDecoder(nn.Module):
+
+ def __init__(self,
+ d_model,
+ decoder_layer,
+ num_layers,
+ norm=None,
+ return_intermediate=False,
+ look_twice=False,
+ detach_support_feat=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+ self.ref_point_head = MLP(d_model, d_model, d_model, num_layers=2)
+ self.look_twice = look_twice
+ self.detach_support_feat = detach_support_feat
+
+ def forward(self,
+ support_feat,
+ query_feat,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None,
+ pos=None,
+ query_pos=None,
+ position_embedding=None,
+ initial_proposals=None,
+ kpt_branch=None,
+ skeleton=None,
+ return_attn_map=False):
+ """
+ position_embedding: Class used to compute positional embedding
+ inital_proposals: [bs, nq, 2], normalized coordinates of inital proposals
+ kpt_branch: MLP used to predict the offsets for each query.
+ """
+
+ refined_support_feat = support_feat
+ intermediate = []
+ attn_maps = []
+ bi = initial_proposals.detach()
+ bi_tag = initial_proposals.detach()
+ query_points = [initial_proposals.detach()]
+
+ tgt_key_padding_mask_remove_all_true = tgt_key_padding_mask.clone().to(tgt_key_padding_mask.device)
+ tgt_key_padding_mask_remove_all_true[tgt_key_padding_mask.logical_not().sum(dim=-1) == 0, 0] = False
+
+ for layer_idx, layer in enumerate(self.layers):
+ if layer_idx == 0: # use positional embedding form inital proposals
+ query_pos_embed = query_pos.transpose(0, 1)
+ else:
+ # recalculate the positional embedding
+ query_pos_embed = position_embedding.forward_coordinates(bi)
+ query_pos_embed = query_pos_embed.transpose(0, 1)
+ query_pos_embed = self.ref_point_head(query_pos_embed)
+
+ if self.detach_support_feat:
+ refined_support_feat = refined_support_feat.detach()
+
+ refined_support_feat, attn_map = layer(
+ refined_support_feat,
+ query_feat,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask_remove_all_true,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos_embed,
+ skeleton=skeleton)
+
+ if self.return_intermediate:
+ intermediate.append(self.norm(refined_support_feat))
+
+ if return_attn_map:
+ attn_maps.append(attn_map)
+
+ # update the query coordinates
+ delta_bi = kpt_branch[layer_idx](refined_support_feat.transpose(0, 1))
+
+ # Prediction loss
+ if self.look_twice:
+ bi_pred = self.update(bi_tag, delta_bi)
+ bi_tag = self.update(bi, delta_bi)
+ else:
+ bi_tag = self.update(bi, delta_bi)
+ bi_pred = bi_tag
+
+ bi = bi_tag.detach()
+ query_points.append(bi_pred)
+
+ if self.norm is not None:
+ refined_support_feat = self.norm(refined_support_feat)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(refined_support_feat)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), query_points, attn_maps
+
+ return refined_support_feat.unsqueeze(0), query_points, attn_maps
+
+ def update(self, query_coordinates, delta_unsig):
+ query_coordinates_unsigmoid = inverse_sigmoid(query_coordinates)
+ new_query_coordinates = query_coordinates_unsigmoid + delta_unsig
+ new_query_coordinates = new_query_coordinates.sigmoid()
+ return new_query_coordinates
+
+
+class GraphTransformerDecoderLayer(nn.Module):
+
+ def __init__(self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ graph_decoder=None):
+
+ super().__init__()
+ self.graph_decoder = graph_decoder
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(
+ d_model * 2, nhead, dropout=dropout, vdim=d_model)
+ self.choker = nn.Linear(in_features=2 * d_model, out_features=d_model)
+ # Implementation of Feedforward model
+ if self.graph_decoder is None:
+ self.ffn1 = nn.Linear(d_model, dim_feedforward)
+ self.ffn2 = nn.Linear(dim_feedforward, d_model)
+ elif self.graph_decoder == 'pre':
+ self.ffn1 = GCNLayer(d_model, dim_feedforward, batch_first=False)
+ self.ffn2 = nn.Linear(dim_feedforward, d_model)
+ elif self.graph_decoder == 'post':
+ self.ffn1 = nn.Linear(d_model, dim_feedforward)
+ self.ffn2 = GCNLayer(dim_feedforward, d_model, batch_first=False)
+ else:
+ self.ffn1 = GCNLayer(d_model, dim_feedforward, batch_first=False)
+ self.ffn2 = GCNLayer(dim_feedforward, d_model, batch_first=False)
+
+ self.dropout = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self,
+ refined_support_feat,
+ refined_query_feat,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ skeleton: Optional[list] = None):
+
+ q = k = self.with_pos_embed(refined_support_feat, query_pos + pos[refined_query_feat.shape[0]:])
+ tgt2 = self.self_attn(
+ q,
+ k,
+ value=refined_support_feat,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+
+ refined_support_feat = refined_support_feat + self.dropout1(tgt2)
+ refined_support_feat = self.norm1(refined_support_feat)
+
+ # concatenate the positional embedding with the content feature, instead of direct addition
+ cross_attn_q = torch.cat((refined_support_feat, query_pos + pos[refined_query_feat.shape[0]:]), dim=-1)
+ cross_attn_k = torch.cat((refined_query_feat, pos[:refined_query_feat.shape[0]]), dim=-1)
+
+ tgt2, attn_map = self.multihead_attn(
+ query=cross_attn_q,
+ key=cross_attn_k,
+ value=refined_query_feat,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+
+ refined_support_feat = refined_support_feat + self.dropout2(self.choker(tgt2))
+ refined_support_feat = self.norm2(refined_support_feat)
+ if self.graph_decoder is not None:
+ num_pts, b, c = refined_support_feat.shape
+ adj = adj_from_skeleton(num_pts=num_pts,
+ skeleton=skeleton,
+ mask=tgt_key_padding_mask,
+ device=refined_support_feat.device)
+ if self.graph_decoder == 'pre':
+ tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat, adj))))
+ elif self.graph_decoder == 'post':
+ tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat))), adj)
+ else:
+ tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat, adj))), adj)
+ else:
+ tgt2 = self.ffn2(self.dropout(self.activation(self.ffn1(refined_support_feat))))
+ refined_support_feat = refined_support_feat + self.dropout3(tgt2)
+ refined_support_feat = self.norm3(refined_support_feat)
+
+ return refined_support_feat, attn_map
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self,
+ src,
+ query,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ query_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ # src: [hw, bs, c]
+ # query: [num_query, bs, c]
+ # mask: None by default
+ # src_key_padding_mask: [bs, hw]
+ # query_key_padding_mask: [bs, nq]
+ # pos: [hw, bs, c]
+
+ # organize the input
+ # implement the attention mask to mask out the useless points
+ n, bs, c = src.shape
+ src_cat = torch.cat((src, query), dim=0) # [hw + nq, bs, c]
+ mask_cat = torch.cat((src_key_padding_mask, query_key_padding_mask),
+ dim=1) # [bs, hw+nq]
+ output = src_cat
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ query_length=n,
+ src_mask=mask,
+ src_key_padding_mask=mask_cat,
+ pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ # resplit the output into src and query
+ refined_query = output[n:, :, :] # [nq, bs, c]
+ output = output[:n, :, :] # [n, bs, c]
+
+ return output, refined_query
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self,
+ src,
+ query_length,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ src = self.with_pos_embed(src, pos)
+ q = k = src
+ # NOTE: compared with original implementation, we add positional embedding into the VALUE.
+ src2 = self.self_attn(
+ q,
+ k,
+ value=src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+
+def adj_from_skeleton(num_pts, skeleton, mask, device='cuda'):
+ adj_mx = torch.empty(0, device=device)
+ batch_size = len(skeleton)
+ for b in range(batch_size):
+ edges = torch.tensor(skeleton[b])
+ adj = torch.zeros(num_pts, num_pts, device=device)
+ adj[edges[:, 0], edges[:, 1]] = 1
+ adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0)
+ trans_adj_mx = torch.transpose(adj_mx, 1, 2)
+ cond = (trans_adj_mx > adj_mx).float()
+ adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
+ adj = adj * ~mask[..., None] * ~mask[:, None]
+ adj = torch.nan_to_num(adj / adj.sum(dim=-1, keepdim=True))
+ adj = torch.stack((torch.diag_embed(~mask), adj), dim=1)
+ return adj
+
+
+class GCNLayer(nn.Module):
+ def __init__(self,
+ in_features,
+ out_features,
+ kernel_size=2,
+ use_bias=True,
+ activation=nn.ReLU(inplace=True),
+ batch_first=True):
+ super(GCNLayer, self).__init__()
+ self.conv = nn.Conv1d(in_features, out_features * kernel_size, kernel_size=1,
+ padding=0, stride=1, dilation=1, bias=use_bias)
+ self.kernel_size = kernel_size
+ self.activation = activation
+ self.batch_first = batch_first
+
+ def forward(self, x, adj):
+ assert adj.size(1) == self.kernel_size
+ if not self.batch_first:
+ x = x.permute(1, 2, 0)
+ else:
+ x = x.transpose(1, 2)
+ x = self.conv(x)
+ b, kc, v = x.size()
+ x = x.view(b, self.kernel_size, kc // self.kernel_size, v)
+ x = torch.einsum('bkcv,bkvw->bcw', (x, adj))
+ if self.activation is not None:
+ x = self.activation(x)
+ if not self.batch_first:
+ x = x.permute(2, 0, 1)
+ else:
+ x = x.transpose(1, 2)
+ return x
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+def clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
diff --git a/models/models/utils/positional_encoding.py b/models/models/utils/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bb0e2eae2919ab859e1c2a4c8cc2efb7a35eaa3
--- /dev/null
+++ b/models/models/utils/positional_encoding.py
@@ -0,0 +1,193 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
+from mmcv.runner import BaseModule
+
+
+# TODO: add an SinePositionalEncoding for coordinates input
+
+@POSITIONAL_ENCODING.register_module()
+class SinePositionalEncoding(BaseModule):
+ """Position encoding with sine and cosine functions.
+
+ See `End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. Note the final returned dimension
+ for each position is 2 times of this value.
+ temperature (int, optional): The temperature used for scaling
+ the position embedding. Defaults to 10000.
+ normalize (bool, optional): Whether to normalize the position
+ embedding. Defaults to False.
+ scale (float, optional): A scale factor that scales the position
+ embedding. The scale will be used only when `normalize` is True.
+ Defaults to 2*pi.
+ eps (float, optional): A value added to the denominator for
+ numerical stability. Defaults to 1e-6.
+ offset (float): offset add to embed when do the normalization.
+ Defaults to 0.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_feats,
+ temperature=10000,
+ normalize=False,
+ scale=2 * math.pi,
+ eps=1e-6,
+ offset=0.,
+ init_cfg=None):
+ super(SinePositionalEncoding, self).__init__(init_cfg)
+ if normalize:
+ assert isinstance(scale, (float, int)), 'when normalize is set,' \
+ 'scale should be provided and in float or int type, ' \
+ f'found {type(scale)}'
+ self.num_feats = num_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = eps
+ self.offset = offset
+
+ def forward(self, mask):
+ """Forward function for `SinePositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ # For convenience of exporting to ONNX, it's required to convert
+ # `masks` from bool to int.
+ mask = mask.to(torch.int)
+ not_mask = 1 - mask # logical_not
+ y_embed = not_mask.cumsum(1, dtype=torch.float32) # [bs, h, w], recording the y coordinate ot each pixel
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize: # default True
+ y_embed = (y_embed + self.offset) / \
+ (y_embed[:, -1:, :] + self.eps) * self.scale
+ x_embed = (x_embed + self.offset) / \
+ (x_embed[:, :, -1:] + self.eps) * self.scale
+ dim_t = torch.arange(
+ self.num_feats, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
+ pos_x = x_embed[:, :, :, None] / dim_t # [bs, h, w, num_feats]
+ pos_y = y_embed[:, :, :, None] / dim_t
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ B, H, W = mask.size()
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).view(B, H, W, -1) # [bs, h, w, num_feats]
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).view(B, H, W, -1)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def forward_coordinates(self, coord):
+ """
+ Forward funtion for normalized coordinates input with the shape of [bs, kpt, 2]
+ return:
+ pos (Tensor): position embedding with the shape of [bs, kpt, num_feats*2]
+ """
+ x_embed, y_embed = coord[:, :, 0], coord[:, :, 1] # [bs, kpt]
+ x_embed = x_embed * self.scale # [bs, kpt]
+ y_embed = y_embed * self.scale
+
+ dim_t = torch.arange(
+ self.num_feats, dtype=torch.float32, device=coord.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
+
+ pos_x = x_embed[:, :, None] / dim_t # [bs, kpt, num_feats]
+ pos_y = y_embed[:, :, None] / dim_t # [bs, kpt, num_feats]
+ bs, kpt, _ = pos_x.shape
+
+ pos_x = torch.stack(
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),
+ dim=3).view(bs, kpt, -1) # [bs, kpt, num_feats]
+ pos_y = torch.stack(
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),
+ dim=3).view(bs, kpt, -1) # [bs, kpt, num_feats]
+ pos = torch.cat((pos_y, pos_x), dim=2) # [bs, kpt, num_feats * 2]
+
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'temperature={self.temperature}, '
+ repr_str += f'normalize={self.normalize}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'eps={self.eps})'
+ return repr_str
+
+
+@POSITIONAL_ENCODING.register_module()
+class LearnedPositionalEncoding(BaseModule):
+ """Position embedding with learnable embedding weights.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. The final returned dimension for
+ each position is 2 times of this value.
+ row_num_embed (int, optional): The dictionary size of row embeddings.
+ Default 50.
+ col_num_embed (int, optional): The dictionary size of col embeddings.
+ Default 50.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_feats,
+ row_num_embed=50,
+ col_num_embed=50,
+ init_cfg=dict(type='Uniform', layer='Embedding')):
+ super(LearnedPositionalEncoding, self).__init__(init_cfg)
+ self.row_embed = nn.Embedding(row_num_embed, num_feats)
+ self.col_embed = nn.Embedding(col_num_embed, num_feats)
+ self.num_feats = num_feats
+ self.row_num_embed = row_num_embed
+ self.col_num_embed = col_num_embed
+
+ def forward(self, mask):
+ """Forward function for `LearnedPositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ h, w = mask.shape[-2:]
+ x = torch.arange(w, device=mask.device)
+ y = torch.arange(h, device=mask.device)
+ x_embed = self.col_embed(x)
+ y_embed = self.row_embed(y)
+ pos = torch.cat(
+ (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(
+ 1, w, 1)),
+ dim=-1).permute(2, 0,
+ 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'row_num_embed={self.row_num_embed}, '
+ repr_str += f'col_num_embed={self.col_num_embed})'
+ return repr_str
diff --git a/models/models/utils/transformer.py b/models/models/utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad9c64bc17bff54427b2d8064562f6a1fd6a403
--- /dev/null
+++ b/models/models/utils/transformer.py
@@ -0,0 +1,331 @@
+import torch
+import torch.nn as nn
+from models.models.utils.builder import TRANSFORMER
+from mmcv.cnn import (build_activation_layer, build_norm_layer, xavier_init)
+from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
+ TRANSFORMER_LAYER_SEQUENCE)
+from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
+ TransformerLayerSequence,
+ build_transformer_layer_sequence)
+from mmcv.runner.base_module import BaseModule
+
+
+@TRANSFORMER.register_module()
+class Transformer(BaseModule):
+ """Implements the DETR transformer.
+ Following the official DETR implementation, this module copy-paste
+ from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MultiheadAttention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+ Args:
+ encoder (`mmcv.ConfigDict` | Dict): Config of
+ TransformerEncoder. Defaults to None.
+ decoder ((`mmcv.ConfigDict` | Dict)): Config of
+ TransformerDecoder. Defaults to None
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self, encoder=None, decoder=None, init_cfg=None):
+ super(Transformer, self).__init__(init_cfg=init_cfg)
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.decoder = build_transformer_layer_sequence(decoder)
+ self.embed_dims = self.encoder.embed_dims
+
+ def init_weights(self):
+ # follow the official DETR to init parameters
+ for m in self.modules():
+ if hasattr(m, 'weight') and m.weight.dim() > 1:
+ xavier_init(m, distribution='uniform')
+ self._is_init = True
+
+ def forward(self, x, mask, query_embed, pos_embed, mask_query):
+ """Forward function for `Transformer`.
+ Args:
+ x (Tensor): Input query with shape [bs, c, h, w] where
+ c = embed_dims.
+ mask (Tensor): The key_padding_mask used for encoder and decoder,
+ with shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder, with shape
+ [num_query, c].
+ pos_embed (Tensor): The positional encoding for encoder and
+ decoder, with the same shape as `x`.
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+ - out_dec: Output from decoder. If return_intermediate_dec \
+ is True output has shape [num_dec_layers, bs,
+ num_query, embed_dims], else has shape [1, bs, \
+ num_query, embed_dims].
+ - memory: Output results from encoder, with shape \
+ [bs, embed_dims, h, w].
+
+ Notes:
+ x: query image features with shape [bs, c, h, w]
+ mask: mask for x with shape [bs, h, w]
+ pos_embed: positional embedding for x with shape [bs, c, h, w]
+ query_embed: sample keypoint features with shape [bs, num_query, c]
+ mask_query: mask for query_embed with shape [bs, num_query]
+ Outputs:
+ out_dec: [num_layers, bs, num_query, c]
+ memory: [bs, c, h, w]
+
+ """
+ bs, c, h, w = x.shape
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
+ mask = mask.view(bs,
+ -1) # [bs, h, w] -> [bs, h*w] Note: this mask should be filled with False, since all images are with the same shape.
+ pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) # positional embeding for memory, i.e., the query.
+ memory = self.encoder(
+ query=x,
+ key=None,
+ value=None,
+ query_pos=pos_embed,
+ query_key_padding_mask=mask) # output memory: [hw, bs, c]
+
+ query_embed = query_embed.permute(1, 0, 2) # [bs, num_query, c] -> [num_query, bs, c]
+ # target = torch.zeros_like(query_embed)
+ # out_dec: [num_layers, num_query, bs, c]
+ out_dec = self.decoder(
+ query=query_embed,
+ key=memory,
+ value=memory,
+ key_pos=pos_embed,
+ # query_pos=query_embed,
+ query_key_padding_mask=mask_query,
+ key_padding_mask=mask)
+ out_dec = out_dec.transpose(1, 2) # [decoder_layer, bs, num_query, c]
+ memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
+ return out_dec, memory
+
+
+@TRANSFORMER_LAYER.register_module()
+class DetrTransformerDecoderLayer(BaseTransformerLayer):
+ """Implements decoder layer in DETR transformer.
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
+ Configs for self_attention or cross_attention, the order
+ should be consistent with it in `operation_order`. If it is
+ a dict, it would be expand to the number of attention in
+ `operation_order`.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ ffn_dropout (float): Probability of an element to be zeroed
+ in ffn. Default 0.0.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Default:None
+ act_cfg (dict): The activation config for FFNs. Default: `LN`
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: `LN`.
+ ffn_num_fcs (int): The number of fully-connected layers in FFNs.
+ Default:2.
+ """
+
+ def __init__(self,
+ attn_cfgs,
+ feedforward_channels,
+ ffn_dropout=0.0,
+ operation_order=None,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ ffn_num_fcs=2,
+ **kwargs):
+ super(DetrTransformerDecoderLayer, self).__init__(
+ attn_cfgs=attn_cfgs,
+ feedforward_channels=feedforward_channels,
+ ffn_dropout=ffn_dropout,
+ operation_order=operation_order,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ ffn_num_fcs=ffn_num_fcs,
+ **kwargs)
+ # assert len(operation_order) == 6
+ # assert set(operation_order) == set(
+ # ['self_attn', 'norm', 'cross_attn', 'ffn'])
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerEncoder(TransformerLayerSequence):
+ """TransformerEncoder of DETR.
+ Args:
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`. Only used when `self.pre_norm` is `True`
+ """
+
+ def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):
+ super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(
+ post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
+ else:
+ # assert not self.pre_norm, f'Use prenorm in ' \
+ # f'{self.__class__.__name__},' \
+ # f'Please specify post_norm_cfg'
+ self.post_norm = None
+
+ def forward(self, *args, **kwargs):
+ """Forward function for `TransformerCoder`.
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
+ if self.post_norm is not None:
+ x = self.post_norm(x)
+ return x
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+
+ def __init__(self,
+ *args,
+ post_norm_cfg=dict(type='LN'),
+ return_intermediate=False,
+ **kwargs):
+
+ super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(post_norm_cfg,
+ self.embed_dims)[1]
+ else:
+ self.post_norm = None
+
+ def forward(self, query, *args, **kwargs):
+ """Forward function for `TransformerDecoder`.
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ if not self.return_intermediate:
+ x = super().forward(query, *args, **kwargs)
+ if self.post_norm:
+ x = self.post_norm(x)[None]
+ return x
+
+ intermediate = []
+ for layer in self.layers:
+ query = layer(query, *args, **kwargs)
+ if self.return_intermediate:
+ if self.post_norm is not None:
+ intermediate.append(self.post_norm(query))
+ else:
+ intermediate.append(query)
+ return torch.stack(intermediate)
+
+
+@TRANSFORMER.register_module()
+class DynamicConv(BaseModule):
+ """Implements Dynamic Convolution.
+ This module generate parameters for each sample and
+ use bmm to implement 1*1 convolution. Code is modified
+ from the `official github repo `_ .
+ Args:
+ in_channels (int): The input feature channel.
+ Defaults to 256.
+ feat_channels (int): The inner feature channel.
+ Defaults to 64.
+ out_channels (int, optional): The output feature channel.
+ When not specified, it will be set to `in_channels`
+ by default
+ input_feat_shape (int): The shape of input feature.
+ Defaults to 7.
+ with_proj (bool): Project two-dimentional feature to
+ one-dimentional feature. Default to True.
+ act_cfg (dict): The activation config for DynamicConv.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels=256,
+ feat_channels=64,
+ out_channels=None,
+ input_feat_shape=7,
+ with_proj=True,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super(DynamicConv, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.input_feat_shape = input_feat_shape
+ self.with_proj = with_proj
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.in_channels * self.feat_channels
+ self.num_params_out = self.out_channels * self.feat_channels
+ self.dynamic_layer = nn.Linear(
+ self.in_channels, self.num_params_in + self.num_params_out)
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ num_output = self.out_channels * input_feat_shape ** 2
+ if self.with_proj:
+ self.fc_layer = nn.Linear(num_output, self.out_channels)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, param_feature, input_feature):
+ """Forward function for `DynamicConv`.
+ Args:
+ param_feature (Tensor): The feature can be used
+ to generate the parameter, has shape
+ (num_all_proposals, in_channels).
+ input_feature (Tensor): Feature that
+ interact with parameters, has shape
+ (num_all_proposals, in_channels, H, W).
+ Returns:
+ Tensor: The output feature has shape
+ (num_all_proposals, out_channels).
+ """
+ input_feature = input_feature.flatten(2).permute(2, 0, 1)
+
+ input_feature = input_feature.permute(1, 0, 2)
+ parameters = self.dynamic_layer(param_feature)
+
+ param_in = parameters[:, :self.num_params_in].view(
+ -1, self.in_channels, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out:].view(
+ -1, self.feat_channels, self.out_channels)
+
+ # input_feature has shape (num_all_proposals, H*W, in_channels)
+ # param_in has shape (num_all_proposals, in_channels, feat_channels)
+ # feature has shape (num_all_proposals, H*W, feat_channels)
+ features = torch.bmm(input_feature, param_in)
+ features = self.norm_in(features)
+ features = self.activation(features)
+
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = torch.bmm(features, param_out)
+ features = self.norm_out(features)
+ features = self.activation(features)
+
+ if self.with_proj:
+ features = features.flatten(1)
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
diff --git a/models/version.py b/models/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e2ca150513d566dca35e045548b914b1b6bfa7
--- /dev/null
+++ b/models/version.py
@@ -0,0 +1,5 @@
+# GENERATED VERSION FILE
+# TIME: Wed May 31 16:07:32 2023
+__version__ = '0.2.0+818517e'
+short_version = '0.2.0'
+version_info = (0, 2, 0)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bf67d4663089bdd28b7505eadd28f66caef00389
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+json_tricks
+numpy
+opencv-python
+pillow==6.2.2
+xtcocotools
+scipy
+timm
+openxlab
+Openmim
\ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..e641c8d2f5f61b7b69b1554fe494fc6ea26b481f
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,22 @@
+[bdist_wheel]
+universal=1
+
+[aliases]
+test=pytest
+
+[tool:pytest]
+addopts=tests/
+
+[yapf]
+based_on_style = pep8
+blank_line_before_nested_class_or_def = true
+split_before_expression_after_opening_paren = true
+
+[isort]
+line_length = 79
+multi_line_output = 0
+known_standard_library = pkg_resources,setuptools
+known_first_party = mmpose
+known_third_party = cv2,json_tricks,mmcv,mmdet,munkres,numpy,xtcocotools,torch
+no_lines_before = STDLIB,LOCALFOLDER
+default_section = THIRDPARTY
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8405a0573fe6160576442fb35f5da71453b1cd2
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,111 @@
+import os
+import subprocess
+import time
+from setuptools import find_packages, setup
+
+
+def readme():
+ with open('README.md', encoding='utf-8') as f:
+ content = f.read()
+ return content
+
+
+version_file = 'models/version.py'
+
+
+def get_git_hash():
+
+ def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ except OSError:
+ sha = 'unknown'
+
+ return sha
+
+
+def get_hash():
+ if os.path.exists('.git'):
+ sha = get_git_hash()[:7]
+ elif os.path.exists(version_file):
+ try:
+ from models.version import __version__
+ sha = __version__.split('+')[-1]
+ except ImportError:
+ raise ImportError('Unable to get git version')
+ else:
+ sha = 'unknown'
+
+ return sha
+
+
+def write_version_py():
+ content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+short_version = '{}'
+version_info = ({})
+"""
+ sha = get_hash()
+ with open('models/VERSION', 'r') as f:
+ SHORT_VERSION = f.read().strip()
+ VERSION_INFO = ', '.join(SHORT_VERSION.split('.'))
+ VERSION = SHORT_VERSION + '+' + sha
+
+ version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
+ VERSION_INFO)
+ with open(version_file, 'w') as f:
+ f.write(version_file_str)
+
+
+def get_version():
+ with open(version_file, 'r') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+def get_requirements(filename='requirements.txt'):
+ here = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(here, filename), 'r') as f:
+ requires = [line.replace('\n', '') for line in f.readlines()]
+ return requires
+
+
+if __name__ == '__main__':
+ write_version_py()
+ setup(
+ name='pose_anything',
+ version=get_version(),
+ description='A template for pytorch projects.',
+ long_description=readme(),
+ packages=find_packages(exclude=('configs', 'tools', 'demo')),
+ package_data={'pose_anything.ops': ['*/*.so']},
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: OS Independent',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ ],
+ license='Apache License 2.0',
+ setup_requires=['pytest-runner', 'cython', 'numpy'],
+ tests_require=['pytest', 'xdoctest'],
+ install_requires=get_requirements(),
+ zip_safe=False)
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cce455cad783f5d7d4dc52800da6d74073b6843
--- /dev/null
+++ b/test.py
@@ -0,0 +1,162 @@
+import argparse
+import os
+import os.path as osp
+import random
+import uuid
+
+import mmcv
+import numpy as np
+import torch
+from mmcv import Config, DictAction
+from mmcv.cnn import fuse_conv_bn
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import get_dist_info, init_dist, load_checkpoint
+from models import * # noqa
+from models.datasets import build_dataset
+
+from mmpose.apis import multi_gpu_test, single_gpu_test
+from mmpose.core import wrap_fp16_model
+from mmpose.datasets import build_dataloader
+from mmpose.models import build_posenet
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='mmpose test model')
+ parser.add_argument('config', default=None, help='test config file path')
+ parser.add_argument('checkpoint', default=None, help='checkpoint file')
+ parser.add_argument('--out', help='output result file')
+ parser.add_argument(
+ '--fuse-conv-bn',
+ action='store_true',
+ help='Whether to fuse conv and bn, this will slightly increase the inference speed')
+ parser.add_argument(
+ '--eval',
+ default=None,
+ nargs='+',
+ help='evaluation metric, which depends on the dataset,'
+ ' e.g., "mAP" for MSCOCO')
+ parser.add_argument(
+ '--permute_keypoints',
+ action='store_true',
+ help='whether to randomly permute keypoints')
+ parser.add_argument(
+ '--gpu_collect',
+ action='store_true',
+ help='whether to use gpu to collect results')
+ parser.add_argument('--tmpdir', help='tmp dir for writing some results')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ default={},
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. For example, '
+ "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+ return args
+
+
+def merge_configs(cfg1, cfg2):
+ # Merge cfg2 into cfg1
+ # Overwrite cfg1 if repeated, ignore if value is None.
+ cfg1 = {} if cfg1 is None else cfg1.copy()
+ cfg2 = {} if cfg2 is None else cfg2
+ for k, v in cfg2.items():
+ if v:
+ cfg1[k] = v
+ return cfg1
+
+
+def main():
+ random.seed(0)
+ np.random.seed(0)
+ torch.manual_seed(0)
+ uuid.UUID(int=0)
+
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ # cfg.model.pretrained = None
+ cfg.data.test.test_mode = True
+
+ args.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+ mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+
+ # build the dataloader
+ dataset = build_dataset(cfg.data.test, dict(test_mode=True))
+ dataloader_setting = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.get('workers_per_gpu', 12),
+ dist=distributed,
+ shuffle=False,
+ drop_last=False)
+ dataloader_setting = dict(dataloader_setting,
+ **cfg.data.get('test_dataloader', {}))
+ data_loader = build_dataloader(dataset, **dataloader_setting)
+
+ # build the model and load checkpoint
+ model = build_posenet(cfg.model)
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ load_checkpoint(model, args.checkpoint, map_location='cpu')
+
+ if args.fuse_conv_bn:
+ model = fuse_conv_bn(model)
+
+ if not distributed:
+ model = MMDataParallel(model, device_ids=[0])
+ outputs = single_gpu_test(model, data_loader)
+ else:
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False)
+ outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect)
+
+ rank, _ = get_dist_info()
+ eval_config = cfg.get('evaluation', {})
+ eval_config = merge_configs(eval_config, dict(metric=args.eval))
+
+ if rank == 0:
+ if args.out:
+ print(f'\nwriting results to {args.out}')
+ mmcv.dump(outputs, args.out)
+
+ results = dataset.evaluate(outputs, **eval_config)
+ print('\n')
+ for k, v in sorted(results.items()):
+ print(f'{k}: {v}')
+
+ # save testing log
+ test_log = "./work_dirs/testing_log.txt"
+ with open(test_log, 'a') as f:
+ f.write("** config_file: " + args.config + "\t checkpoint: " + args.checkpoint + "\t \n")
+ for k, v in sorted(results.items()):
+ f.write(f'\t {k}: {v}'+'\n')
+ f.write("********************************************************************\n")
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/dist_test.sh b/tools/dist_test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e4d2a6b29bf93aaf10e1533883d39ac8cbc09432
--- /dev/null
+++ b/tools/dist_test.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+# Copyright (c) OpenMMLab. All rights reserved.
+
+CONFIG=$1
+CHECKPOINT=$2
+GPUS=$3
+PORT=${PORT:-29000}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+ $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
diff --git a/tools/dist_train.sh b/tools/dist_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..536a001dd8c3cc24a704d73088a00d1e5904490a
--- /dev/null
+++ b/tools/dist_train.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+# Copyright (c) OpenMMLab. All rights reserved.
+
+CONFIG=$1
+GPUS=$2
+OUTPUT_DIR=$3
+PORT=${PORT:-29000}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+ $(dirname "$0")/train.py $CONFIG --work-dir $OUTPUT_DIR --launcher pytorch ${@:3}
diff --git a/tools/fix_carfuxion.py b/tools/fix_carfuxion.py
new file mode 100644
index 0000000000000000000000000000000000000000..abcdaae33b483bfaf667986df2f7d429b7bdc81a
--- /dev/null
+++ b/tools/fix_carfuxion.py
@@ -0,0 +1,77 @@
+import json
+import os
+import shutil
+import sys
+import numpy as np
+from xtcocotools.coco import COCO
+
+
+def search_match(bbox, num_keypoints, segmentation):
+ found = []
+ checked = 0
+ for json_file, coco in COCO_DICT.items():
+ cat_ids = coco.getCatIds()
+ for cat_id in cat_ids:
+ img_ids = coco.getImgIds(catIds=cat_id)
+ for img_id in img_ids:
+ annotations = coco.loadAnns(coco.getAnnIds(imgIds=img_id, catIds=cat_id))
+ for ann in annotations:
+ checked += 1
+ if (ann['num_keypoints'] == num_keypoints and ann['bbox'] == bbox and ann[
+ 'segmentation'] == segmentation):
+ src_file = coco.loadImgs(img_id)[0]["file_name"]
+ split = "test" if "test" in json_file else "train"
+ found.append((src_file, ann, split))
+ # return src_file, ann, split
+ if len(found) == 0:
+ raise Exception("No match found out of {} images".format(checked))
+ elif len(found) > 1:
+ raise Exception("More than one match! ".format(found))
+ return found[0]
+
+if __name__ == "__main__":
+
+ carfusion_dir_path = sys.argv[1]
+ mp100_dataset_path = sys.argv[2]
+ os.makedirs('output', exist_ok=True)
+ for cat in ['car', 'bus', 'suv']:
+ os.makedirs(os.path.join('output', cat), exist_ok=True)
+
+
+ COCO_DICT = {}
+ ann_files = os.path.join(carfusion_dir_path, 'annotations')
+ for json_file in os.listdir(ann_files):
+ COCO_DICT[json_file] = COCO(os.path.join(carfusion_dir_path, 'annotations', json_file))
+
+ count = 0
+ print_log = []
+ for json_file in os.listdir(mp100_dataset_path):
+ print("Processing {}".format(json_file))
+ cats = {}
+ coco = COCO(os.path.join(mp100_dataset_path, json_file))
+ cat_ids = coco.getCatIds()
+ for cat_id in cat_ids:
+ category_info = coco.loadCats(cat_id)
+ cat_name = category_info[0]['name']
+ if cat_name in ['car', 'bus', 'suv']:
+ cats[cat_name] = cat_id
+
+
+ for cat_name, cat_id in cats.items():
+ img_ids = coco.getImgIds(catIds=cat_id)
+ count += len(img_ids)
+ print_log.append(f'{json_file} : {cat_name}: {len(img_ids)}')
+ for img_id in img_ids:
+ img = coco.loadImgs(img_id)[0]
+ dst_file_name = img['file_name']
+ annotation = coco.loadAnns(coco.getAnnIds(imgIds=img_id, catIds=cat_id, iscrowd=None))
+ bbox = annotation[0]['bbox']
+ keypoints = annotation[0]['keypoints']
+ segmentation = annotation[0]['segmentation']
+ num_keypoints = annotation[0]['num_keypoints']
+
+ # Search for a match:
+ src_img, src_ann, split = search_match(bbox, num_keypoints, segmentation)
+ shutil.copyfile(
+ os.path.join(carfusion_dir_path, split, src_img),
+ os.path.join('output', dst_file_name))
\ No newline at end of file
diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..24de10aa3e38c8da3bf1e4d0bc6e503ea4dbc54c
--- /dev/null
+++ b/tools/slurm_test.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+
+set -x
+
+PARTITION=$1
+JOB_NAME=$2
+CONFIG=$3
+CHECKPOINT=$4
+GPUS=${GPUS:-8}
+GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+CPUS_PER_TASK=${CPUS_PER_TASK:-5}
+PY_ARGS=${@:5}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ python -u test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..39a0caadad027e439cf544e0cf1081b257a1ac85
--- /dev/null
+++ b/tools/slurm_train.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+
+set -x
+
+PARTITION=$1
+JOB_NAME=$2
+CONFIG=$3
+WORK_DIR=$4
+GPUS=${GPUS:-8}
+GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+CPUS_PER_TASK=${CPUS_PER_TASK:-5}
+PY_ARGS=${@:5}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ python -u train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
diff --git a/tools/visualization.py b/tools/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a4b95716c0bd5236241e26d516fece0a3cf71fa
--- /dev/null
+++ b/tools/visualization.py
@@ -0,0 +1,67 @@
+import os
+import random
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+import uuid
+
+colors = [
+ [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
+ [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
+ [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
+ [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]]
+
+
+def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w, skeleton,
+ initial_proposals, prediction, radius=6, out_dir='./heatmaps'):
+ img_names = [img.split("_")[0] for img in os.listdir(out_dir) if str_is_int(img.split("_")[0])]
+ if len(img_names) > 0:
+ name_idx = max([int(img_name) for img_name in img_names]) + 1
+ else:
+ name_idx = 0
+
+ h, w, c = support_img.shape
+ prediction = prediction[-1].cpu().numpy() * h
+ support_img = (support_img - np.min(support_img)) / (np.max(support_img) - np.min(support_img))
+ query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
+
+ for id, (img, w, keypoint) in enumerate(zip([support_img, query_img],
+ [support_w, query_w],
+ [support_kp, prediction])):
+ f, axes = plt.subplots()
+ plt.imshow(img)
+ for k in range(keypoint.shape[0]):
+ if w[k] > 0:
+ kp = keypoint[k, :2]
+ c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)
+ patch = plt.Circle(kp, radius, color=c)
+ axes.add_patch(patch)
+ axes.text(kp[0], kp[1], k)
+ plt.draw()
+ for l, limb in enumerate(skeleton):
+ kp = keypoint[:, :2]
+ if l > len(colors) - 1:
+ c = [x / 255 for x in random.sample(range(0, 255), 3)]
+ else:
+ c = [x / 255 for x in colors[l]]
+ if w[limb[0]] > 0 and w[limb[1]] > 0:
+ patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]],
+ [kp[limb[0], 1], kp[limb[1], 1]],
+ linewidth=6, color=c, alpha=0.6)
+ axes.add_artist(patch)
+ plt.axis('off') # command for hiding the axis.
+ name = 'support' if id == 0 else 'query'
+ plt.savefig(f'./{out_dir}/{str(name_idx)}_{str(name)}.png', bbox_inches='tight', pad_inches=0)
+ if id == 1:
+ plt.show()
+ plt.clf()
+ plt.close('all')
+
+
+def str_is_int(s):
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ed2e93037a56df684a5dd2790b5116e1618524b
--- /dev/null
+++ b/train.py
@@ -0,0 +1,190 @@
+import argparse
+import copy
+import os
+import os.path as osp
+import time
+
+import mmcv
+import torch
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist, set_random_seed
+from mmcv.utils import get_git_hash
+
+from models import * # noqa
+from models.apis import train_model
+from models.datasets import build_dataset
+
+from mmpose import __version__
+from mmpose.models import build_posenet
+from mmpose.utils import collect_env, get_root_logger
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a pose model')
+ parser.add_argument('--config', default=None, help='train config file path')
+ parser.add_argument('--work-dir', default=None, help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume-from', help='the checkpoint file to resume from')
+ parser.add_argument(
+ '--auto-resume', type=bool, default=True, help='automatically detect the latest checkpoint in word dir and resume from it.')
+ parser.add_argument(
+ '--no-validate',
+ action='store_true',
+ help='whether not to evaluate the checkpoint during training')
+ group_gpus = parser.add_mutually_exclusive_group()
+ group_gpus.add_argument(
+ '--gpus',
+ type=int,
+ help='number of gpus to use '
+ '(only applicable to non-distributed training)')
+ group_gpus.add_argument(
+ '--gpu-ids',
+ type=int,
+ nargs='+',
+ help='ids of gpus to use '
+ '(only applicable to non-distributed training)')
+ parser.add_argument('--seed', type=int, default=None, help='random seed')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ default={},
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. For example, '
+ "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--autoscale-lr',
+ action='store_true',
+ help='automatically scale lr with the number of gpus')
+ parser.add_argument(
+ '--show',
+ action='store_true',
+ help='whether to display the prediction results in a window.')
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+
+ # work_dir is determined in this priority: CLI
+ # > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+ # auto resume
+ if args.auto_resume:
+ checkpoint = os.path.join(args.work_dir, 'latest.pth')
+ if os.path.exists(checkpoint):
+ cfg.resume_from = checkpoint
+
+ if args.resume_from is not None:
+ cfg.resume_from = args.resume_from
+ if args.gpu_ids is not None:
+ cfg.gpu_ids = args.gpu_ids
+ else:
+ cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+
+ if args.autoscale_lr:
+ # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
+ cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+ # re-set gpu_ids with distributed training mode
+ _, world_size = get_dist_info()
+ cfg.gpu_ids = range(world_size)
+
+ # create work_dir
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+ # init the logger before other steps
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+ logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
+
+ # init the meta dict to record some important information such as
+ # environment info and seed, which will be logged
+ meta = dict()
+ # log env info
+ env_info_dict = collect_env()
+ env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+ dash_line)
+ meta['env_info'] = env_info
+
+ # log some basic info
+ logger.info(f'Distributed training: {distributed}')
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # set random seeds
+ args.seed = 1
+ args.deterministic = True
+ if args.seed is not None:
+ logger.info(f'Set random seed to {args.seed}, '
+ f'deterministic: {args.deterministic}')
+ set_random_seed(args.seed, deterministic=args.deterministic)
+ cfg.seed = args.seed
+ meta['seed'] = args.seed
+
+ model = build_posenet(cfg.model)
+ train_datasets = [build_dataset(cfg.data.train)]
+
+ # if len(cfg.workflow) == 2:
+ # val_dataset = copy.deepcopy(cfg.data.val)
+ # val_dataset.pipeline = cfg.data.train.pipeline
+ # datasets.append(build_dataset(val_dataset))
+
+ val_dataset = copy.deepcopy(cfg.data.val)
+ val_dataset = build_dataset(val_dataset, dict(test_mode=True))
+
+ if cfg.checkpoint_config is not None:
+ # save mmpose version, config file content
+ # checkpoints as meta data
+ cfg.checkpoint_config.meta = dict(
+ mmpose_version=__version__ + get_git_hash(digits=7),
+ config=cfg.pretty_text,
+ )
+ train_model(
+ model,
+ train_datasets,
+ val_dataset,
+ cfg,
+ distributed=distributed,
+ validate=(not args.no_validate),
+ timestamp=timestamp,
+ meta=meta)
+
+
+if __name__ == '__main__':
+ main()