Duplicate from nev/CoNR
Browse filesCo-authored-by: Stepan Shabalin <[email protected]>
- .gitattributes +33 -0
- .gitignore +14 -0
- LICENSE +7 -0
- README.md +127 -0
- README_chinese.md +124 -0
- app.py +59 -0
- character_sheet_ponytail_example/0.png +3 -0
- character_sheet_ponytail_example/1.png +3 -0
- character_sheet_ponytail_example/2.png +3 -0
- character_sheet_ponytail_example/3.png +3 -0
- conr.py +292 -0
- data_loader.py +275 -0
- images/MAIN.png +3 -0
- infer.sh +13 -0
- model/__init__.py +1 -0
- model/backbone.py +289 -0
- model/decoder_small.py +43 -0
- model/shader.py +290 -0
- model/warplayer.py +56 -0
- notebooks/conr.ipynb +0 -0
- notebooks/conr_chinese.ipynb +0 -0
- poses_template.zip +3 -0
- requirements.txt +9 -0
- train.py +203 -0
- weights/.gitkeep +0 -0
- weights/rgbadecodernet.pth +3 -0
- weights/shader.pth +3 -0
- weights/target_pose_encoder.pth +3 -0
- weights/udpparsernet.pth +3 -0
.gitattributes
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
33 |
+
poses.zip filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
results/
|
2 |
+
poses/
|
3 |
+
test_data/
|
4 |
+
test_data_pre/
|
5 |
+
character_sheet/
|
6 |
+
# weights/
|
7 |
+
*.mp4
|
8 |
+
*.webm
|
9 |
+
poses.zip
|
10 |
+
gradio_cached_examples/
|
11 |
+
x264/
|
12 |
+
filelist.txt
|
13 |
+
complex_infer.sh
|
14 |
+
__pycache__/
|
LICENSE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2022 Megvii Inc.
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: CoNR
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.1.4
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: nev/CoNR
|
12 |
+
---
|
13 |
+
|
14 |
+
[English](https://github.com/megvii-research/CoNR/blob/main/README.md) | [中文](https://github.com/megvii-research/CoNR/blob/main/README_chinese.md)
|
15 |
+
# Collaborative Neural Rendering using Anime Character Sheets
|
16 |
+
|
17 |
+
|
18 |
+
## [Homepage](https://conr.ml) | Colab [English](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr.ipynb)/[中文](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr_chinese.ipynb) | [arXiv](https://arxiv.org/abs/2207.05378)
|
19 |
+
|
20 |
+

|
21 |
+
|
22 |
+
## Introduction
|
23 |
+
|
24 |
+
This project is the official implement of [Collaborative Neural Rendering using Anime Character Sheets](https://arxiv.org/abs/2207.05378), which aims to genarate vivid dancing videos from hand-drawn anime character sheets(ACS). Watch more demos in our [HomePage](https://conr.ml).
|
25 |
+
|
26 |
+
Contributors: [@transpchan](https://github.com/transpchan/), [@P2Oileen](https://github.com/P2Oileen), [@hzwer](https://github.com/hzwer)
|
27 |
+
|
28 |
+
## Usage
|
29 |
+
|
30 |
+
#### Prerequisites
|
31 |
+
|
32 |
+
* NVIDIA GPU + CUDA + CUDNN
|
33 |
+
* Python 3.6
|
34 |
+
|
35 |
+
#### Installation
|
36 |
+
|
37 |
+
* Clone this repository
|
38 |
+
|
39 |
+
```bash
|
40 |
+
git clone https://github.com/megvii-research/CoNR
|
41 |
+
```
|
42 |
+
|
43 |
+
* Dependencies
|
44 |
+
|
45 |
+
To install all the dependencies, please run the following commands.
|
46 |
+
|
47 |
+
```bash
|
48 |
+
cd CoNR
|
49 |
+
pip install -r requirements.txt
|
50 |
+
```
|
51 |
+
|
52 |
+
* Download Weights
|
53 |
+
Download weights from Google Drive. Alternatively, you can download from [Baidu Netdisk](https://pan.baidu.com/s/1U11iIk-DiJodgCveSzB6ig?pwd=RDxc) (password:RDxc).
|
54 |
+
|
55 |
+
```
|
56 |
+
mkdir weights && cd weights
|
57 |
+
gdown https://drive.google.com/uc?id=1M1LEpx70tJ72AIV2TQKr6NE_7mJ7tLYx
|
58 |
+
gdown https://drive.google.com/uc?id=1YvZy3NHkJ6gC3pq_j8agcbEJymHCwJy0
|
59 |
+
gdown https://drive.google.com/uc?id=1AOWZxBvTo9nUf2_9Y7Xe27ZFQuPrnx9i
|
60 |
+
gdown https://drive.google.com/uc?id=19jM1-GcqgGoE1bjmQycQw_vqD9C5e-Jm
|
61 |
+
```
|
62 |
+
|
63 |
+
#### Prepare Inputs
|
64 |
+
We provide two Ultra-Dense Pose sequences for two characters. You can generate more UDPs via 3D models and motions refers to [our paper](https://arxiv.org/abs/2207.05378).
|
65 |
+
[Baidu Netdisk](https://pan.baidu.com/s/1hWvz4iQXnVTaTSb6vu1NBg?pwd=RDxc) (password:RDxc)
|
66 |
+
|
67 |
+
```
|
68 |
+
# for short hair girl
|
69 |
+
gdown https://drive.google.com/uc?id=11HMSaEkN__QiAZSnCuaM6GI143xo62KO
|
70 |
+
unzip short_hair.zip
|
71 |
+
mv short_hair/ poses/
|
72 |
+
|
73 |
+
# for double ponytail girl
|
74 |
+
gdown https://drive.google.com/uc?id=1WNnGVuU0ZLyEn04HzRKzITXqib1wwM4Q
|
75 |
+
unzip double_ponytail.zip
|
76 |
+
mv double_ponytail/ poses/
|
77 |
+
```
|
78 |
+
|
79 |
+
We provide sample inputs of anime character sheets. You can also draw more by yourself.
|
80 |
+
Character sheets need to be cut out from the background and in png format.
|
81 |
+
[Baidu Netdisk](https://pan.baidu.com/s/1shpP90GOMeHke7MuT0-Txw?pwd=RDxc) (password:RDxc)
|
82 |
+
|
83 |
+
```
|
84 |
+
# for short hair girl
|
85 |
+
gdown https://drive.google.com/uc?id=1r-3hUlENSWj81ve2IUPkRKNB81o9WrwT
|
86 |
+
unzip short_hair_images.zip
|
87 |
+
mv short_hair_images/ character_sheet/
|
88 |
+
|
89 |
+
# for double ponytail girl
|
90 |
+
gdown https://drive.google.com/uc?id=1XMrJf9Lk_dWgXyTJhbEK2LZIXL9G3MWc
|
91 |
+
unzip double_ponytail_images.zip
|
92 |
+
mv double_ponytail_images/ character_sheet/
|
93 |
+
```
|
94 |
+
|
95 |
+
#### RUN!
|
96 |
+
* with web UI (powered by [Streamlit](https://streamlit.io/))
|
97 |
+
|
98 |
+
```
|
99 |
+
streamlit run streamlit.py --server.port=8501
|
100 |
+
```
|
101 |
+
then open your browser and visit `localhost:8501`, follow the instructions to genarate video.
|
102 |
+
|
103 |
+
* via terminal
|
104 |
+
|
105 |
+
```
|
106 |
+
mkdir {dir_to_save_result}
|
107 |
+
|
108 |
+
python -m torch.distributed.launch \
|
109 |
+
--nproc_per_node=1 train.py --mode=test \
|
110 |
+
--world_size=1 --dataloaders=2 \
|
111 |
+
--test_input_poses_images={dir_to_poses} \
|
112 |
+
--test_input_person_images={dir_to_character_sheet} \
|
113 |
+
--test_output_dir={dir_to_save_result} \
|
114 |
+
--test_checkpoint_dir={dir_to_weights}
|
115 |
+
|
116 |
+
ffmpeg -r 30 -y -i {dir_to_save_result}/%d.png -r 30 -c:v libx264 output.mp4 -r 30
|
117 |
+
```
|
118 |
+
|
119 |
+
## Citation
|
120 |
+
```bibtex
|
121 |
+
@article{lin2022conr,
|
122 |
+
title={Collaborative Neural Rendering using Anime Character Sheets},
|
123 |
+
author={Lin, Zuzeng and Huang, Ailin and Huang, Zhewei and Hu, Chen and Zhou, Shuchang},
|
124 |
+
journal={arXiv preprint arXiv:2207.05378},
|
125 |
+
year={2022}
|
126 |
+
}
|
127 |
+
```
|
README_chinese.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[English](https://github.com/megvii-research/CoNR/blob/main/README.md) | [中文](https://github.com/megvii-research/CoNR/blob/main/README_chinese.md)
|
2 |
+
# 用于二次元手绘设定稿动画化的神经渲染器
|
3 |
+
|
4 |
+
|
5 |
+
## [HomePage](https://conr.ml) | Colab [English](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr.ipynb)/[中文](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr_chinese.ipynb) | [arXiv](https://arxiv.org/abs/2207.05378)
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
## Introduction
|
10 |
+
|
11 |
+
该项目为论文[Collaborative Neural Rendering using Anime Character Sheets](https://arxiv.org/abs/2207.05378)的官方复现,旨在从手绘人物设定稿生成生动的舞蹈动画。您可以在我们的[主页](https://conr.ml)中查看更多视频 demo。
|
12 |
+
|
13 |
+
贡献者: [@transpchan](https://github.com/transpchan/), [@P2Oileen](https://github.com/P2Oileen), [@hzwer](https://github.com/hzwer)
|
14 |
+
|
15 |
+
## 使用方法
|
16 |
+
|
17 |
+
#### 需求
|
18 |
+
|
19 |
+
* Nvidia GPU + CUDA + CUDNN
|
20 |
+
* Python 3.6
|
21 |
+
|
22 |
+
#### 安装
|
23 |
+
|
24 |
+
* 克隆该项目
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://github.com/megvii-research/CoNR
|
28 |
+
```
|
29 |
+
|
30 |
+
* 安装依赖
|
31 |
+
|
32 |
+
请运行以下命令以安装CoNR所需的所有依赖。
|
33 |
+
|
34 |
+
```bash
|
35 |
+
cd CoNR
|
36 |
+
pip install -r requirements.txt
|
37 |
+
```
|
38 |
+
|
39 |
+
* 下载权重
|
40 |
+
运行以下代码,从 Google Drive 下载模型的权重。此外, 你也可以从 [百度云盘](https://pan.baidu.com/s/1U11iIk-DiJodgCveSzB6ig?pwd=RDxc) (password:RDxc)下载权重。
|
41 |
+
|
42 |
+
```
|
43 |
+
mkdir weights && cd weights
|
44 |
+
gdown https://drive.google.com/uc?id=1M1LEpx70tJ72AIV2TQKr6NE_7mJ7tLYx
|
45 |
+
gdown https://drive.google.com/uc?id=1YvZy3NHkJ6gC3pq_j8agcbEJymHCwJy0
|
46 |
+
gdown https://drive.google.com/uc?id=1AOWZxBvTo9nUf2_9Y7Xe27ZFQuPrnx9i
|
47 |
+
gdown https://drive.google.com/uc?id=19jM1-GcqgGoE1bjmQycQw_vqD9C5e-Jm
|
48 |
+
```
|
49 |
+
|
50 |
+
#### Prepare inputs
|
51 |
+
我们为两个不同的人物,准备了两个超密集姿势(Ultra-Dense Pose)序列,从以下代码中二选一运行,即可从 Google Drive 下载。您可以通过任意的3D模型和动作数据,生成更多的超密集姿势序列,参考我们的[论文](https://arxiv.org/abs/2207.05378)。暂不提供官方转换接口。
|
52 |
+
[百度云盘](https://pan.baidu.com/s/1hWvz4iQXnVTaTSb6vu1NBg?pwd=RDxc) (password:RDxc)
|
53 |
+
|
54 |
+
```
|
55 |
+
# 短发女孩的超密集姿势
|
56 |
+
gdown https://drive.google.com/uc?id=11HMSaEkN__QiAZSnCuaM6GI143xo62KO
|
57 |
+
unzip short_hair.zip
|
58 |
+
mv short_hair/ poses/
|
59 |
+
|
60 |
+
# 双马尾女孩的超密集姿势
|
61 |
+
gdown https://drive.google.com/uc?id=1WNnGVuU0ZLyEn04HzRKzITXqib1wwM4Q
|
62 |
+
unzip double_ponytail.zip
|
63 |
+
mv double_ponytail/ poses/
|
64 |
+
```
|
65 |
+
|
66 |
+
我们提供两个人物手绘设定表的样例,从以下代码中二选一运行,即可从 Google Drive下载。您也可以自行绘制。
|
67 |
+
请注意:人物手绘设定表**必须从背景中分割开**,且必须为png格式。
|
68 |
+
[百度云盘](https://pan.baidu.com/s/1shpP90GOMeHke7MuT0-Txw?pwd=RDxc) (password:RDxc)
|
69 |
+
|
70 |
+
```
|
71 |
+
# 短发女孩的手绘设定表
|
72 |
+
gdown https://drive.google.com/uc?id=1r-3hUlENSWj81ve2IUPkRKNB81o9WrwT
|
73 |
+
unzip short_hair_images.zip
|
74 |
+
mv short_hair_images/ character_sheet/
|
75 |
+
|
76 |
+
# 双马尾女孩的手绘设定表
|
77 |
+
gdown https://drive.google.com/uc?id=1XMrJf9Lk_dWgXyTJhbEK2LZIXL9G3MWc
|
78 |
+
unzip double_ponytail_images.zip
|
79 |
+
mv double_ponytail_images/ character_sheet/
|
80 |
+
```
|
81 |
+
|
82 |
+
#### 运行!
|
83 |
+
我们提供两种方案:使用web图形界面,或使用命令行代码运行。
|
84 |
+
|
85 |
+
* 使用web图形界面 (通过 [Streamlit](https://streamlit.io/) 实现)
|
86 |
+
|
87 |
+
运行以下代码:
|
88 |
+
|
89 |
+
```
|
90 |
+
streamlit run streamlit.py --server_port.8501
|
91 |
+
```
|
92 |
+
|
93 |
+
然后打开浏览器并访问 `localhost:8501`, 根据页面内的指示生成视频。
|
94 |
+
|
95 |
+
* 使用命令行代码
|
96 |
+
|
97 |
+
请注意替换`{}`内容,并更换为您放置相应内容的文件夹位置。
|
98 |
+
|
99 |
+
```
|
100 |
+
mkdir {结果保存路径}
|
101 |
+
|
102 |
+
python -m torch.distributed.launch \
|
103 |
+
--nproc_per_node=1 train.py --mode=test \
|
104 |
+
--world_size=1 --dataloaders=2 \
|
105 |
+
--test_input_poses_images={姿势路径} \
|
106 |
+
--test_input_person_images={人物设定表路径} \
|
107 |
+
--test_output_dir={结果保存路径} \
|
108 |
+
--test_checkpoint_dir={权重路径}
|
109 |
+
|
110 |
+
ffmpeg -r 30 -y -i {结果保存路径}/%d.png -r 30 -c:v libx264 output.mp4 -r 30
|
111 |
+
```
|
112 |
+
|
113 |
+
视频结果将生成在 `CoNR/output.mp4`。
|
114 |
+
|
115 |
+
## 引用CoNR
|
116 |
+
```bibtex
|
117 |
+
@article{lin2022conr,
|
118 |
+
title={Collaborative Neural Rendering using Anime Character Sheets},
|
119 |
+
author={Lin, Zuzeng and Huang, Ailin and Huang, Zhewei and Hu, Chen and Zhou, Shuchang},
|
120 |
+
journal={arXiv preprint arXiv:2207.05378},
|
121 |
+
year={2022}
|
122 |
+
}
|
123 |
+
```
|
124 |
+
|
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import base64
|
5 |
+
|
6 |
+
|
7 |
+
def get_base64(bin_file):
|
8 |
+
with open(bin_file, "rb") as f:
|
9 |
+
data = f.read()
|
10 |
+
return base64.b64encode(data).decode()
|
11 |
+
|
12 |
+
|
13 |
+
def conr_fn(character_sheets, pose_zip):
|
14 |
+
os.system("rm character_sheet/*")
|
15 |
+
os.system("rm result/*")
|
16 |
+
os.system("rm poses/*")
|
17 |
+
os.makedirs("character_sheet", exist_ok=True)
|
18 |
+
for i, e in enumerate(character_sheets):
|
19 |
+
with open(f"character_sheet/{i}.png", "wb") as f:
|
20 |
+
e.seek(0)
|
21 |
+
f.write(e.read())
|
22 |
+
e.seek(0)
|
23 |
+
os.makedirs("poses", exist_ok=True)
|
24 |
+
pose_zip.seek(0)
|
25 |
+
open("poses.zip", "wb").write(pose_zip.read())
|
26 |
+
os.system(f"unzip -d poses poses.zip")
|
27 |
+
os.system("sh infer.sh")
|
28 |
+
return "output.mp4"
|
29 |
+
|
30 |
+
|
31 |
+
with gr.Blocks() as ui:
|
32 |
+
gr.Markdown("CoNR demo")
|
33 |
+
gr.Markdown("<a target='_blank' href='https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr.ipynb'> <img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a> [GitHub](https://github.com/megvii-research/CoNR/)")
|
34 |
+
gr.Markdown("Unofficial demo for [CoNR](https://transpchan.github.io/live3d/).")
|
35 |
+
|
36 |
+
with gr.Row():
|
37 |
+
# with gr.Column():
|
38 |
+
# gr.Markdown("## Parse video")
|
39 |
+
# gr.Markdown("TBD")
|
40 |
+
with gr.Column():
|
41 |
+
gr.Markdown("## Animate character")
|
42 |
+
gr.Markdown("Character sheet")
|
43 |
+
character_sheets = gr.File(file_count="multiple")
|
44 |
+
gr.Markdown("Pose zip") # Don't hack
|
45 |
+
pose_video = gr.File(file_count="single")
|
46 |
+
|
47 |
+
# os.system("sh download.sh")
|
48 |
+
run = gr.Button("Run")
|
49 |
+
video = gr.Video()
|
50 |
+
run.click(fn=conr_fn, inputs=[character_sheets, pose_video], outputs=video)
|
51 |
+
|
52 |
+
gr.Markdown("## Examples")
|
53 |
+
sheets = "character_sheet_ponytail_example"
|
54 |
+
gr.Examples(fn=conr_fn, inputs=[character_sheets, pose_video], outputs=video,
|
55 |
+
examples=[[[os.path.join(sheets, x) for x in os.listdir(sheets)], "poses_template.zip"]], cache_examples=True, examples_per_page=1)
|
56 |
+
|
57 |
+
# ui.launch()
|
58 |
+
demo = ui
|
59 |
+
demo.launch()
|
character_sheet_ponytail_example/0.png
ADDED
![]() |
Git LFS Details
|
character_sheet_ponytail_example/1.png
ADDED
![]() |
Git LFS Details
|
character_sheet_ponytail_example/2.png
ADDED
![]() |
Git LFS Details
|
character_sheet_ponytail_example/3.png
ADDED
![]() |
Git LFS Details
|
conr.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.backbone import ResEncUnet
|
6 |
+
|
7 |
+
from model.shader import CINN
|
8 |
+
from model.decoder_small import RGBADecoderNet
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
|
13 |
+
def UDPClip(x):
|
14 |
+
return torch.clamp(x, min=0, max=1) # NCHW
|
15 |
+
|
16 |
+
|
17 |
+
class CoNR():
|
18 |
+
def __init__(self, args):
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
self.udpparsernet = ResEncUnet(
|
22 |
+
backbone_name='resnet50_danbo',
|
23 |
+
classes=4,
|
24 |
+
pretrained=(args.local_rank == 0),
|
25 |
+
parametric_upsampling=True,
|
26 |
+
decoder_filters=(512, 384, 256, 128, 32),
|
27 |
+
map_location=device
|
28 |
+
)
|
29 |
+
self.target_pose_encoder = ResEncUnet(
|
30 |
+
backbone_name='resnet18_danbo-4',
|
31 |
+
classes=1,
|
32 |
+
pretrained=(args.local_rank == 0),
|
33 |
+
parametric_upsampling=True,
|
34 |
+
decoder_filters=(512, 384, 256, 128, 32),
|
35 |
+
map_location=device
|
36 |
+
)
|
37 |
+
self.DIM_SHADER_REFERENCE = 4
|
38 |
+
self.shader = CINN(self.DIM_SHADER_REFERENCE)
|
39 |
+
self.rgbadecodernet = RGBADecoderNet(
|
40 |
+
)
|
41 |
+
self.device()
|
42 |
+
self.parser_ckpt = None
|
43 |
+
|
44 |
+
def dist(self):
|
45 |
+
args = self.args
|
46 |
+
if args.distributed:
|
47 |
+
self.udpparsernet = torch.nn.parallel.DistributedDataParallel(
|
48 |
+
self.udpparsernet,
|
49 |
+
device_ids=[
|
50 |
+
args.local_rank],
|
51 |
+
output_device=args.local_rank,
|
52 |
+
broadcast_buffers=False,
|
53 |
+
find_unused_parameters=True
|
54 |
+
)
|
55 |
+
self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel(
|
56 |
+
self.target_pose_encoder,
|
57 |
+
device_ids=[
|
58 |
+
args.local_rank],
|
59 |
+
output_device=args.local_rank,
|
60 |
+
broadcast_buffers=False,
|
61 |
+
find_unused_parameters=True
|
62 |
+
)
|
63 |
+
self.shader = torch.nn.parallel.DistributedDataParallel(
|
64 |
+
self.shader,
|
65 |
+
device_ids=[
|
66 |
+
args.local_rank],
|
67 |
+
output_device=args.local_rank,
|
68 |
+
broadcast_buffers=True
|
69 |
+
)
|
70 |
+
|
71 |
+
self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel(
|
72 |
+
self.rgbadecodernet,
|
73 |
+
device_ids=[
|
74 |
+
args.local_rank],
|
75 |
+
output_device=args.local_rank,
|
76 |
+
broadcast_buffers=True
|
77 |
+
)
|
78 |
+
|
79 |
+
def load_model(self, path):
|
80 |
+
self.udpparsernet.load_state_dict(
|
81 |
+
torch.load('{}/udpparsernet.pth'.format(path), map_location=device))
|
82 |
+
self.target_pose_encoder.load_state_dict(
|
83 |
+
torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device))
|
84 |
+
self.shader.load_state_dict(
|
85 |
+
torch.load('{}/shader.pth'.format(path), map_location=device))
|
86 |
+
self.rgbadecodernet.load_state_dict(
|
87 |
+
torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device))
|
88 |
+
|
89 |
+
def save_model(self, ite_num):
|
90 |
+
self._save_pth(self.udpparsernet,
|
91 |
+
model_name="udpparsernet", ite_num=ite_num)
|
92 |
+
self._save_pth(self.target_pose_encoder,
|
93 |
+
model_name="target_pose_encoder", ite_num=ite_num)
|
94 |
+
self._save_pth(self.shader,
|
95 |
+
model_name="shader", ite_num=ite_num)
|
96 |
+
self._save_pth(self.rgbadecodernet,
|
97 |
+
model_name="rgbadecodernet", ite_num=ite_num)
|
98 |
+
|
99 |
+
def _save_pth(self, net, model_name, ite_num):
|
100 |
+
args = self.args
|
101 |
+
to_save = None
|
102 |
+
if args.distributed:
|
103 |
+
if args.local_rank == 0:
|
104 |
+
to_save = net.module.state_dict()
|
105 |
+
else:
|
106 |
+
to_save = net.state_dict()
|
107 |
+
if to_save:
|
108 |
+
model_dir = os.path.join(
|
109 |
+
os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep)
|
110 |
+
|
111 |
+
os.makedirs(model_dir, exist_ok=True)
|
112 |
+
torch.save(to_save, model_dir + model_name + ".pth")
|
113 |
+
|
114 |
+
def train(self):
|
115 |
+
self.udpparsernet.train()
|
116 |
+
self.target_pose_encoder.train()
|
117 |
+
self.shader.train()
|
118 |
+
self.rgbadecodernet.train()
|
119 |
+
|
120 |
+
def eval(self):
|
121 |
+
self.udpparsernet.eval()
|
122 |
+
self.target_pose_encoder.eval()
|
123 |
+
self.shader.eval()
|
124 |
+
self.rgbadecodernet.eval()
|
125 |
+
|
126 |
+
def device(self):
|
127 |
+
self.udpparsernet.to(device)
|
128 |
+
self.target_pose_encoder.to(device)
|
129 |
+
self.shader.to(device)
|
130 |
+
self.rgbadecodernet.to(device)
|
131 |
+
|
132 |
+
def data_norm_image(self, data):
|
133 |
+
|
134 |
+
with torch.cuda.amp.autocast(enabled=False):
|
135 |
+
for name in ["character_labels", "pose_label"]:
|
136 |
+
if name in data:
|
137 |
+
data[name] = data[name].to(
|
138 |
+
device, non_blocking=True).float()
|
139 |
+
for name in ["pose_images", "pose_mask", "character_images", "character_masks"]:
|
140 |
+
if name in data:
|
141 |
+
data[name] = data[name].to(
|
142 |
+
device, non_blocking=True).float() / 255.0
|
143 |
+
if "pose_images" in data:
|
144 |
+
data["num_pose_images"] = data["pose_images"].shape[1]
|
145 |
+
data["num_samples"] = data["pose_images"].shape[0]
|
146 |
+
if "character_images" in data:
|
147 |
+
data["num_character_images"] = data["character_images"].shape[1]
|
148 |
+
data["num_samples"] = data["character_images"].shape[0]
|
149 |
+
if "pose_images" in data and "character_images" in data:
|
150 |
+
assert (data["pose_images"].shape[0] ==
|
151 |
+
data["character_images"].shape[0])
|
152 |
+
return data
|
153 |
+
|
154 |
+
def reset_charactersheet(self):
|
155 |
+
self.parser_ckpt = None
|
156 |
+
|
157 |
+
def model_step(self, data, training=False):
|
158 |
+
self.eval()
|
159 |
+
with torch.cuda.amp.autocast(enabled=False):
|
160 |
+
pred = {}
|
161 |
+
if self.parser_ckpt:
|
162 |
+
pred["parser"] = self.parser_ckpt
|
163 |
+
else:
|
164 |
+
pred = self.character_parser_forward(data, pred)
|
165 |
+
self.parser_ckpt = pred["parser"]
|
166 |
+
pred = self.pose_parser_sc_forward(data, pred)
|
167 |
+
pred = self.shader_pose_encoder_forward(data, pred)
|
168 |
+
pred = self.shader_forward(data, pred)
|
169 |
+
return pred
|
170 |
+
|
171 |
+
def shader_forward(self, data, pred={}):
|
172 |
+
assert ("num_character_images" in data), "ERROR: No Character Sheet input."
|
173 |
+
|
174 |
+
character_images_rgb_nmchw, num_character_images = data[
|
175 |
+
"character_images"], data["num_character_images"]
|
176 |
+
# build x_reference_rgb_a_sudp in the draw call
|
177 |
+
shader_character_a_nmchw = data["character_masks"]
|
178 |
+
assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \
|
179 |
+
No transparent area found in the image, PLEASE separate the foreground of input character sheets.\
|
180 |
+
The website waifucutout.com is recommended to automatically cut out the foreground."
|
181 |
+
|
182 |
+
if shader_character_a_nmchw is None:
|
183 |
+
shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :]
|
184 |
+
x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :],
|
185 |
+
shader_character_a_nmchw[:,
|
186 |
+
:, :, :, :],
|
187 |
+
|
188 |
+
], 2)
|
189 |
+
assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE)
|
190 |
+
# build x_reference_features in the draw call
|
191 |
+
x_reference_features = pred["parser"]["features"]
|
192 |
+
# run cinn shader
|
193 |
+
retdic = self.shader(
|
194 |
+
pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features)
|
195 |
+
pred["shader"].update(retdic)
|
196 |
+
|
197 |
+
# decode rgba
|
198 |
+
if True:
|
199 |
+
dec_out = self.rgbadecodernet(
|
200 |
+
retdic["y_last_remote_features"])
|
201 |
+
y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :]
|
202 |
+
y_weighted_mask_A = dec_out[:, 3:4, :, :]
|
203 |
+
y_weighted_warp_decoded_rgba = torch.cat(
|
204 |
+
(y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1
|
205 |
+
)
|
206 |
+
assert(y_weighted_warp_decoded_rgba.shape[1] == 4)
|
207 |
+
assert(
|
208 |
+
y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1])
|
209 |
+
# apply decoded mask to decoded rgb, finishing the draw call
|
210 |
+
pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba
|
211 |
+
return pred
|
212 |
+
|
213 |
+
def character_parser_forward(self, data, pred={}):
|
214 |
+
if not("num_character_images" in data and "character_images" in data):
|
215 |
+
return pred
|
216 |
+
pred["parser"] = {"pred": None} # create output
|
217 |
+
|
218 |
+
inputs_rgb_nmchw, num_samples, num_character_images = data[
|
219 |
+
"character_images"], data["num_samples"], data["num_character_images"]
|
220 |
+
inputs_rgb_fchw = inputs_rgb_nmchw.view(
|
221 |
+
(num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4]))
|
222 |
+
|
223 |
+
encoder_out, features = self.udpparsernet(
|
224 |
+
(inputs_rgb_fchw-0.6)/0.2970)
|
225 |
+
|
226 |
+
pred["parser"]["features"] = [features_out.view(
|
227 |
+
(num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features]
|
228 |
+
|
229 |
+
if (encoder_out is not None):
|
230 |
+
|
231 |
+
pred["parser"]["pred"] = UDPClip(encoder_out.view(
|
232 |
+
(num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])))
|
233 |
+
|
234 |
+
return pred
|
235 |
+
|
236 |
+
def pose_parser_sc_forward(self, data, pred={}):
|
237 |
+
if not("num_pose_images" in data and "pose_images" in data):
|
238 |
+
return pred
|
239 |
+
inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[
|
240 |
+
"pose_images"], data["num_samples"], data["num_pose_images"]
|
241 |
+
inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view(
|
242 |
+
(num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4]))
|
243 |
+
|
244 |
+
encoder_out, _ = self.udpparsernet(
|
245 |
+
(inputs_aug_rgb_fchw-0.6)/0.2970)
|
246 |
+
|
247 |
+
encoder_out = encoder_out.view(
|
248 |
+
(num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))
|
249 |
+
|
250 |
+
# apply sigmoid after eval loss
|
251 |
+
pred["pose_parser"] = {"pred":UDPClip(encoder_out)[:,0,:,:,:]}
|
252 |
+
|
253 |
+
|
254 |
+
return pred
|
255 |
+
|
256 |
+
def shader_pose_encoder_forward(self, data, pred={}):
|
257 |
+
pred["shader"] = {} # create output
|
258 |
+
if "pose_images" in data:
|
259 |
+
pose_images_rgb_nmchw = data["pose_images"]
|
260 |
+
target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :]
|
261 |
+
pred["shader"]["target_gt_rgb"] = target_gt_rgb
|
262 |
+
|
263 |
+
shader_target_a = None
|
264 |
+
if "pose_mask" in data:
|
265 |
+
pred["shader"]["target_gt_a"] = data["pose_mask"]
|
266 |
+
shader_target_a = data["pose_mask"]
|
267 |
+
|
268 |
+
shader_target_sudp = None
|
269 |
+
if "pose_label" in data:
|
270 |
+
shader_target_sudp = data["pose_label"][:, :3, :, :]
|
271 |
+
|
272 |
+
if self.args.test_pose_use_parser_udp:
|
273 |
+
shader_target_sudp = None
|
274 |
+
if shader_target_sudp is None:
|
275 |
+
shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :]
|
276 |
+
|
277 |
+
if shader_target_a is None:
|
278 |
+
shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :]
|
279 |
+
|
280 |
+
# build x_target_sudp_a in the draw call
|
281 |
+
x_target_sudp_a = torch.cat((
|
282 |
+
shader_target_sudp*shader_target_a,
|
283 |
+
shader_target_a
|
284 |
+
), 1)
|
285 |
+
pred["shader"].update({
|
286 |
+
"x_target_sudp_a": x_target_sudp_a
|
287 |
+
})
|
288 |
+
_, features = self.target_pose_encoder(
|
289 |
+
(x_target_sudp_a-0.6)/0.2970, ret_parser_out=False)
|
290 |
+
|
291 |
+
pred["shader"]["target_pose_features"] = features
|
292 |
+
return pred
|
data_loader.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import os
|
6 |
+
cv2.setNumThreads(1)
|
7 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
8 |
+
|
9 |
+
|
10 |
+
class RandomResizedCropWithAutoCenteringAndZeroPadding (object):
|
11 |
+
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), center_jitter=(0.1, 0.1), size_from_alpha_mask=True):
|
12 |
+
assert isinstance(output_size, (int, tuple))
|
13 |
+
if isinstance(output_size, int):
|
14 |
+
self.output_size = (output_size, output_size)
|
15 |
+
else:
|
16 |
+
assert len(output_size) == 2
|
17 |
+
self.output_size = output_size
|
18 |
+
assert isinstance(scale, tuple)
|
19 |
+
assert isinstance(ratio, tuple)
|
20 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
21 |
+
raise ValueError("Scale and ratio should be of kind (min, max)")
|
22 |
+
self.size_from_alpha_mask = size_from_alpha_mask
|
23 |
+
self.scale = scale
|
24 |
+
self.ratio = ratio
|
25 |
+
assert isinstance(center_jitter, tuple)
|
26 |
+
self.center_jitter = center_jitter
|
27 |
+
|
28 |
+
def __call__(self, sample):
|
29 |
+
imidx, image = sample['imidx'], sample["image_np"]
|
30 |
+
if "labels" in sample:
|
31 |
+
label = sample["labels"]
|
32 |
+
else:
|
33 |
+
label = None
|
34 |
+
|
35 |
+
im_h, im_w = image.shape[:2]
|
36 |
+
if self.size_from_alpha_mask and image.shape[2] == 4:
|
37 |
+
# compute bbox from alpha mask
|
38 |
+
bbox_left, bbox_top, bbox_w, bbox_h = cv2.boundingRect(
|
39 |
+
(image[:, :, 3] > 0).astype(np.uint8))
|
40 |
+
else:
|
41 |
+
bbox_left, bbox_top = 0, 0
|
42 |
+
bbox_h, bbox_w = image.shape[:2]
|
43 |
+
if bbox_h <= 1 and bbox_w <= 1:
|
44 |
+
sample["bad"] = 0
|
45 |
+
else:
|
46 |
+
# detect too small image here
|
47 |
+
alpha_varea = np.sum((image[:, :, 3] > 0).astype(np.uint8))
|
48 |
+
image_area = image.shape[0]*image.shape[1]
|
49 |
+
if alpha_varea/image_area < 0.001:
|
50 |
+
sample["bad"] = alpha_varea
|
51 |
+
# detect bad image
|
52 |
+
if "bad" in sample:
|
53 |
+
# baddata_dir = os.path.join(os.getcwd(), 'test_data', "baddata" + os.sep)
|
54 |
+
# save_output(str(imidx)+".png",image,label,baddata_dir)
|
55 |
+
bbox_h, bbox_w = image.shape[:2]
|
56 |
+
sample["image_np"] = np.zeros(
|
57 |
+
[self.output_size[0], self.output_size[1], image.shape[2]], dtype=image.dtype)
|
58 |
+
if label is not None:
|
59 |
+
sample["labels"] = np.zeros(
|
60 |
+
[self.output_size[0], self.output_size[1], 4], dtype=label.dtype)
|
61 |
+
|
62 |
+
return sample
|
63 |
+
|
64 |
+
# compute default area by making sure output_size contains bbox_w * bbox_h
|
65 |
+
|
66 |
+
jitter_h = np.random.uniform(-bbox_h *
|
67 |
+
self.center_jitter[0], bbox_h*self.center_jitter[0])
|
68 |
+
jitter_w = np.random.uniform(-bbox_w *
|
69 |
+
self.center_jitter[1], bbox_w*self.center_jitter[1])
|
70 |
+
|
71 |
+
# h/w
|
72 |
+
target_aspect_ratio = np.exp(
|
73 |
+
np.log(self.output_size[0]/self.output_size[1]) +
|
74 |
+
np.random.uniform(np.log(self.ratio[0]), np.log(self.ratio[1]))
|
75 |
+
)
|
76 |
+
|
77 |
+
source_aspect_ratio = bbox_h/bbox_w
|
78 |
+
|
79 |
+
if target_aspect_ratio < source_aspect_ratio:
|
80 |
+
# same w, target has larger h, use h to align
|
81 |
+
target_height = bbox_h * \
|
82 |
+
np.random.uniform(self.scale[0], self.scale[1])
|
83 |
+
virtual_h = int(
|
84 |
+
round(target_height))
|
85 |
+
virtual_w = int(
|
86 |
+
round(target_height / target_aspect_ratio)) # h/w
|
87 |
+
else:
|
88 |
+
# same w, source has larger h, use w to align
|
89 |
+
target_width = bbox_w * \
|
90 |
+
np.random.uniform(self.scale[0], self.scale[1])
|
91 |
+
virtual_h = int(
|
92 |
+
round(target_width * target_aspect_ratio)) # h/w
|
93 |
+
virtual_w = int(
|
94 |
+
round(target_width))
|
95 |
+
|
96 |
+
# print("required aspect ratio:", target_aspect_ratio)
|
97 |
+
|
98 |
+
virtual_top = int(round(bbox_top + jitter_h - (virtual_h-bbox_h)/2))
|
99 |
+
virutal_left = int(round(bbox_left + jitter_w - (virtual_w-bbox_w)/2))
|
100 |
+
|
101 |
+
if virtual_top < 0:
|
102 |
+
top_padding = abs(virtual_top)
|
103 |
+
crop_top = 0
|
104 |
+
else:
|
105 |
+
top_padding = 0
|
106 |
+
crop_top = virtual_top
|
107 |
+
if virutal_left < 0:
|
108 |
+
left_padding = abs(virutal_left)
|
109 |
+
crop_left = 0
|
110 |
+
else:
|
111 |
+
left_padding = 0
|
112 |
+
crop_left = virutal_left
|
113 |
+
if virtual_top+virtual_h > im_h:
|
114 |
+
bottom_padding = abs(im_h-(virtual_top+virtual_h))
|
115 |
+
crop_bottom = im_h
|
116 |
+
else:
|
117 |
+
bottom_padding = 0
|
118 |
+
crop_bottom = virtual_top+virtual_h
|
119 |
+
if virutal_left+virtual_w > im_w:
|
120 |
+
right_padding = abs(im_w-(virutal_left+virtual_w))
|
121 |
+
crop_right = im_w
|
122 |
+
else:
|
123 |
+
right_padding = 0
|
124 |
+
crop_right = virutal_left+virtual_w
|
125 |
+
# crop
|
126 |
+
|
127 |
+
image = image[crop_top:crop_bottom, crop_left: crop_right]
|
128 |
+
if label is not None:
|
129 |
+
label = label[crop_top:crop_bottom, crop_left: crop_right]
|
130 |
+
|
131 |
+
# pad
|
132 |
+
if top_padding + bottom_padding + left_padding + right_padding > 0:
|
133 |
+
padding = ((top_padding, bottom_padding),
|
134 |
+
(left_padding, right_padding), (0, 0))
|
135 |
+
# print("padding", padding)
|
136 |
+
image = np.pad(image, padding, mode='constant')
|
137 |
+
if label is not None:
|
138 |
+
label = np.pad(label, padding, mode='constant')
|
139 |
+
|
140 |
+
if image.shape[0]/image.shape[1] - virtual_h/virtual_w > 0.001:
|
141 |
+
print("virtual aspect ratio:", virtual_h/virtual_w)
|
142 |
+
print("image aspect ratio:", image.shape[0]/image.shape[1])
|
143 |
+
assert (image.shape[0]/image.shape[1] - virtual_h/virtual_w < 0.001)
|
144 |
+
sample["crop"] = np.array(
|
145 |
+
[im_h, im_w, crop_top, crop_bottom, crop_left, crop_right, top_padding, bottom_padding, left_padding, right_padding, image.shape[0], image.shape[1]])
|
146 |
+
|
147 |
+
# resize
|
148 |
+
if self.output_size[1] != image.shape[1] or self.output_size[0] != image.shape[0]:
|
149 |
+
if self.output_size[1] > image.shape[1] and self.output_size[0] > image.shape[0]:
|
150 |
+
# enlarging
|
151 |
+
image = cv2.resize(
|
152 |
+
image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_LINEAR)
|
153 |
+
else:
|
154 |
+
# shrinking
|
155 |
+
image = cv2.resize(
|
156 |
+
image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_AREA)
|
157 |
+
|
158 |
+
if label is not None:
|
159 |
+
label = cv2.resize(label, (self.output_size[1], self.output_size[0]),
|
160 |
+
interpolation=cv2.INTER_NEAREST_EXACT)
|
161 |
+
|
162 |
+
assert image.shape[0] == self.output_size[0] and image.shape[1] == self.output_size[1]
|
163 |
+
sample['imidx'], sample["image_np"] = imidx, image
|
164 |
+
if label is not None:
|
165 |
+
assert label.shape[0] == self.output_size[0] and label.shape[1] == self.output_size[1]
|
166 |
+
sample["labels"] = label
|
167 |
+
|
168 |
+
return sample
|
169 |
+
|
170 |
+
|
171 |
+
class FileDataset(Dataset):
|
172 |
+
def __init__(self, image_names_list, fg_img_lbl_transform=None, shader_pose_use_gt_udp_test=True, shader_target_use_gt_rgb_debug=False):
|
173 |
+
self.image_name_list = image_names_list
|
174 |
+
self.fg_img_lbl_transform = fg_img_lbl_transform
|
175 |
+
self.shader_pose_use_gt_udp_test = shader_pose_use_gt_udp_test
|
176 |
+
self.shader_target_use_gt_rgb_debug = shader_target_use_gt_rgb_debug
|
177 |
+
|
178 |
+
def __len__(self):
|
179 |
+
return len(self.image_name_list)
|
180 |
+
|
181 |
+
def get_gt_from_disk(self, idx, imname, read_label):
|
182 |
+
if read_label:
|
183 |
+
# read label
|
184 |
+
with open(imname, mode="rb") as bio:
|
185 |
+
if imname.find(".npz") > 0:
|
186 |
+
label_np = np.load(bio, allow_pickle=True)[
|
187 |
+
'i'].astype(np.float32, copy=False)
|
188 |
+
else:
|
189 |
+
label_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(bio.read(
|
190 |
+
), np.uint8), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
191 |
+
assert (4 == label_np.shape[2])
|
192 |
+
# fake image out of valid label
|
193 |
+
image_np = (label_np*255).clip(0, 255).astype(np.uint8, copy=False)
|
194 |
+
# assemble sample
|
195 |
+
sample = {'imidx': np.array(
|
196 |
+
[idx]), "image_np": image_np, "labels": label_np}
|
197 |
+
|
198 |
+
else:
|
199 |
+
# read image as unit8
|
200 |
+
with open(imname, mode="rb") as bio:
|
201 |
+
image_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(
|
202 |
+
bio.read(), np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
203 |
+
# image_np = Image.open(bio)
|
204 |
+
# image_np = np.array(image_np)
|
205 |
+
assert (3 == len(image_np.shape))
|
206 |
+
if (image_np.shape[2] == 4):
|
207 |
+
mask_np = image_np[:, :, 3:4]
|
208 |
+
image_np = (image_np[:, :, :3] *
|
209 |
+
(image_np[:, :, 3][:, :, np.newaxis]/255.0)).clip(0, 255).astype(np.uint8, copy=False)
|
210 |
+
elif (image_np.shape[2] == 3):
|
211 |
+
# generate a fake mask
|
212 |
+
# Fool-proofing
|
213 |
+
mask_np = np.ones(
|
214 |
+
(image_np.shape[0], image_np.shape[1], 1), dtype=np.uint8)*255
|
215 |
+
print("WARN: transparent background is preferred for image ", imname)
|
216 |
+
else:
|
217 |
+
raise ValueError("weird shape of image ", imname, image_np)
|
218 |
+
image_np = np.concatenate((image_np, mask_np), axis=2)
|
219 |
+
sample = {'imidx': np.array(
|
220 |
+
[idx]), "image_np": image_np}
|
221 |
+
|
222 |
+
# apply fg_img_lbl_transform
|
223 |
+
if self.fg_img_lbl_transform:
|
224 |
+
sample = self.fg_img_lbl_transform(sample)
|
225 |
+
|
226 |
+
if "labels" in sample:
|
227 |
+
# return UDP as 4chn XYZV float tensor
|
228 |
+
if "float" not in str(sample["labels"].dtype):
|
229 |
+
sample["labels"] = sample["labels"].astype(np.float32) / np.iinfo(sample["labels"].dtype).max
|
230 |
+
sample["labels"] = torch.from_numpy(
|
231 |
+
sample["labels"].transpose((2, 0, 1)))
|
232 |
+
assert (sample["labels"].dtype == torch.float32)
|
233 |
+
|
234 |
+
if "image_np" in sample:
|
235 |
+
# return image as 3chn RGB uint8 tensor and 1chn A uint8 tensor
|
236 |
+
sample["mask"] = torch.from_numpy(
|
237 |
+
sample["image_np"][:, :, 3:4].transpose((2, 0, 1)))
|
238 |
+
assert (sample["mask"].dtype == torch.uint8)
|
239 |
+
sample["image"] = torch.from_numpy(
|
240 |
+
sample["image_np"][:, :, :3].transpose((2, 0, 1)))
|
241 |
+
|
242 |
+
assert (sample["image"].dtype == torch.uint8)
|
243 |
+
del sample["image_np"]
|
244 |
+
return sample
|
245 |
+
|
246 |
+
def __getitem__(self, idx):
|
247 |
+
sample = {
|
248 |
+
'imidx': np.array([idx])}
|
249 |
+
target = self.get_gt_from_disk(
|
250 |
+
idx, imname=self.image_name_list[idx][0], read_label=self.shader_pose_use_gt_udp_test)
|
251 |
+
if self.shader_target_use_gt_rgb_debug:
|
252 |
+
sample["pose_images"] = torch.stack([target["image"]])
|
253 |
+
sample["pose_mask"] = target["mask"]
|
254 |
+
elif self.shader_pose_use_gt_udp_test:
|
255 |
+
sample["pose_label"] = target["labels"]
|
256 |
+
sample["pose_mask"] = target["mask"]
|
257 |
+
else:
|
258 |
+
sample["pose_images"] = torch.stack([target["image"]])
|
259 |
+
if "crop" in target:
|
260 |
+
sample["pose_crop"] = target["crop"]
|
261 |
+
character_images = []
|
262 |
+
character_masks = []
|
263 |
+
for i in range(1, len(self.image_name_list[idx])):
|
264 |
+
source = self.get_gt_from_disk(
|
265 |
+
idx, self.image_name_list[idx][i], read_label=False)
|
266 |
+
character_images.append(source["image"])
|
267 |
+
character_masks.append(source["mask"])
|
268 |
+
character_images = torch.stack(character_images)
|
269 |
+
character_masks = torch.stack(character_masks)
|
270 |
+
sample.update({
|
271 |
+
"character_images": character_images,
|
272 |
+
"character_masks": character_masks
|
273 |
+
})
|
274 |
+
# do not make fake labels in inference
|
275 |
+
return sample
|
images/MAIN.png
ADDED
![]() |
Git LFS Details
|
infer.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm -r "./results"
|
2 |
+
mkdir "./results"
|
3 |
+
|
4 |
+
python3 train.py --mode=test \
|
5 |
+
--world_size=1 --dataloaders=2 \
|
6 |
+
--test_input_poses_images=./poses/ \
|
7 |
+
--test_input_person_images=./character_sheet/ \
|
8 |
+
--test_output_dir=./results/ \
|
9 |
+
--test_checkpoint_dir=./weights/
|
10 |
+
|
11 |
+
echo Generating Video...
|
12 |
+
ffmpeg -r 30 -y -i ./results/%d.png -r 30 -c:v libx264 -pix_fmt yuv420p output.mp4 -crf 18 -r 30
|
13 |
+
echo DONE.
|
model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/backbone.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code was mostly taken from backbone-unet by mkisantal:
|
3 |
+
https://github.com/mkisantal/backboned-unet/blob/master/backboned_unet/unet.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch
|
12 |
+
from torchvision import models
|
13 |
+
|
14 |
+
|
15 |
+
class AdaptiveConcatPool2d(nn.Module):
|
16 |
+
"""
|
17 |
+
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
|
18 |
+
Source: Fastai. This code was taken from the fastai library at url
|
19 |
+
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, sz=None):
|
23 |
+
"Output will be 2*sz or 2 if sz is None"
|
24 |
+
super().__init__()
|
25 |
+
self.output_size = sz or 1
|
26 |
+
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
|
27 |
+
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
|
28 |
+
|
29 |
+
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
|
30 |
+
|
31 |
+
|
32 |
+
class MyNorm(nn.Module):
|
33 |
+
def __init__(self, num_channels):
|
34 |
+
super(MyNorm, self).__init__()
|
35 |
+
self.norm = nn.InstanceNorm2d(
|
36 |
+
num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.norm(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
def resnet_fastai(model, pretrained, url, replace_first_layer=None, replace_maxpool_layer=None, progress=True, map_location=None, **kwargs):
|
44 |
+
cut = -2
|
45 |
+
s = model(pretrained=False, **kwargs)
|
46 |
+
if replace_maxpool_layer is not None:
|
47 |
+
s.maxpool = replace_maxpool_layer
|
48 |
+
if replace_first_layer is not None:
|
49 |
+
body = nn.Sequential(replace_first_layer, *list(s.children())[1:cut])
|
50 |
+
else:
|
51 |
+
body = nn.Sequential(*list(s.children())[:cut])
|
52 |
+
|
53 |
+
if pretrained:
|
54 |
+
state = torch.hub.load_state_dict_from_url(url,
|
55 |
+
progress=progress, map_location=map_location)
|
56 |
+
if replace_first_layer is not None:
|
57 |
+
for each in list(state.keys()).copy():
|
58 |
+
if each.find("0.0.") == 0:
|
59 |
+
del state[each]
|
60 |
+
body_tail = nn.Sequential(body)
|
61 |
+
ret = body_tail.load_state_dict(state, strict=False)
|
62 |
+
return body
|
63 |
+
|
64 |
+
|
65 |
+
def get_backbone(name, pretrained=True, map_location=None):
|
66 |
+
""" Loading backbone, defining names for skip-connections and encoder output. """
|
67 |
+
|
68 |
+
first_layer_for_4chn = nn.Conv2d(
|
69 |
+
4, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
70 |
+
max_pool_layer_replace = nn.Conv2d(
|
71 |
+
64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
72 |
+
# loading backbone model
|
73 |
+
if name == 'resnet18':
|
74 |
+
backbone = models.resnet18(pretrained=pretrained)
|
75 |
+
if name == 'resnet18-4':
|
76 |
+
backbone = models.resnet18(pretrained=pretrained)
|
77 |
+
backbone.conv1 = first_layer_for_4chn
|
78 |
+
elif name == 'resnet34':
|
79 |
+
backbone = models.resnet34(pretrained=pretrained)
|
80 |
+
elif name == 'resnet50':
|
81 |
+
backbone = models.resnet50(pretrained=False, norm_layer=MyNorm)
|
82 |
+
backbone.maxpool = max_pool_layer_replace
|
83 |
+
elif name == 'resnet101':
|
84 |
+
backbone = models.resnet101(pretrained=pretrained)
|
85 |
+
elif name == 'resnet152':
|
86 |
+
backbone = models.resnet152(pretrained=pretrained)
|
87 |
+
elif name == 'vgg16':
|
88 |
+
backbone = models.vgg16_bn(pretrained=pretrained).features
|
89 |
+
elif name == 'vgg19':
|
90 |
+
backbone = models.vgg19_bn(pretrained=pretrained).features
|
91 |
+
elif name == 'resnet18_danbo-4':
|
92 |
+
backbone = resnet_fastai(models.resnet18, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet18-3f77756f.pth",
|
93 |
+
pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_first_layer=first_layer_for_4chn)
|
94 |
+
elif name == 'resnet50_danbo':
|
95 |
+
backbone = resnet_fastai(models.resnet50, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
|
96 |
+
pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_maxpool_layer=max_pool_layer_replace)
|
97 |
+
elif name == 'densenet121':
|
98 |
+
backbone = models.densenet121(pretrained=True).features
|
99 |
+
elif name == 'densenet161':
|
100 |
+
backbone = models.densenet161(pretrained=True).features
|
101 |
+
elif name == 'densenet169':
|
102 |
+
backbone = models.densenet169(pretrained=True).features
|
103 |
+
elif name == 'densenet201':
|
104 |
+
backbone = models.densenet201(pretrained=True).features
|
105 |
+
else:
|
106 |
+
raise NotImplemented(
|
107 |
+
'{} backbone model is not implemented so far.'.format(name))
|
108 |
+
#print(backbone)
|
109 |
+
# specifying skip feature and output names
|
110 |
+
if name.startswith('resnet'):
|
111 |
+
feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
|
112 |
+
backbone_output = 'layer4'
|
113 |
+
elif name == 'vgg16':
|
114 |
+
# TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
|
115 |
+
feature_names = ['5', '12', '22', '32', '42']
|
116 |
+
backbone_output = '43'
|
117 |
+
elif name == 'vgg19':
|
118 |
+
feature_names = ['5', '12', '25', '38', '51']
|
119 |
+
backbone_output = '52'
|
120 |
+
elif name.startswith('densenet'):
|
121 |
+
feature_names = [None, 'relu0', 'denseblock1',
|
122 |
+
'denseblock2', 'denseblock3']
|
123 |
+
backbone_output = 'denseblock4'
|
124 |
+
elif name == 'unet_encoder':
|
125 |
+
feature_names = ['module1', 'module2', 'module3', 'module4']
|
126 |
+
backbone_output = 'module5'
|
127 |
+
else:
|
128 |
+
raise NotImplemented(
|
129 |
+
'{} backbone model is not implemented so far.'.format(name))
|
130 |
+
if name.find('_danbo') > 0:
|
131 |
+
feature_names = [None, '2', '4', '5', '6']
|
132 |
+
backbone_output = '7'
|
133 |
+
return backbone, feature_names, backbone_output
|
134 |
+
|
135 |
+
|
136 |
+
class UpsampleBlock(nn.Module):
|
137 |
+
|
138 |
+
# TODO: separate parametric and non-parametric classes?
|
139 |
+
# TODO: skip connection concatenated OR added
|
140 |
+
|
141 |
+
def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
|
142 |
+
super(UpsampleBlock, self).__init__()
|
143 |
+
|
144 |
+
self.parametric = parametric
|
145 |
+
ch_out = ch_in/2 if ch_out is None else ch_out
|
146 |
+
|
147 |
+
# first convolution: either transposed conv, or conv following the skip connection
|
148 |
+
if parametric:
|
149 |
+
# versions: kernel=4 padding=1, kernel=2 padding=0
|
150 |
+
self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
|
151 |
+
stride=2, padding=1, output_padding=0, bias=(not use_bn))
|
152 |
+
self.bn1 = MyNorm(ch_out) if use_bn else None
|
153 |
+
else:
|
154 |
+
self.up = None
|
155 |
+
ch_in = ch_in + skip_in
|
156 |
+
self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
|
157 |
+
stride=1, padding=1, bias=(not use_bn))
|
158 |
+
self.bn1 = MyNorm(ch_out) if use_bn else None
|
159 |
+
|
160 |
+
self.relu = nn.ReLU(inplace=True)
|
161 |
+
|
162 |
+
# second convolution
|
163 |
+
conv2_in = ch_out if not parametric else ch_out + skip_in
|
164 |
+
self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
|
165 |
+
stride=1, padding=1, bias=(not use_bn))
|
166 |
+
self.bn2 = MyNorm(ch_out) if use_bn else None
|
167 |
+
|
168 |
+
def forward(self, x, skip_connection=None):
|
169 |
+
|
170 |
+
x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
|
171 |
+
align_corners=None)
|
172 |
+
if self.parametric:
|
173 |
+
x = self.bn1(x) if self.bn1 is not None else x
|
174 |
+
x = self.relu(x)
|
175 |
+
|
176 |
+
if skip_connection is not None:
|
177 |
+
x = torch.cat([x, skip_connection], dim=1)
|
178 |
+
|
179 |
+
if not self.parametric:
|
180 |
+
x = self.conv1(x)
|
181 |
+
x = self.bn1(x) if self.bn1 is not None else x
|
182 |
+
x = self.relu(x)
|
183 |
+
x = self.conv2(x)
|
184 |
+
x = self.bn2(x) if self.bn2 is not None else x
|
185 |
+
x = self.relu(x)
|
186 |
+
|
187 |
+
return x
|
188 |
+
|
189 |
+
|
190 |
+
class ResEncUnet(nn.Module):
|
191 |
+
|
192 |
+
""" U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
|
193 |
+
|
194 |
+
def __init__(self,
|
195 |
+
backbone_name,
|
196 |
+
pretrained=True,
|
197 |
+
encoder_freeze=False,
|
198 |
+
classes=21,
|
199 |
+
decoder_filters=(512, 256, 128, 64, 32),
|
200 |
+
parametric_upsampling=True,
|
201 |
+
shortcut_features='default',
|
202 |
+
decoder_use_instancenorm=True,
|
203 |
+
map_location=None
|
204 |
+
):
|
205 |
+
super(ResEncUnet, self).__init__()
|
206 |
+
|
207 |
+
self.backbone_name = backbone_name
|
208 |
+
|
209 |
+
self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(
|
210 |
+
backbone_name, pretrained=pretrained, map_location=map_location)
|
211 |
+
shortcut_chs, bb_out_chs = self.infer_skip_channels()
|
212 |
+
if shortcut_features != 'default':
|
213 |
+
self.shortcut_features = shortcut_features
|
214 |
+
|
215 |
+
# build decoder part
|
216 |
+
self.upsample_blocks = nn.ModuleList()
|
217 |
+
# avoiding having more blocks than skip connections
|
218 |
+
decoder_filters = decoder_filters[:len(self.shortcut_features)]
|
219 |
+
decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
|
220 |
+
num_blocks = len(self.shortcut_features)
|
221 |
+
for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
|
222 |
+
self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
|
223 |
+
skip_in=shortcut_chs[num_blocks-i-1],
|
224 |
+
parametric=parametric_upsampling,
|
225 |
+
use_bn=decoder_use_instancenorm))
|
226 |
+
self.final_conv = nn.Conv2d(
|
227 |
+
decoder_filters[-1], classes, kernel_size=(1, 1))
|
228 |
+
|
229 |
+
if encoder_freeze:
|
230 |
+
self.freeze_encoder()
|
231 |
+
|
232 |
+
def freeze_encoder(self):
|
233 |
+
""" Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
|
234 |
+
|
235 |
+
for param in self.backbone.parameters():
|
236 |
+
param.requires_grad = False
|
237 |
+
|
238 |
+
def forward(self, *input, ret_parser_out=True):
|
239 |
+
""" Forward propagation in U-Net. """
|
240 |
+
|
241 |
+
x, features = self.forward_backbone(*input)
|
242 |
+
output_feature = [x]
|
243 |
+
for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
|
244 |
+
skip_features = features[skip_name]
|
245 |
+
if skip_features is not None:
|
246 |
+
output_feature.append(skip_features)
|
247 |
+
if ret_parser_out:
|
248 |
+
x = upsample_block(x, skip_features)
|
249 |
+
if ret_parser_out:
|
250 |
+
x = self.final_conv(x)
|
251 |
+
# apply sigmoid later
|
252 |
+
else:
|
253 |
+
x = None
|
254 |
+
|
255 |
+
return x, output_feature
|
256 |
+
|
257 |
+
def forward_backbone(self, x):
|
258 |
+
""" Forward propagation in backbone encoder network. """
|
259 |
+
|
260 |
+
features = {None: None} if None in self.shortcut_features else dict()
|
261 |
+
for name, child in self.backbone.named_children():
|
262 |
+
x = child(x)
|
263 |
+
if name in self.shortcut_features:
|
264 |
+
features[name] = x
|
265 |
+
if name == self.bb_out_name:
|
266 |
+
break
|
267 |
+
|
268 |
+
return x, features
|
269 |
+
|
270 |
+
def infer_skip_channels(self):
|
271 |
+
""" Getting the number of channels at skip connections and at the output of the encoder. """
|
272 |
+
if self.backbone_name.find("-4") > 0:
|
273 |
+
x = torch.zeros(1, 4, 224, 224)
|
274 |
+
else:
|
275 |
+
x = torch.zeros(1, 3, 224, 224)
|
276 |
+
has_fullres_features = self.backbone_name.startswith(
|
277 |
+
'vgg') or self.backbone_name == 'unet_encoder'
|
278 |
+
# only VGG has features at full resolution
|
279 |
+
channels = [] if has_fullres_features else [0]
|
280 |
+
|
281 |
+
# forward run in backbone to count channels (dirty solution but works for *any* Module)
|
282 |
+
for name, child in self.backbone.named_children():
|
283 |
+
x = child(x)
|
284 |
+
if name in self.shortcut_features:
|
285 |
+
channels.append(x.shape[1])
|
286 |
+
if name == self.bb_out_name:
|
287 |
+
out_channels = x.shape[1]
|
288 |
+
break
|
289 |
+
return channels, out_channels
|
model/decoder_small.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class ResBlock2d(nn.Module):
|
9 |
+
def __init__(self, in_features, kernel_size, padding):
|
10 |
+
super(ResBlock2d, self).__init__()
|
11 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
12 |
+
padding=padding)
|
13 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
14 |
+
padding=padding)
|
15 |
+
|
16 |
+
self.norm1 = nn.Conv2d(
|
17 |
+
in_channels=in_features, out_channels=in_features, kernel_size=1)
|
18 |
+
self.norm2 = nn.Conv2d(
|
19 |
+
in_channels=in_features, out_channels=in_features, kernel_size=1)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
out = self.norm1(x)
|
23 |
+
out = F.relu(out, inplace=True)
|
24 |
+
out = self.conv1(out)
|
25 |
+
out = self.norm2(out)
|
26 |
+
out = F.relu(out, inplace=True)
|
27 |
+
out = self.conv2(out)
|
28 |
+
out += x
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
class RGBADecoderNet(nn.Module):
|
33 |
+
def __init__(self, c=64, out_planes=4, num_bottleneck_blocks=1):
|
34 |
+
super(RGBADecoderNet, self).__init__()
|
35 |
+
self.conv_rgba = nn.Sequential(nn.Conv2d(c, out_planes, kernel_size=3, stride=1,
|
36 |
+
padding=1, dilation=1, bias=True))
|
37 |
+
self.bottleneck = torch.nn.Sequential()
|
38 |
+
for i in range(num_bottleneck_blocks):
|
39 |
+
self.bottleneck.add_module(
|
40 |
+
'r' + str(i), ResBlock2d(c, kernel_size=(3, 3), padding=(1, 1)))
|
41 |
+
|
42 |
+
def forward(self, features_weighted_mask_atfeaturesscale_list=[]):
|
43 |
+
return torch.sigmoid(self.conv_rgba(self.bottleneck(features_weighted_mask_atfeaturesscale_list.pop(0))))
|
model/shader.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .warplayer import warp_features
|
5 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
|
7 |
+
|
8 |
+
class DecoderBlock(nn.Module):
|
9 |
+
def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):
|
10 |
+
|
11 |
+
super(DecoderBlock, self).__init__()
|
12 |
+
self.conv0 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_planes, c, 3, 2, 1),
|
14 |
+
nn.PReLU(c),
|
15 |
+
nn.Conv2d(c, c, 3, 2, 1),
|
16 |
+
nn.PReLU(c),
|
17 |
+
)
|
18 |
+
|
19 |
+
self.convblocks = nn.ModuleList()
|
20 |
+
for i in range(block_nums):
|
21 |
+
self.convblocks.append(nn.Sequential(
|
22 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
23 |
+
nn.PReLU(c),
|
24 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
25 |
+
nn.PReLU(c),
|
26 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
27 |
+
nn.PReLU(c),
|
28 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
29 |
+
nn.PReLU(c),
|
30 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
31 |
+
nn.PReLU(c),
|
32 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
33 |
+
nn.PReLU(c),
|
34 |
+
))
|
35 |
+
self.out_flows = 2
|
36 |
+
self.out_msgs = out_msgs
|
37 |
+
self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
|
38 |
+
self.out_locals = out_locals
|
39 |
+
self.out_local_flows = out_local_flows if out_locals > 0 else 0
|
40 |
+
self.out_masks = out_masks
|
41 |
+
self.out_feat_flows = out_feat_flows
|
42 |
+
|
43 |
+
self.conv_last = nn.Sequential(
|
44 |
+
nn.ConvTranspose2d(c, c, 4, 2, 1),
|
45 |
+
nn.PReLU(c),
|
46 |
+
nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
|
47 |
+
self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, accumulated_flow, *other):
|
51 |
+
x = [accumulated_flow]
|
52 |
+
for each in other:
|
53 |
+
if each is not None:
|
54 |
+
assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
|
55 |
+
accumulated_flow.shape, each.shape)
|
56 |
+
x.append(each)
|
57 |
+
feat = self.conv0(torch.cat(x, dim=1))
|
58 |
+
for convblock1 in self.convblocks:
|
59 |
+
feat = convblock1(feat) + feat
|
60 |
+
feat = self.conv_last(feat)
|
61 |
+
prev = 0
|
62 |
+
flow = feat[:, prev:prev+self.out_flows, :, :]
|
63 |
+
prev += self.out_flows
|
64 |
+
message = feat[:, prev:prev+self.out_msgs,
|
65 |
+
:, :] if self.out_msgs > 0 else None
|
66 |
+
prev += self.out_msgs
|
67 |
+
message_flow = feat[:, prev:prev + self.out_msgs_flows,
|
68 |
+
:, :] if self.out_msgs_flows > 0 else None
|
69 |
+
prev += self.out_msgs_flows
|
70 |
+
local_message = feat[:, prev:prev + self.out_locals,
|
71 |
+
:, :] if self.out_locals > 0 else None
|
72 |
+
prev += self.out_locals
|
73 |
+
local_message_flow = feat[:, prev:prev+self.out_local_flows,
|
74 |
+
:, :] if self.out_local_flows > 0 else None
|
75 |
+
prev += self.out_local_flows
|
76 |
+
mask = torch.sigmoid(
|
77 |
+
feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
|
78 |
+
prev += self.out_masks
|
79 |
+
feat_flow = feat[:, prev:prev+self.out_feat_flows,
|
80 |
+
:, :] if self.out_feat_flows > 0 else None
|
81 |
+
prev += self.out_feat_flows
|
82 |
+
return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow
|
83 |
+
|
84 |
+
|
85 |
+
class CINN(nn.Module):
|
86 |
+
def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
|
87 |
+
super(CINN, self).__init__()
|
88 |
+
|
89 |
+
self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
|
90 |
+
self.in_locals_chn = [0, *out_locals_chn[:-1]]
|
91 |
+
|
92 |
+
self.decoder_blocks = nn.ModuleList()
|
93 |
+
self.feed_weighted = True
|
94 |
+
if self.feed_weighted:
|
95 |
+
in_planes = 2+2+DIM_SHADER_REFERENCE*2
|
96 |
+
else:
|
97 |
+
in_planes = 2+DIM_SHADER_REFERENCE
|
98 |
+
for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
|
99 |
+
self.decoder_blocks.append(
|
100 |
+
DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
|
101 |
+
for i in range(len(feature_chns), len(out_locals_chn)):
|
102 |
+
#print("append extra block", i, "msg",
|
103 |
+
# out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
|
104 |
+
self.decoder_blocks.append(
|
105 |
+
DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))
|
106 |
+
|
107 |
+
def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
|
108 |
+
if each_x_reference_features is not None:
|
109 |
+
size_from = each_x_reference_features
|
110 |
+
else:
|
111 |
+
size_from = x_reference
|
112 |
+
f_size = (size_from.shape[2], size_from.shape[3])
|
113 |
+
accumulated_flow = self.flow_rescale(
|
114 |
+
accumulated_flow, size_from)
|
115 |
+
# mask = warp_features(F.interpolate(
|
116 |
+
# mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
|
117 |
+
mask = F.interpolate(
|
118 |
+
mask, size=f_size, mode="bilinear") if mask is not None else None
|
119 |
+
message = F.interpolate(
|
120 |
+
message, size=f_size, mode="bilinear") if message is not None else None
|
121 |
+
message_flow = self.flow_rescale(
|
122 |
+
message_flow, size_from) if message_flow is not None else None
|
123 |
+
message = warp_features(
|
124 |
+
message, message_flow) if message_flow is not None else message
|
125 |
+
|
126 |
+
local_message = F.interpolate(
|
127 |
+
local_message, size=f_size, mode="bilinear") if local_message is not None else None
|
128 |
+
local_message_flow = self.flow_rescale(
|
129 |
+
local_message_flow, size_from) if local_message_flow is not None else None
|
130 |
+
local_message = warp_features(
|
131 |
+
local_message, local_message_flow) if local_message_flow is not None else local_message
|
132 |
+
|
133 |
+
warp_x_reference = warp_features(F.interpolate(
|
134 |
+
x_reference, size=f_size, mode="bilinear"), accumulated_flow)
|
135 |
+
|
136 |
+
each_x_reference_features_flow = self.flow_rescale(
|
137 |
+
each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
|
138 |
+
warp_each_x_reference_features = warp_features(
|
139 |
+
each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features
|
140 |
+
|
141 |
+
return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow
|
142 |
+
|
143 |
+
def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
|
144 |
+
y_flow = []
|
145 |
+
y_feat_flow = []
|
146 |
+
|
147 |
+
y_local_message = []
|
148 |
+
y_warp_x_reference = []
|
149 |
+
y_warp_x_reference_features = []
|
150 |
+
|
151 |
+
y_weighted_flow = []
|
152 |
+
y_weighted_mask = []
|
153 |
+
y_weighted_message = []
|
154 |
+
y_weighted_x_reference = []
|
155 |
+
y_weighted_x_reference_features = []
|
156 |
+
|
157 |
+
for pyrlevel, ifblock in enumerate(self.decoder_blocks):
|
158 |
+
stacked_wref = []
|
159 |
+
stacked_feat = []
|
160 |
+
stacked_anci = []
|
161 |
+
stacked_flow = []
|
162 |
+
stacked_mask = []
|
163 |
+
stacked_mesg = []
|
164 |
+
stacked_locm = []
|
165 |
+
stacked_feat_flow = []
|
166 |
+
for view_id in range(x_reference.shape[1]): # NMCHW
|
167 |
+
|
168 |
+
if pyrlevel == 0:
|
169 |
+
# create from zero flow
|
170 |
+
feat_ev = x_reference_features[pyrlevel][:,
|
171 |
+
view_id, :, :, :] if pyrlevel < len(x_reference_features) else None
|
172 |
+
|
173 |
+
accumulated_flow = torch.zeros_like(
|
174 |
+
feat_ev[:, :2, :, :]).to(device)
|
175 |
+
accumulated_feat_flow = torch.zeros_like(
|
176 |
+
feat_ev[:, :32, :, :]).to(device)
|
177 |
+
# domestic inputs
|
178 |
+
warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
|
179 |
+
feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
|
180 |
+
warp_x_reference_features = feat_ev
|
181 |
+
|
182 |
+
local_message = None
|
183 |
+
# federated inputs
|
184 |
+
weighted_flow = accumulated_flow if self.feed_weighted else None
|
185 |
+
weighted_wref = warp_x_reference if self.feed_weighted else None
|
186 |
+
weighted_message = None
|
187 |
+
else:
|
188 |
+
# resume from last layer
|
189 |
+
accumulated_flow = y_flow[-1][:, view_id, :, :, :]
|
190 |
+
accumulated_feat_flow = y_feat_flow[-1][:,
|
191 |
+
view_id, :, :, :] if y_feat_flow[-1] is not None else None
|
192 |
+
# domestic inputs
|
193 |
+
warp_x_reference = y_warp_x_reference[-1][:,
|
194 |
+
view_id, :, :, :]
|
195 |
+
warp_x_reference_features = y_warp_x_reference_features[-1][:,
|
196 |
+
view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
|
197 |
+
local_message = y_local_message[-1][:, view_id, :,
|
198 |
+
:, :] if len(y_local_message) > 0 else None
|
199 |
+
|
200 |
+
# federated inputs
|
201 |
+
weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
|
202 |
+
weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
|
203 |
+
weighted_message = y_weighted_message[-1] if len(
|
204 |
+
y_weighted_message) > 0 else None
|
205 |
+
scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
|
206 |
+
x_target_features) else None
|
207 |
+
# compute flow
|
208 |
+
residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
|
209 |
+
accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
|
210 |
+
accumulated_flow = residual_flow + accumulated_flow
|
211 |
+
accumulated_feat_flow = accumulated_flow
|
212 |
+
|
213 |
+
feat_ev = x_reference_features[pyrlevel+1][:,
|
214 |
+
view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
|
215 |
+
mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
|
216 |
+
mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
|
217 |
+
stacked_flow.append(accumulated_flow)
|
218 |
+
if accumulated_feat_flow is not None:
|
219 |
+
stacked_feat_flow.append(accumulated_feat_flow)
|
220 |
+
stacked_mask.append(mask)
|
221 |
+
if message is not None:
|
222 |
+
stacked_mesg.append(message)
|
223 |
+
if local_message is not None:
|
224 |
+
stacked_locm.append(local_message)
|
225 |
+
stacked_wref.append(warp_x_reference)
|
226 |
+
if warp_x_reference_features is not None:
|
227 |
+
stacked_feat.append(warp_x_reference_features)
|
228 |
+
|
229 |
+
stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW
|
230 |
+
stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
|
231 |
+
stacked_feat_flow) > 0 else None
|
232 |
+
stacked_mask = torch.stack(
|
233 |
+
stacked_mask, dim=1)
|
234 |
+
|
235 |
+
stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
|
236 |
+
stacked_mesg) > 0 else None
|
237 |
+
stacked_locm = torch.stack(stacked_locm, dim=1) if len(
|
238 |
+
stacked_locm) > 0 else None
|
239 |
+
|
240 |
+
stacked_wref = torch.stack(stacked_wref, dim=1)
|
241 |
+
stacked_feat = torch.stack(stacked_feat, dim=1) if len(
|
242 |
+
stacked_feat) > 0 else None
|
243 |
+
stacked_anci = torch.stack(stacked_anci, dim=1) if len(
|
244 |
+
stacked_anci) > 0 else None
|
245 |
+
y_flow.append(stacked_flow)
|
246 |
+
y_feat_flow.append(stacked_feat_flow)
|
247 |
+
|
248 |
+
y_warp_x_reference.append(stacked_wref)
|
249 |
+
y_warp_x_reference_features.append(stacked_feat)
|
250 |
+
# compute normalized confidence
|
251 |
+
stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)
|
252 |
+
|
253 |
+
# torch.sum to remove temp dimension M from NMCHW --> NCHW
|
254 |
+
weighted_flow = torch.sum(
|
255 |
+
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
|
256 |
+
weighted_mask = torch.sum(
|
257 |
+
stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
|
258 |
+
weighted_wref = torch.sum(
|
259 |
+
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
|
260 |
+
weighted_feat = torch.sum(
|
261 |
+
stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
|
262 |
+
weighted_mesg = torch.sum(
|
263 |
+
stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
|
264 |
+
y_weighted_flow.append(weighted_flow)
|
265 |
+
y_weighted_mask.append(weighted_mask)
|
266 |
+
if weighted_mesg is not None:
|
267 |
+
y_weighted_message.append(weighted_mesg)
|
268 |
+
if stacked_locm is not None:
|
269 |
+
y_local_message.append(stacked_locm)
|
270 |
+
y_weighted_message.append(weighted_mesg)
|
271 |
+
y_weighted_x_reference.append(weighted_wref)
|
272 |
+
y_weighted_x_reference_features.append(weighted_feat)
|
273 |
+
|
274 |
+
if weighted_feat is not None:
|
275 |
+
y_weighted_x_reference_features.append(weighted_feat)
|
276 |
+
return {
|
277 |
+
"y_last_remote_features": [weighted_mesg],
|
278 |
+
}
|
279 |
+
|
280 |
+
def flow_rescale(self, prev_flow, each_x_reference_features):
|
281 |
+
if prev_flow is None:
|
282 |
+
prev_flow = torch.zeros_like(
|
283 |
+
each_x_reference_features[:, :2]).to(device)
|
284 |
+
else:
|
285 |
+
up_scale_factor = each_x_reference_features.shape[-1] / \
|
286 |
+
prev_flow.shape[-1]
|
287 |
+
if up_scale_factor != 1:
|
288 |
+
prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
|
289 |
+
align_corners=False, recompute_scale_factor=False) * up_scale_factor
|
290 |
+
return prev_flow
|
model/warplayer.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
backwarp_tenGrid = {}
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow):
|
9 |
+
with torch.cuda.amp.autocast(enabled=False):
|
10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
11 |
+
if k not in backwarp_tenGrid:
|
12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
16 |
+
backwarp_tenGrid[k] = torch.cat(
|
17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
18 |
+
|
19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
21 |
+
|
22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
23 |
+
if tenInput.dtype != g.dtype:
|
24 |
+
g = g.to(tenInput.dtype)
|
25 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
26 |
+
# "zeros" "border"
|
27 |
+
|
28 |
+
|
29 |
+
def warp_features(inp, flow, ):
|
30 |
+
groups = flow.shape[1]//2 # NCHW
|
31 |
+
samples = inp.shape[0]
|
32 |
+
h = inp.shape[2]
|
33 |
+
w = inp.shape[3]
|
34 |
+
assert(flow.shape[0] == samples and flow.shape[2]
|
35 |
+
== h and flow.shape[3] == w)
|
36 |
+
chns = inp.shape[1]
|
37 |
+
chns_per_group = chns // groups
|
38 |
+
assert(flow.shape[1] % 2 == 0)
|
39 |
+
assert(chns % groups == 0)
|
40 |
+
inp = inp.contiguous().view(samples*groups, chns_per_group, h, w)
|
41 |
+
flow = flow.contiguous().view(samples*groups, 2, h, w)
|
42 |
+
feat = warp(inp, flow)
|
43 |
+
feat = feat.view(samples, chns, h, w)
|
44 |
+
return feat
|
45 |
+
|
46 |
+
|
47 |
+
def flow2rgb(flow_map_np):
|
48 |
+
h, w, _ = flow_map_np.shape
|
49 |
+
rgb_map = np.ones((h, w, 3)).astype(np.float32)/2.0
|
50 |
+
normalized_flow_map = np.concatenate(
|
51 |
+
(flow_map_np[:, :, 0:1]/h/2.0, flow_map_np[:, :, 1:2]/w/2.0), axis=2)
|
52 |
+
rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
|
53 |
+
rgb_map[:, :, 1] -= 0.5 * \
|
54 |
+
(normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
|
55 |
+
rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
|
56 |
+
return (rgb_map.clip(0, 1)*255.0).astype(np.uint8)
|
notebooks/conr.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/conr_chinese.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
poses_template.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d6a3ed9527dd6b04275bf4c81704ac50c9b8cd1bf4a26b1d3ec6658e200ff57
|
3 |
+
size 10662874
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pillow>=7.2.0
|
2 |
+
numpy>=1.16
|
3 |
+
tqdm>=4.35.0
|
4 |
+
torch>=1.3.0
|
5 |
+
opencv-python>=4.5.2
|
6 |
+
scikit-image>=0.14.0
|
7 |
+
torchvision>=0.2.1
|
8 |
+
lpips>=0.1.3
|
9 |
+
gdown
|
train.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from datetime import datetime
|
5 |
+
from distutils.util import strtobool
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from torchvision import transforms
|
11 |
+
from data_loader import (FileDataset,
|
12 |
+
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
13 |
+
from torch.utils.data.distributed import DistributedSampler
|
14 |
+
from conr import CoNR
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
def data_sampler(dataset, shuffle, distributed):
|
18 |
+
|
19 |
+
if distributed:
|
20 |
+
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
21 |
+
|
22 |
+
if shuffle:
|
23 |
+
return torch.utils.data.RandomSampler(dataset)
|
24 |
+
|
25 |
+
else:
|
26 |
+
return torch.utils.data.SequentialSampler(dataset)
|
27 |
+
|
28 |
+
def save_output(image_name, inputs_v, d_dir=".", crop=None):
|
29 |
+
import cv2
|
30 |
+
|
31 |
+
inputs_v = inputs_v.detach().squeeze()
|
32 |
+
input_np = torch.clamp(inputs_v*255, 0, 255).byte().cpu().numpy().transpose(
|
33 |
+
(1, 2, 0))
|
34 |
+
# cv2.setNumThreads(1)
|
35 |
+
out_render_scale = cv2.cvtColor(input_np, cv2.COLOR_RGBA2BGRA)
|
36 |
+
if crop is not None:
|
37 |
+
crop = crop.cpu().numpy()[0]
|
38 |
+
output_img = np.zeros((crop[0], crop[1], 4), dtype=np.uint8)
|
39 |
+
before_resize_scale = cv2.resize(
|
40 |
+
out_render_scale, (crop[5]-crop[4]+crop[8]+crop[9], crop[3]-crop[2]+crop[6]+crop[7]), interpolation=cv2.INTER_AREA) # w,h
|
41 |
+
output_img[crop[2]:crop[3], crop[4]:crop[5]] = before_resize_scale[crop[6]:before_resize_scale.shape[0] -
|
42 |
+
crop[7], crop[8]:before_resize_scale.shape[1]-crop[9]]
|
43 |
+
else:
|
44 |
+
output_img = out_render_scale
|
45 |
+
cv2.imwrite(d_dir+"/"+image_name.split(os.sep)[-1]+'.png',
|
46 |
+
output_img
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
def test():
|
51 |
+
source_names_list = []
|
52 |
+
for name in sorted(os.listdir(args.test_input_person_images)):
|
53 |
+
thissource = os.path.join(args.test_input_person_images, name)
|
54 |
+
if os.path.isfile(thissource):
|
55 |
+
source_names_list.append(thissource)
|
56 |
+
if os.path.isdir(thissource):
|
57 |
+
print("skipping empty folder :"+thissource)
|
58 |
+
|
59 |
+
image_names_list = []
|
60 |
+
for name in sorted(os.listdir(args.test_input_poses_images)):
|
61 |
+
thistarget = os.path.join(args.test_input_poses_images, name)
|
62 |
+
if os.path.isfile(thistarget):
|
63 |
+
image_names_list.append([thistarget, *source_names_list])
|
64 |
+
if os.path.isdir(thistarget):
|
65 |
+
print("skipping folder :"+thistarget)
|
66 |
+
print(image_names_list)
|
67 |
+
|
68 |
+
print("---building models")
|
69 |
+
conrmodel = CoNR(args)
|
70 |
+
conrmodel.load_model(path=args.test_checkpoint_dir)
|
71 |
+
conrmodel.dist()
|
72 |
+
infer(args, conrmodel, image_names_list)
|
73 |
+
|
74 |
+
|
75 |
+
def infer(args, humanflowmodel, image_names_list):
|
76 |
+
print("---test images: ", len(image_names_list))
|
77 |
+
test_salobj_dataset = FileDataset(image_names_list=image_names_list,
|
78 |
+
fg_img_lbl_transform=transforms.Compose([
|
79 |
+
RandomResizedCropWithAutoCenteringAndZeroPadding(
|
80 |
+
(args.dataloader_imgsize, args.dataloader_imgsize), scale=(1, 1), ratio=(1.0, 1.0), center_jitter=(0.0, 0.0)
|
81 |
+
)]),
|
82 |
+
shader_pose_use_gt_udp_test=not args.test_pose_use_parser_udp,
|
83 |
+
shader_target_use_gt_rgb_debug=False
|
84 |
+
)
|
85 |
+
sampler = data_sampler(test_salobj_dataset, shuffle=False,
|
86 |
+
distributed=args.distributed)
|
87 |
+
train_data = DataLoader(test_salobj_dataset,
|
88 |
+
batch_size=1,
|
89 |
+
shuffle=False,sampler=sampler,
|
90 |
+
num_workers=args.dataloaders)
|
91 |
+
|
92 |
+
# start testing
|
93 |
+
|
94 |
+
train_num = train_data.__len__()
|
95 |
+
time_stamp = time.time()
|
96 |
+
prev_frame_rgb = []
|
97 |
+
prev_frame_a = []
|
98 |
+
|
99 |
+
pbar = tqdm(range(train_num), ncols=100)
|
100 |
+
for i, data in enumerate(train_data):
|
101 |
+
data_time_interval = time.time() - time_stamp
|
102 |
+
time_stamp = time.time()
|
103 |
+
with torch.no_grad():
|
104 |
+
data["character_images"] = torch.cat(
|
105 |
+
[data["character_images"], *prev_frame_rgb], dim=1)
|
106 |
+
data["character_masks"] = torch.cat(
|
107 |
+
[data["character_masks"], *prev_frame_a], dim=1)
|
108 |
+
data = humanflowmodel.data_norm_image(data)
|
109 |
+
pred = humanflowmodel.model_step(data, training=False)
|
110 |
+
# remember to call humanflowmodel.reset_charactersheet() if you change character .
|
111 |
+
|
112 |
+
train_time_interval = time.time() - time_stamp
|
113 |
+
time_stamp = time.time()
|
114 |
+
if args.local_rank == 0:
|
115 |
+
pbar.set_description(f"Epoch {i}/{train_num}")
|
116 |
+
pbar.set_postfix({"data_time": data_time_interval, "train_time":train_time_interval})
|
117 |
+
pbar.update(1)
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
|
121 |
+
if args.test_output_video:
|
122 |
+
pred_img = pred["shader"]["y_weighted_warp_decoded_rgba"]
|
123 |
+
save_output(
|
124 |
+
str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir, crop=data["pose_crop"])
|
125 |
+
|
126 |
+
if args.test_output_udp:
|
127 |
+
pred_img = pred["shader"]["x_target_sudp_a"]
|
128 |
+
save_output(
|
129 |
+
"udp_"+str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir)
|
130 |
+
|
131 |
+
|
132 |
+
def build_args():
|
133 |
+
parser = argparse.ArgumentParser()
|
134 |
+
# distributed learning settings
|
135 |
+
parser.add_argument("--world_size", type=int, default=1,
|
136 |
+
help='world size')
|
137 |
+
parser.add_argument("--local_rank", type=int, default=0,
|
138 |
+
help='local_rank, DON\'T change it')
|
139 |
+
|
140 |
+
# model settings
|
141 |
+
parser.add_argument('--dataloader_imgsize', type=int, default=256,
|
142 |
+
help='Input image size of the model')
|
143 |
+
parser.add_argument('--batch_size', type=int, default=4,
|
144 |
+
help='minibatch size')
|
145 |
+
parser.add_argument('--model_name', default='model_result',
|
146 |
+
help='Name of the experiment')
|
147 |
+
parser.add_argument('--dataloaders', type=int, default=2,
|
148 |
+
help='Num of dataloaders')
|
149 |
+
parser.add_argument('--mode', default="test", choices=['train', 'test'],
|
150 |
+
help='Training mode or Testing mode')
|
151 |
+
|
152 |
+
# i/o settings
|
153 |
+
parser.add_argument('--test_input_person_images',
|
154 |
+
type=str, default="./character_sheet/",
|
155 |
+
help='Directory to input character sheets')
|
156 |
+
parser.add_argument('--test_input_poses_images', type=str,
|
157 |
+
default="./test_data/",
|
158 |
+
help='Directory to input UDP sequences or pose images')
|
159 |
+
parser.add_argument('--test_checkpoint_dir', type=str,
|
160 |
+
default='./weights/',
|
161 |
+
help='Directory to model weights')
|
162 |
+
parser.add_argument('--test_output_dir', type=str,
|
163 |
+
default="./results/",
|
164 |
+
help='Directory to output images')
|
165 |
+
|
166 |
+
# output content settings
|
167 |
+
parser.add_argument('--test_output_video', type=strtobool, default=True,
|
168 |
+
help='Whether to output the final result of CoNR, \
|
169 |
+
images will be output to test_output_dir while True.')
|
170 |
+
parser.add_argument('--test_output_udp', type=strtobool, default=False,
|
171 |
+
help='Whether to output UDP generated from UDP detector, \
|
172 |
+
this is meaningful ONLY when test_input_poses_images \
|
173 |
+
is not UDP sequences but pose images. Meanwhile, \
|
174 |
+
test_pose_use_parser_udp need to be True')
|
175 |
+
|
176 |
+
# UDP detector settings
|
177 |
+
parser.add_argument('--test_pose_use_parser_udp',
|
178 |
+
type=strtobool, default=False,
|
179 |
+
help='Whether to use UDP detector to generate UDP from pngs, \
|
180 |
+
pose input MUST be pose images instead of UDP sequences \
|
181 |
+
while True')
|
182 |
+
|
183 |
+
args = parser.parse_args()
|
184 |
+
|
185 |
+
args.distributed = (args.world_size > 1)
|
186 |
+
if args.local_rank == 0:
|
187 |
+
print("batch_size:", args.batch_size, flush=True)
|
188 |
+
if args.distributed:
|
189 |
+
if args.local_rank == 0:
|
190 |
+
print("world_size: ", args.world_size)
|
191 |
+
torch.distributed.init_process_group(
|
192 |
+
backend="nccl", init_method="env://", world_size=args.world_size)
|
193 |
+
torch.cuda.set_device(args.local_rank)
|
194 |
+
torch.backends.cudnn.benchmark = True
|
195 |
+
else:
|
196 |
+
args.local_rank = 0
|
197 |
+
|
198 |
+
return args
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
args = build_args()
|
203 |
+
test()
|
weights/.gitkeep
ADDED
File without changes
|
weights/rgbadecodernet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1cc48cf4b3b6f7c4856bec8299e57466836eff3bffa73f518965fdb75bc16fd
|
3 |
+
size 341897
|
weights/shader.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c8fa74e07db0e0a853fe2aa8d299c854b5c957037dcf4858dc0ca755bec9a95
|
3 |
+
size 384615535
|
weights/target_pose_encoder.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a077437779043b543b015d56ca4f521668645c9ad1cd67ee843aa8a94bf59034
|
3 |
+
size 107891883
|
weights/udpparsernet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13bb149d18146277cfb11f90b96ae0ecb25828a1ce57ae47ba04d5468cd3322e
|
3 |
+
size 228957615
|