Spaces:
simkell
/
Runtime error

simkell nev commited on
Commit
4ef9016
·
0 Parent(s):

Duplicate from nev/CoNR

Browse files

Co-authored-by: Stepan Shabalin <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.png filter=lfs diff=lfs merge=lfs -text
33
+ poses.zip filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ results/
2
+ poses/
3
+ test_data/
4
+ test_data_pre/
5
+ character_sheet/
6
+ # weights/
7
+ *.mp4
8
+ *.webm
9
+ poses.zip
10
+ gradio_cached_examples/
11
+ x264/
12
+ filelist.txt
13
+ complex_infer.sh
14
+ __pycache__/
LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright 2022 Megvii Inc.
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CoNR
3
+ emoji: ⚡
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.1.4
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: nev/CoNR
12
+ ---
13
+
14
+ [English](https://github.com/megvii-research/CoNR/blob/main/README.md) | [中文](https://github.com/megvii-research/CoNR/blob/main/README_chinese.md)
15
+ # Collaborative Neural Rendering using Anime Character Sheets
16
+
17
+
18
+ ## [Homepage](https://conr.ml) | Colab [English](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr.ipynb)/[中文](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr_chinese.ipynb) | [arXiv](https://arxiv.org/abs/2207.05378)
19
+
20
+ ![image](images/MAIN.png)
21
+
22
+ ## Introduction
23
+
24
+ This project is the official implement of [Collaborative Neural Rendering using Anime Character Sheets](https://arxiv.org/abs/2207.05378), which aims to genarate vivid dancing videos from hand-drawn anime character sheets(ACS). Watch more demos in our [HomePage](https://conr.ml).
25
+
26
+ Contributors: [@transpchan](https://github.com/transpchan/), [@P2Oileen](https://github.com/P2Oileen), [@hzwer](https://github.com/hzwer)
27
+
28
+ ## Usage
29
+
30
+ #### Prerequisites
31
+
32
+ * NVIDIA GPU + CUDA + CUDNN
33
+ * Python 3.6
34
+
35
+ #### Installation
36
+
37
+ * Clone this repository
38
+
39
+ ```bash
40
+ git clone https://github.com/megvii-research/CoNR
41
+ ```
42
+
43
+ * Dependencies
44
+
45
+ To install all the dependencies, please run the following commands.
46
+
47
+ ```bash
48
+ cd CoNR
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ * Download Weights
53
+ Download weights from Google Drive. Alternatively, you can download from [Baidu Netdisk](https://pan.baidu.com/s/1U11iIk-DiJodgCveSzB6ig?pwd=RDxc) (password:RDxc).
54
+
55
+ ```
56
+ mkdir weights && cd weights
57
+ gdown https://drive.google.com/uc?id=1M1LEpx70tJ72AIV2TQKr6NE_7mJ7tLYx
58
+ gdown https://drive.google.com/uc?id=1YvZy3NHkJ6gC3pq_j8agcbEJymHCwJy0
59
+ gdown https://drive.google.com/uc?id=1AOWZxBvTo9nUf2_9Y7Xe27ZFQuPrnx9i
60
+ gdown https://drive.google.com/uc?id=19jM1-GcqgGoE1bjmQycQw_vqD9C5e-Jm
61
+ ```
62
+
63
+ #### Prepare Inputs
64
+ We provide two Ultra-Dense Pose sequences for two characters. You can generate more UDPs via 3D models and motions refers to [our paper](https://arxiv.org/abs/2207.05378).
65
+ [Baidu Netdisk](https://pan.baidu.com/s/1hWvz4iQXnVTaTSb6vu1NBg?pwd=RDxc) (password:RDxc)
66
+
67
+ ```
68
+ # for short hair girl
69
+ gdown https://drive.google.com/uc?id=11HMSaEkN__QiAZSnCuaM6GI143xo62KO
70
+ unzip short_hair.zip
71
+ mv short_hair/ poses/
72
+
73
+ # for double ponytail girl
74
+ gdown https://drive.google.com/uc?id=1WNnGVuU0ZLyEn04HzRKzITXqib1wwM4Q
75
+ unzip double_ponytail.zip
76
+ mv double_ponytail/ poses/
77
+ ```
78
+
79
+ We provide sample inputs of anime character sheets. You can also draw more by yourself.
80
+ Character sheets need to be cut out from the background and in png format.
81
+ [Baidu Netdisk](https://pan.baidu.com/s/1shpP90GOMeHke7MuT0-Txw?pwd=RDxc) (password:RDxc)
82
+
83
+ ```
84
+ # for short hair girl
85
+ gdown https://drive.google.com/uc?id=1r-3hUlENSWj81ve2IUPkRKNB81o9WrwT
86
+ unzip short_hair_images.zip
87
+ mv short_hair_images/ character_sheet/
88
+
89
+ # for double ponytail girl
90
+ gdown https://drive.google.com/uc?id=1XMrJf9Lk_dWgXyTJhbEK2LZIXL9G3MWc
91
+ unzip double_ponytail_images.zip
92
+ mv double_ponytail_images/ character_sheet/
93
+ ```
94
+
95
+ #### RUN!
96
+ * with web UI (powered by [Streamlit](https://streamlit.io/))
97
+
98
+ ```
99
+ streamlit run streamlit.py --server.port=8501
100
+ ```
101
+ then open your browser and visit `localhost:8501`, follow the instructions to genarate video.
102
+
103
+ * via terminal
104
+
105
+ ```
106
+ mkdir {dir_to_save_result}
107
+
108
+ python -m torch.distributed.launch \
109
+ --nproc_per_node=1 train.py --mode=test \
110
+ --world_size=1 --dataloaders=2 \
111
+ --test_input_poses_images={dir_to_poses} \
112
+ --test_input_person_images={dir_to_character_sheet} \
113
+ --test_output_dir={dir_to_save_result} \
114
+ --test_checkpoint_dir={dir_to_weights}
115
+
116
+ ffmpeg -r 30 -y -i {dir_to_save_result}/%d.png -r 30 -c:v libx264 output.mp4 -r 30
117
+ ```
118
+
119
+ ## Citation
120
+ ```bibtex
121
+ @article{lin2022conr,
122
+ title={Collaborative Neural Rendering using Anime Character Sheets},
123
+ author={Lin, Zuzeng and Huang, Ailin and Huang, Zhewei and Hu, Chen and Zhou, Shuchang},
124
+ journal={arXiv preprint arXiv:2207.05378},
125
+ year={2022}
126
+ }
127
+ ```
README_chinese.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [English](https://github.com/megvii-research/CoNR/blob/main/README.md) | [中文](https://github.com/megvii-research/CoNR/blob/main/README_chinese.md)
2
+ # 用于二次元手绘设定稿动画化的神经渲染器
3
+
4
+
5
+ ## [HomePage](https://conr.ml) | Colab [English](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr.ipynb)/[中文](https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr_chinese.ipynb) | [arXiv](https://arxiv.org/abs/2207.05378)
6
+
7
+ ![image](images/MAIN.png)
8
+
9
+ ## Introduction
10
+
11
+ 该项目为论文[Collaborative Neural Rendering using Anime Character Sheets](https://arxiv.org/abs/2207.05378)的官方复现,旨在从手绘人物设定稿生成生动的舞蹈动画。您可以在我们的[主页](https://conr.ml)中查看更多视频 demo。
12
+
13
+ 贡献者: [@transpchan](https://github.com/transpchan/), [@P2Oileen](https://github.com/P2Oileen), [@hzwer](https://github.com/hzwer)
14
+
15
+ ## 使用方法
16
+
17
+ #### 需求
18
+
19
+ * Nvidia GPU + CUDA + CUDNN
20
+ * Python 3.6
21
+
22
+ #### 安装
23
+
24
+ * 克隆该项目
25
+
26
+ ```bash
27
+ git clone https://github.com/megvii-research/CoNR
28
+ ```
29
+
30
+ * 安装依赖
31
+
32
+ 请运行以下命令以安装CoNR所需的所有依赖。
33
+
34
+ ```bash
35
+ cd CoNR
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+ * 下载权重
40
+ 运行以下代码,从 Google Drive 下载模型的权重。此外, 你也可以从 [百度云盘](https://pan.baidu.com/s/1U11iIk-DiJodgCveSzB6ig?pwd=RDxc) (password:RDxc)下载权重。
41
+
42
+ ```
43
+ mkdir weights && cd weights
44
+ gdown https://drive.google.com/uc?id=1M1LEpx70tJ72AIV2TQKr6NE_7mJ7tLYx
45
+ gdown https://drive.google.com/uc?id=1YvZy3NHkJ6gC3pq_j8agcbEJymHCwJy0
46
+ gdown https://drive.google.com/uc?id=1AOWZxBvTo9nUf2_9Y7Xe27ZFQuPrnx9i
47
+ gdown https://drive.google.com/uc?id=19jM1-GcqgGoE1bjmQycQw_vqD9C5e-Jm
48
+ ```
49
+
50
+ #### Prepare inputs
51
+ 我们为两个不同的人物,准备了两个超密集姿势(Ultra-Dense Pose)序列,从以下代码中二选一运行,即可从 Google Drive 下载。您可以通过任意的3D模型和动作数据,生成更多的超密集姿势序列,参考我们的[论文](https://arxiv.org/abs/2207.05378)。暂不提供官方转换接口。
52
+ [百度云盘](https://pan.baidu.com/s/1hWvz4iQXnVTaTSb6vu1NBg?pwd=RDxc) (password:RDxc)
53
+
54
+ ```
55
+ # 短发女孩的超密集姿势
56
+ gdown https://drive.google.com/uc?id=11HMSaEkN__QiAZSnCuaM6GI143xo62KO
57
+ unzip short_hair.zip
58
+ mv short_hair/ poses/
59
+
60
+ # 双马尾女孩的超密集姿势
61
+ gdown https://drive.google.com/uc?id=1WNnGVuU0ZLyEn04HzRKzITXqib1wwM4Q
62
+ unzip double_ponytail.zip
63
+ mv double_ponytail/ poses/
64
+ ```
65
+
66
+ 我们提供两个人物手绘设定表的样例,从以下代码中二选一运行,即可从 Google Drive下载。您也可以自行绘制。
67
+ 请注意:人物手绘设定表**必须从背景中分割开**,且必须为png格式。
68
+ [百度云盘](https://pan.baidu.com/s/1shpP90GOMeHke7MuT0-Txw?pwd=RDxc) (password:RDxc)
69
+
70
+ ```
71
+ # 短发女孩的手绘设定表
72
+ gdown https://drive.google.com/uc?id=1r-3hUlENSWj81ve2IUPkRKNB81o9WrwT
73
+ unzip short_hair_images.zip
74
+ mv short_hair_images/ character_sheet/
75
+
76
+ # 双马尾女孩的手绘设定表
77
+ gdown https://drive.google.com/uc?id=1XMrJf9Lk_dWgXyTJhbEK2LZIXL9G3MWc
78
+ unzip double_ponytail_images.zip
79
+ mv double_ponytail_images/ character_sheet/
80
+ ```
81
+
82
+ #### 运行!
83
+ 我们提供两种方案:使用web图形界面,或使用命令行代码运行。
84
+
85
+ * 使用web图形界面 (通过 [Streamlit](https://streamlit.io/) 实现)
86
+
87
+ 运行以下代码:
88
+
89
+ ```
90
+ streamlit run streamlit.py --server_port.8501
91
+ ```
92
+
93
+ 然后打开浏览器并访问 `localhost:8501`, 根据页面内的指示生成视频。
94
+
95
+ * 使用命令行代码
96
+
97
+ 请注意替换`{}`内容,并更换为您放置相应内容的文件夹位置。
98
+
99
+ ```
100
+ mkdir {结果保存路径}
101
+
102
+ python -m torch.distributed.launch \
103
+ --nproc_per_node=1 train.py --mode=test \
104
+ --world_size=1 --dataloaders=2 \
105
+ --test_input_poses_images={姿势路径} \
106
+ --test_input_person_images={人物设定表路径} \
107
+ --test_output_dir={结果保存路径} \
108
+ --test_checkpoint_dir={权重路径}
109
+
110
+ ffmpeg -r 30 -y -i {结果保存路径}/%d.png -r 30 -c:v libx264 output.mp4 -r 30
111
+ ```
112
+
113
+ 视频结果将生成在 `CoNR/output.mp4`。
114
+
115
+ ## 引用CoNR
116
+ ```bibtex
117
+ @article{lin2022conr,
118
+ title={Collaborative Neural Rendering using Anime Character Sheets},
119
+ author={Lin, Zuzeng and Huang, Ailin and Huang, Zhewei and Hu, Chen and Zhou, Shuchang},
120
+ journal={arXiv preprint arXiv:2207.05378},
121
+ year={2022}
122
+ }
123
+ ```
124
+
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import os
4
+ import base64
5
+
6
+
7
+ def get_base64(bin_file):
8
+ with open(bin_file, "rb") as f:
9
+ data = f.read()
10
+ return base64.b64encode(data).decode()
11
+
12
+
13
+ def conr_fn(character_sheets, pose_zip):
14
+ os.system("rm character_sheet/*")
15
+ os.system("rm result/*")
16
+ os.system("rm poses/*")
17
+ os.makedirs("character_sheet", exist_ok=True)
18
+ for i, e in enumerate(character_sheets):
19
+ with open(f"character_sheet/{i}.png", "wb") as f:
20
+ e.seek(0)
21
+ f.write(e.read())
22
+ e.seek(0)
23
+ os.makedirs("poses", exist_ok=True)
24
+ pose_zip.seek(0)
25
+ open("poses.zip", "wb").write(pose_zip.read())
26
+ os.system(f"unzip -d poses poses.zip")
27
+ os.system("sh infer.sh")
28
+ return "output.mp4"
29
+
30
+
31
+ with gr.Blocks() as ui:
32
+ gr.Markdown("CoNR demo")
33
+ gr.Markdown("<a target='_blank' href='https://colab.research.google.com/github/megvii-research/CoNR/blob/main/notebooks/conr.ipynb'> <img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a> [GitHub](https://github.com/megvii-research/CoNR/)")
34
+ gr.Markdown("Unofficial demo for [CoNR](https://transpchan.github.io/live3d/).")
35
+
36
+ with gr.Row():
37
+ # with gr.Column():
38
+ # gr.Markdown("## Parse video")
39
+ # gr.Markdown("TBD")
40
+ with gr.Column():
41
+ gr.Markdown("## Animate character")
42
+ gr.Markdown("Character sheet")
43
+ character_sheets = gr.File(file_count="multiple")
44
+ gr.Markdown("Pose zip") # Don't hack
45
+ pose_video = gr.File(file_count="single")
46
+
47
+ # os.system("sh download.sh")
48
+ run = gr.Button("Run")
49
+ video = gr.Video()
50
+ run.click(fn=conr_fn, inputs=[character_sheets, pose_video], outputs=video)
51
+
52
+ gr.Markdown("## Examples")
53
+ sheets = "character_sheet_ponytail_example"
54
+ gr.Examples(fn=conr_fn, inputs=[character_sheets, pose_video], outputs=video,
55
+ examples=[[[os.path.join(sheets, x) for x in os.listdir(sheets)], "poses_template.zip"]], cache_examples=True, examples_per_page=1)
56
+
57
+ # ui.launch()
58
+ demo = ui
59
+ demo.launch()
character_sheet_ponytail_example/0.png ADDED

Git LFS Details

  • SHA256: 282b26246caa2d3450d17c4968ebd46c99e6131962cd72edeb32693a6f191218
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
character_sheet_ponytail_example/1.png ADDED

Git LFS Details

  • SHA256: 481b6d9ed13fb7a7d07f8fc29a5ec58dba2ca3d7790036cbbeef7b3439658706
  • Pointer size: 131 Bytes
  • Size of remote file: 882 kB
character_sheet_ponytail_example/2.png ADDED

Git LFS Details

  • SHA256: ecea695a6a2072ba286e577b9b1b2ecffc01b48cc88724e1f2199fae95fe7b5a
  • Pointer size: 131 Bytes
  • Size of remote file: 837 kB
character_sheet_ponytail_example/3.png ADDED

Git LFS Details

  • SHA256: 41bae8f8a3ccca9ae582e998c916673dba34dece461d9aaac0b3e60e638307b5
  • Pointer size: 131 Bytes
  • Size of remote file: 829 kB
conr.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+
5
+ from model.backbone import ResEncUnet
6
+
7
+ from model.shader import CINN
8
+ from model.decoder_small import RGBADecoderNet
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+
13
+ def UDPClip(x):
14
+ return torch.clamp(x, min=0, max=1) # NCHW
15
+
16
+
17
+ class CoNR():
18
+ def __init__(self, args):
19
+ self.args = args
20
+
21
+ self.udpparsernet = ResEncUnet(
22
+ backbone_name='resnet50_danbo',
23
+ classes=4,
24
+ pretrained=(args.local_rank == 0),
25
+ parametric_upsampling=True,
26
+ decoder_filters=(512, 384, 256, 128, 32),
27
+ map_location=device
28
+ )
29
+ self.target_pose_encoder = ResEncUnet(
30
+ backbone_name='resnet18_danbo-4',
31
+ classes=1,
32
+ pretrained=(args.local_rank == 0),
33
+ parametric_upsampling=True,
34
+ decoder_filters=(512, 384, 256, 128, 32),
35
+ map_location=device
36
+ )
37
+ self.DIM_SHADER_REFERENCE = 4
38
+ self.shader = CINN(self.DIM_SHADER_REFERENCE)
39
+ self.rgbadecodernet = RGBADecoderNet(
40
+ )
41
+ self.device()
42
+ self.parser_ckpt = None
43
+
44
+ def dist(self):
45
+ args = self.args
46
+ if args.distributed:
47
+ self.udpparsernet = torch.nn.parallel.DistributedDataParallel(
48
+ self.udpparsernet,
49
+ device_ids=[
50
+ args.local_rank],
51
+ output_device=args.local_rank,
52
+ broadcast_buffers=False,
53
+ find_unused_parameters=True
54
+ )
55
+ self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel(
56
+ self.target_pose_encoder,
57
+ device_ids=[
58
+ args.local_rank],
59
+ output_device=args.local_rank,
60
+ broadcast_buffers=False,
61
+ find_unused_parameters=True
62
+ )
63
+ self.shader = torch.nn.parallel.DistributedDataParallel(
64
+ self.shader,
65
+ device_ids=[
66
+ args.local_rank],
67
+ output_device=args.local_rank,
68
+ broadcast_buffers=True
69
+ )
70
+
71
+ self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel(
72
+ self.rgbadecodernet,
73
+ device_ids=[
74
+ args.local_rank],
75
+ output_device=args.local_rank,
76
+ broadcast_buffers=True
77
+ )
78
+
79
+ def load_model(self, path):
80
+ self.udpparsernet.load_state_dict(
81
+ torch.load('{}/udpparsernet.pth'.format(path), map_location=device))
82
+ self.target_pose_encoder.load_state_dict(
83
+ torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device))
84
+ self.shader.load_state_dict(
85
+ torch.load('{}/shader.pth'.format(path), map_location=device))
86
+ self.rgbadecodernet.load_state_dict(
87
+ torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device))
88
+
89
+ def save_model(self, ite_num):
90
+ self._save_pth(self.udpparsernet,
91
+ model_name="udpparsernet", ite_num=ite_num)
92
+ self._save_pth(self.target_pose_encoder,
93
+ model_name="target_pose_encoder", ite_num=ite_num)
94
+ self._save_pth(self.shader,
95
+ model_name="shader", ite_num=ite_num)
96
+ self._save_pth(self.rgbadecodernet,
97
+ model_name="rgbadecodernet", ite_num=ite_num)
98
+
99
+ def _save_pth(self, net, model_name, ite_num):
100
+ args = self.args
101
+ to_save = None
102
+ if args.distributed:
103
+ if args.local_rank == 0:
104
+ to_save = net.module.state_dict()
105
+ else:
106
+ to_save = net.state_dict()
107
+ if to_save:
108
+ model_dir = os.path.join(
109
+ os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep)
110
+
111
+ os.makedirs(model_dir, exist_ok=True)
112
+ torch.save(to_save, model_dir + model_name + ".pth")
113
+
114
+ def train(self):
115
+ self.udpparsernet.train()
116
+ self.target_pose_encoder.train()
117
+ self.shader.train()
118
+ self.rgbadecodernet.train()
119
+
120
+ def eval(self):
121
+ self.udpparsernet.eval()
122
+ self.target_pose_encoder.eval()
123
+ self.shader.eval()
124
+ self.rgbadecodernet.eval()
125
+
126
+ def device(self):
127
+ self.udpparsernet.to(device)
128
+ self.target_pose_encoder.to(device)
129
+ self.shader.to(device)
130
+ self.rgbadecodernet.to(device)
131
+
132
+ def data_norm_image(self, data):
133
+
134
+ with torch.cuda.amp.autocast(enabled=False):
135
+ for name in ["character_labels", "pose_label"]:
136
+ if name in data:
137
+ data[name] = data[name].to(
138
+ device, non_blocking=True).float()
139
+ for name in ["pose_images", "pose_mask", "character_images", "character_masks"]:
140
+ if name in data:
141
+ data[name] = data[name].to(
142
+ device, non_blocking=True).float() / 255.0
143
+ if "pose_images" in data:
144
+ data["num_pose_images"] = data["pose_images"].shape[1]
145
+ data["num_samples"] = data["pose_images"].shape[0]
146
+ if "character_images" in data:
147
+ data["num_character_images"] = data["character_images"].shape[1]
148
+ data["num_samples"] = data["character_images"].shape[0]
149
+ if "pose_images" in data and "character_images" in data:
150
+ assert (data["pose_images"].shape[0] ==
151
+ data["character_images"].shape[0])
152
+ return data
153
+
154
+ def reset_charactersheet(self):
155
+ self.parser_ckpt = None
156
+
157
+ def model_step(self, data, training=False):
158
+ self.eval()
159
+ with torch.cuda.amp.autocast(enabled=False):
160
+ pred = {}
161
+ if self.parser_ckpt:
162
+ pred["parser"] = self.parser_ckpt
163
+ else:
164
+ pred = self.character_parser_forward(data, pred)
165
+ self.parser_ckpt = pred["parser"]
166
+ pred = self.pose_parser_sc_forward(data, pred)
167
+ pred = self.shader_pose_encoder_forward(data, pred)
168
+ pred = self.shader_forward(data, pred)
169
+ return pred
170
+
171
+ def shader_forward(self, data, pred={}):
172
+ assert ("num_character_images" in data), "ERROR: No Character Sheet input."
173
+
174
+ character_images_rgb_nmchw, num_character_images = data[
175
+ "character_images"], data["num_character_images"]
176
+ # build x_reference_rgb_a_sudp in the draw call
177
+ shader_character_a_nmchw = data["character_masks"]
178
+ assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \
179
+ No transparent area found in the image, PLEASE separate the foreground of input character sheets.\
180
+ The website waifucutout.com is recommended to automatically cut out the foreground."
181
+
182
+ if shader_character_a_nmchw is None:
183
+ shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :]
184
+ x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :],
185
+ shader_character_a_nmchw[:,
186
+ :, :, :, :],
187
+
188
+ ], 2)
189
+ assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE)
190
+ # build x_reference_features in the draw call
191
+ x_reference_features = pred["parser"]["features"]
192
+ # run cinn shader
193
+ retdic = self.shader(
194
+ pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features)
195
+ pred["shader"].update(retdic)
196
+
197
+ # decode rgba
198
+ if True:
199
+ dec_out = self.rgbadecodernet(
200
+ retdic["y_last_remote_features"])
201
+ y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :]
202
+ y_weighted_mask_A = dec_out[:, 3:4, :, :]
203
+ y_weighted_warp_decoded_rgba = torch.cat(
204
+ (y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1
205
+ )
206
+ assert(y_weighted_warp_decoded_rgba.shape[1] == 4)
207
+ assert(
208
+ y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1])
209
+ # apply decoded mask to decoded rgb, finishing the draw call
210
+ pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba
211
+ return pred
212
+
213
+ def character_parser_forward(self, data, pred={}):
214
+ if not("num_character_images" in data and "character_images" in data):
215
+ return pred
216
+ pred["parser"] = {"pred": None} # create output
217
+
218
+ inputs_rgb_nmchw, num_samples, num_character_images = data[
219
+ "character_images"], data["num_samples"], data["num_character_images"]
220
+ inputs_rgb_fchw = inputs_rgb_nmchw.view(
221
+ (num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4]))
222
+
223
+ encoder_out, features = self.udpparsernet(
224
+ (inputs_rgb_fchw-0.6)/0.2970)
225
+
226
+ pred["parser"]["features"] = [features_out.view(
227
+ (num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features]
228
+
229
+ if (encoder_out is not None):
230
+
231
+ pred["parser"]["pred"] = UDPClip(encoder_out.view(
232
+ (num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])))
233
+
234
+ return pred
235
+
236
+ def pose_parser_sc_forward(self, data, pred={}):
237
+ if not("num_pose_images" in data and "pose_images" in data):
238
+ return pred
239
+ inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[
240
+ "pose_images"], data["num_samples"], data["num_pose_images"]
241
+ inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view(
242
+ (num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4]))
243
+
244
+ encoder_out, _ = self.udpparsernet(
245
+ (inputs_aug_rgb_fchw-0.6)/0.2970)
246
+
247
+ encoder_out = encoder_out.view(
248
+ (num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))
249
+
250
+ # apply sigmoid after eval loss
251
+ pred["pose_parser"] = {"pred":UDPClip(encoder_out)[:,0,:,:,:]}
252
+
253
+
254
+ return pred
255
+
256
+ def shader_pose_encoder_forward(self, data, pred={}):
257
+ pred["shader"] = {} # create output
258
+ if "pose_images" in data:
259
+ pose_images_rgb_nmchw = data["pose_images"]
260
+ target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :]
261
+ pred["shader"]["target_gt_rgb"] = target_gt_rgb
262
+
263
+ shader_target_a = None
264
+ if "pose_mask" in data:
265
+ pred["shader"]["target_gt_a"] = data["pose_mask"]
266
+ shader_target_a = data["pose_mask"]
267
+
268
+ shader_target_sudp = None
269
+ if "pose_label" in data:
270
+ shader_target_sudp = data["pose_label"][:, :3, :, :]
271
+
272
+ if self.args.test_pose_use_parser_udp:
273
+ shader_target_sudp = None
274
+ if shader_target_sudp is None:
275
+ shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :]
276
+
277
+ if shader_target_a is None:
278
+ shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :]
279
+
280
+ # build x_target_sudp_a in the draw call
281
+ x_target_sudp_a = torch.cat((
282
+ shader_target_sudp*shader_target_a,
283
+ shader_target_a
284
+ ), 1)
285
+ pred["shader"].update({
286
+ "x_target_sudp_a": x_target_sudp_a
287
+ })
288
+ _, features = self.target_pose_encoder(
289
+ (x_target_sudp_a-0.6)/0.2970, ret_parser_out=False)
290
+
291
+ pred["shader"]["target_pose_features"] = features
292
+ return pred
data_loader.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import os
6
+ cv2.setNumThreads(1)
7
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
8
+
9
+
10
+ class RandomResizedCropWithAutoCenteringAndZeroPadding (object):
11
+ def __init__(self, output_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), center_jitter=(0.1, 0.1), size_from_alpha_mask=True):
12
+ assert isinstance(output_size, (int, tuple))
13
+ if isinstance(output_size, int):
14
+ self.output_size = (output_size, output_size)
15
+ else:
16
+ assert len(output_size) == 2
17
+ self.output_size = output_size
18
+ assert isinstance(scale, tuple)
19
+ assert isinstance(ratio, tuple)
20
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
21
+ raise ValueError("Scale and ratio should be of kind (min, max)")
22
+ self.size_from_alpha_mask = size_from_alpha_mask
23
+ self.scale = scale
24
+ self.ratio = ratio
25
+ assert isinstance(center_jitter, tuple)
26
+ self.center_jitter = center_jitter
27
+
28
+ def __call__(self, sample):
29
+ imidx, image = sample['imidx'], sample["image_np"]
30
+ if "labels" in sample:
31
+ label = sample["labels"]
32
+ else:
33
+ label = None
34
+
35
+ im_h, im_w = image.shape[:2]
36
+ if self.size_from_alpha_mask and image.shape[2] == 4:
37
+ # compute bbox from alpha mask
38
+ bbox_left, bbox_top, bbox_w, bbox_h = cv2.boundingRect(
39
+ (image[:, :, 3] > 0).astype(np.uint8))
40
+ else:
41
+ bbox_left, bbox_top = 0, 0
42
+ bbox_h, bbox_w = image.shape[:2]
43
+ if bbox_h <= 1 and bbox_w <= 1:
44
+ sample["bad"] = 0
45
+ else:
46
+ # detect too small image here
47
+ alpha_varea = np.sum((image[:, :, 3] > 0).astype(np.uint8))
48
+ image_area = image.shape[0]*image.shape[1]
49
+ if alpha_varea/image_area < 0.001:
50
+ sample["bad"] = alpha_varea
51
+ # detect bad image
52
+ if "bad" in sample:
53
+ # baddata_dir = os.path.join(os.getcwd(), 'test_data', "baddata" + os.sep)
54
+ # save_output(str(imidx)+".png",image,label,baddata_dir)
55
+ bbox_h, bbox_w = image.shape[:2]
56
+ sample["image_np"] = np.zeros(
57
+ [self.output_size[0], self.output_size[1], image.shape[2]], dtype=image.dtype)
58
+ if label is not None:
59
+ sample["labels"] = np.zeros(
60
+ [self.output_size[0], self.output_size[1], 4], dtype=label.dtype)
61
+
62
+ return sample
63
+
64
+ # compute default area by making sure output_size contains bbox_w * bbox_h
65
+
66
+ jitter_h = np.random.uniform(-bbox_h *
67
+ self.center_jitter[0], bbox_h*self.center_jitter[0])
68
+ jitter_w = np.random.uniform(-bbox_w *
69
+ self.center_jitter[1], bbox_w*self.center_jitter[1])
70
+
71
+ # h/w
72
+ target_aspect_ratio = np.exp(
73
+ np.log(self.output_size[0]/self.output_size[1]) +
74
+ np.random.uniform(np.log(self.ratio[0]), np.log(self.ratio[1]))
75
+ )
76
+
77
+ source_aspect_ratio = bbox_h/bbox_w
78
+
79
+ if target_aspect_ratio < source_aspect_ratio:
80
+ # same w, target has larger h, use h to align
81
+ target_height = bbox_h * \
82
+ np.random.uniform(self.scale[0], self.scale[1])
83
+ virtual_h = int(
84
+ round(target_height))
85
+ virtual_w = int(
86
+ round(target_height / target_aspect_ratio)) # h/w
87
+ else:
88
+ # same w, source has larger h, use w to align
89
+ target_width = bbox_w * \
90
+ np.random.uniform(self.scale[0], self.scale[1])
91
+ virtual_h = int(
92
+ round(target_width * target_aspect_ratio)) # h/w
93
+ virtual_w = int(
94
+ round(target_width))
95
+
96
+ # print("required aspect ratio:", target_aspect_ratio)
97
+
98
+ virtual_top = int(round(bbox_top + jitter_h - (virtual_h-bbox_h)/2))
99
+ virutal_left = int(round(bbox_left + jitter_w - (virtual_w-bbox_w)/2))
100
+
101
+ if virtual_top < 0:
102
+ top_padding = abs(virtual_top)
103
+ crop_top = 0
104
+ else:
105
+ top_padding = 0
106
+ crop_top = virtual_top
107
+ if virutal_left < 0:
108
+ left_padding = abs(virutal_left)
109
+ crop_left = 0
110
+ else:
111
+ left_padding = 0
112
+ crop_left = virutal_left
113
+ if virtual_top+virtual_h > im_h:
114
+ bottom_padding = abs(im_h-(virtual_top+virtual_h))
115
+ crop_bottom = im_h
116
+ else:
117
+ bottom_padding = 0
118
+ crop_bottom = virtual_top+virtual_h
119
+ if virutal_left+virtual_w > im_w:
120
+ right_padding = abs(im_w-(virutal_left+virtual_w))
121
+ crop_right = im_w
122
+ else:
123
+ right_padding = 0
124
+ crop_right = virutal_left+virtual_w
125
+ # crop
126
+
127
+ image = image[crop_top:crop_bottom, crop_left: crop_right]
128
+ if label is not None:
129
+ label = label[crop_top:crop_bottom, crop_left: crop_right]
130
+
131
+ # pad
132
+ if top_padding + bottom_padding + left_padding + right_padding > 0:
133
+ padding = ((top_padding, bottom_padding),
134
+ (left_padding, right_padding), (0, 0))
135
+ # print("padding", padding)
136
+ image = np.pad(image, padding, mode='constant')
137
+ if label is not None:
138
+ label = np.pad(label, padding, mode='constant')
139
+
140
+ if image.shape[0]/image.shape[1] - virtual_h/virtual_w > 0.001:
141
+ print("virtual aspect ratio:", virtual_h/virtual_w)
142
+ print("image aspect ratio:", image.shape[0]/image.shape[1])
143
+ assert (image.shape[0]/image.shape[1] - virtual_h/virtual_w < 0.001)
144
+ sample["crop"] = np.array(
145
+ [im_h, im_w, crop_top, crop_bottom, crop_left, crop_right, top_padding, bottom_padding, left_padding, right_padding, image.shape[0], image.shape[1]])
146
+
147
+ # resize
148
+ if self.output_size[1] != image.shape[1] or self.output_size[0] != image.shape[0]:
149
+ if self.output_size[1] > image.shape[1] and self.output_size[0] > image.shape[0]:
150
+ # enlarging
151
+ image = cv2.resize(
152
+ image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_LINEAR)
153
+ else:
154
+ # shrinking
155
+ image = cv2.resize(
156
+ image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_AREA)
157
+
158
+ if label is not None:
159
+ label = cv2.resize(label, (self.output_size[1], self.output_size[0]),
160
+ interpolation=cv2.INTER_NEAREST_EXACT)
161
+
162
+ assert image.shape[0] == self.output_size[0] and image.shape[1] == self.output_size[1]
163
+ sample['imidx'], sample["image_np"] = imidx, image
164
+ if label is not None:
165
+ assert label.shape[0] == self.output_size[0] and label.shape[1] == self.output_size[1]
166
+ sample["labels"] = label
167
+
168
+ return sample
169
+
170
+
171
+ class FileDataset(Dataset):
172
+ def __init__(self, image_names_list, fg_img_lbl_transform=None, shader_pose_use_gt_udp_test=True, shader_target_use_gt_rgb_debug=False):
173
+ self.image_name_list = image_names_list
174
+ self.fg_img_lbl_transform = fg_img_lbl_transform
175
+ self.shader_pose_use_gt_udp_test = shader_pose_use_gt_udp_test
176
+ self.shader_target_use_gt_rgb_debug = shader_target_use_gt_rgb_debug
177
+
178
+ def __len__(self):
179
+ return len(self.image_name_list)
180
+
181
+ def get_gt_from_disk(self, idx, imname, read_label):
182
+ if read_label:
183
+ # read label
184
+ with open(imname, mode="rb") as bio:
185
+ if imname.find(".npz") > 0:
186
+ label_np = np.load(bio, allow_pickle=True)[
187
+ 'i'].astype(np.float32, copy=False)
188
+ else:
189
+ label_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(bio.read(
190
+ ), np.uint8), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
191
+ assert (4 == label_np.shape[2])
192
+ # fake image out of valid label
193
+ image_np = (label_np*255).clip(0, 255).astype(np.uint8, copy=False)
194
+ # assemble sample
195
+ sample = {'imidx': np.array(
196
+ [idx]), "image_np": image_np, "labels": label_np}
197
+
198
+ else:
199
+ # read image as unit8
200
+ with open(imname, mode="rb") as bio:
201
+ image_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(
202
+ bio.read(), np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
203
+ # image_np = Image.open(bio)
204
+ # image_np = np.array(image_np)
205
+ assert (3 == len(image_np.shape))
206
+ if (image_np.shape[2] == 4):
207
+ mask_np = image_np[:, :, 3:4]
208
+ image_np = (image_np[:, :, :3] *
209
+ (image_np[:, :, 3][:, :, np.newaxis]/255.0)).clip(0, 255).astype(np.uint8, copy=False)
210
+ elif (image_np.shape[2] == 3):
211
+ # generate a fake mask
212
+ # Fool-proofing
213
+ mask_np = np.ones(
214
+ (image_np.shape[0], image_np.shape[1], 1), dtype=np.uint8)*255
215
+ print("WARN: transparent background is preferred for image ", imname)
216
+ else:
217
+ raise ValueError("weird shape of image ", imname, image_np)
218
+ image_np = np.concatenate((image_np, mask_np), axis=2)
219
+ sample = {'imidx': np.array(
220
+ [idx]), "image_np": image_np}
221
+
222
+ # apply fg_img_lbl_transform
223
+ if self.fg_img_lbl_transform:
224
+ sample = self.fg_img_lbl_transform(sample)
225
+
226
+ if "labels" in sample:
227
+ # return UDP as 4chn XYZV float tensor
228
+ if "float" not in str(sample["labels"].dtype):
229
+ sample["labels"] = sample["labels"].astype(np.float32) / np.iinfo(sample["labels"].dtype).max
230
+ sample["labels"] = torch.from_numpy(
231
+ sample["labels"].transpose((2, 0, 1)))
232
+ assert (sample["labels"].dtype == torch.float32)
233
+
234
+ if "image_np" in sample:
235
+ # return image as 3chn RGB uint8 tensor and 1chn A uint8 tensor
236
+ sample["mask"] = torch.from_numpy(
237
+ sample["image_np"][:, :, 3:4].transpose((2, 0, 1)))
238
+ assert (sample["mask"].dtype == torch.uint8)
239
+ sample["image"] = torch.from_numpy(
240
+ sample["image_np"][:, :, :3].transpose((2, 0, 1)))
241
+
242
+ assert (sample["image"].dtype == torch.uint8)
243
+ del sample["image_np"]
244
+ return sample
245
+
246
+ def __getitem__(self, idx):
247
+ sample = {
248
+ 'imidx': np.array([idx])}
249
+ target = self.get_gt_from_disk(
250
+ idx, imname=self.image_name_list[idx][0], read_label=self.shader_pose_use_gt_udp_test)
251
+ if self.shader_target_use_gt_rgb_debug:
252
+ sample["pose_images"] = torch.stack([target["image"]])
253
+ sample["pose_mask"] = target["mask"]
254
+ elif self.shader_pose_use_gt_udp_test:
255
+ sample["pose_label"] = target["labels"]
256
+ sample["pose_mask"] = target["mask"]
257
+ else:
258
+ sample["pose_images"] = torch.stack([target["image"]])
259
+ if "crop" in target:
260
+ sample["pose_crop"] = target["crop"]
261
+ character_images = []
262
+ character_masks = []
263
+ for i in range(1, len(self.image_name_list[idx])):
264
+ source = self.get_gt_from_disk(
265
+ idx, self.image_name_list[idx][i], read_label=False)
266
+ character_images.append(source["image"])
267
+ character_masks.append(source["mask"])
268
+ character_images = torch.stack(character_images)
269
+ character_masks = torch.stack(character_masks)
270
+ sample.update({
271
+ "character_images": character_images,
272
+ "character_masks": character_masks
273
+ })
274
+ # do not make fake labels in inference
275
+ return sample
images/MAIN.png ADDED

Git LFS Details

  • SHA256: 6749cc7fd06b2e9b10345655ef93b2b65ba96dcb308c6e9bb6ef762cdb7db76e
  • Pointer size: 131 Bytes
  • Size of remote file: 805 kB
infer.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rm -r "./results"
2
+ mkdir "./results"
3
+
4
+ python3 train.py --mode=test \
5
+ --world_size=1 --dataloaders=2 \
6
+ --test_input_poses_images=./poses/ \
7
+ --test_input_person_images=./character_sheet/ \
8
+ --test_output_dir=./results/ \
9
+ --test_checkpoint_dir=./weights/
10
+
11
+ echo Generating Video...
12
+ ffmpeg -r 30 -y -i ./results/%d.png -r 30 -c:v libx264 -pix_fmt yuv420p output.mp4 -crf 18 -r 30
13
+ echo DONE.
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
model/backbone.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code was mostly taken from backbone-unet by mkisantal:
3
+ https://github.com/mkisantal/backboned-unet/blob/master/backboned_unet/unet.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+ from torch.nn import functional as F
9
+
10
+ import torch.nn as nn
11
+ import torch
12
+ from torchvision import models
13
+
14
+
15
+ class AdaptiveConcatPool2d(nn.Module):
16
+ """
17
+ Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
18
+ Source: Fastai. This code was taken from the fastai library at url
19
+ https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
20
+ """
21
+
22
+ def __init__(self, sz=None):
23
+ "Output will be 2*sz or 2 if sz is None"
24
+ super().__init__()
25
+ self.output_size = sz or 1
26
+ self.ap = nn.AdaptiveAvgPool2d(self.output_size)
27
+ self.mp = nn.AdaptiveMaxPool2d(self.output_size)
28
+
29
+ def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
30
+
31
+
32
+ class MyNorm(nn.Module):
33
+ def __init__(self, num_channels):
34
+ super(MyNorm, self).__init__()
35
+ self.norm = nn.InstanceNorm2d(
36
+ num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
37
+
38
+ def forward(self, x):
39
+ x = self.norm(x)
40
+ return x
41
+
42
+
43
+ def resnet_fastai(model, pretrained, url, replace_first_layer=None, replace_maxpool_layer=None, progress=True, map_location=None, **kwargs):
44
+ cut = -2
45
+ s = model(pretrained=False, **kwargs)
46
+ if replace_maxpool_layer is not None:
47
+ s.maxpool = replace_maxpool_layer
48
+ if replace_first_layer is not None:
49
+ body = nn.Sequential(replace_first_layer, *list(s.children())[1:cut])
50
+ else:
51
+ body = nn.Sequential(*list(s.children())[:cut])
52
+
53
+ if pretrained:
54
+ state = torch.hub.load_state_dict_from_url(url,
55
+ progress=progress, map_location=map_location)
56
+ if replace_first_layer is not None:
57
+ for each in list(state.keys()).copy():
58
+ if each.find("0.0.") == 0:
59
+ del state[each]
60
+ body_tail = nn.Sequential(body)
61
+ ret = body_tail.load_state_dict(state, strict=False)
62
+ return body
63
+
64
+
65
+ def get_backbone(name, pretrained=True, map_location=None):
66
+ """ Loading backbone, defining names for skip-connections and encoder output. """
67
+
68
+ first_layer_for_4chn = nn.Conv2d(
69
+ 4, 64, kernel_size=7, stride=2, padding=3, bias=False)
70
+ max_pool_layer_replace = nn.Conv2d(
71
+ 64, 64, kernel_size=3, stride=2, padding=1, bias=False)
72
+ # loading backbone model
73
+ if name == 'resnet18':
74
+ backbone = models.resnet18(pretrained=pretrained)
75
+ if name == 'resnet18-4':
76
+ backbone = models.resnet18(pretrained=pretrained)
77
+ backbone.conv1 = first_layer_for_4chn
78
+ elif name == 'resnet34':
79
+ backbone = models.resnet34(pretrained=pretrained)
80
+ elif name == 'resnet50':
81
+ backbone = models.resnet50(pretrained=False, norm_layer=MyNorm)
82
+ backbone.maxpool = max_pool_layer_replace
83
+ elif name == 'resnet101':
84
+ backbone = models.resnet101(pretrained=pretrained)
85
+ elif name == 'resnet152':
86
+ backbone = models.resnet152(pretrained=pretrained)
87
+ elif name == 'vgg16':
88
+ backbone = models.vgg16_bn(pretrained=pretrained).features
89
+ elif name == 'vgg19':
90
+ backbone = models.vgg19_bn(pretrained=pretrained).features
91
+ elif name == 'resnet18_danbo-4':
92
+ backbone = resnet_fastai(models.resnet18, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet18-3f77756f.pth",
93
+ pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_first_layer=first_layer_for_4chn)
94
+ elif name == 'resnet50_danbo':
95
+ backbone = resnet_fastai(models.resnet50, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
96
+ pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_maxpool_layer=max_pool_layer_replace)
97
+ elif name == 'densenet121':
98
+ backbone = models.densenet121(pretrained=True).features
99
+ elif name == 'densenet161':
100
+ backbone = models.densenet161(pretrained=True).features
101
+ elif name == 'densenet169':
102
+ backbone = models.densenet169(pretrained=True).features
103
+ elif name == 'densenet201':
104
+ backbone = models.densenet201(pretrained=True).features
105
+ else:
106
+ raise NotImplemented(
107
+ '{} backbone model is not implemented so far.'.format(name))
108
+ #print(backbone)
109
+ # specifying skip feature and output names
110
+ if name.startswith('resnet'):
111
+ feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
112
+ backbone_output = 'layer4'
113
+ elif name == 'vgg16':
114
+ # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
115
+ feature_names = ['5', '12', '22', '32', '42']
116
+ backbone_output = '43'
117
+ elif name == 'vgg19':
118
+ feature_names = ['5', '12', '25', '38', '51']
119
+ backbone_output = '52'
120
+ elif name.startswith('densenet'):
121
+ feature_names = [None, 'relu0', 'denseblock1',
122
+ 'denseblock2', 'denseblock3']
123
+ backbone_output = 'denseblock4'
124
+ elif name == 'unet_encoder':
125
+ feature_names = ['module1', 'module2', 'module3', 'module4']
126
+ backbone_output = 'module5'
127
+ else:
128
+ raise NotImplemented(
129
+ '{} backbone model is not implemented so far.'.format(name))
130
+ if name.find('_danbo') > 0:
131
+ feature_names = [None, '2', '4', '5', '6']
132
+ backbone_output = '7'
133
+ return backbone, feature_names, backbone_output
134
+
135
+
136
+ class UpsampleBlock(nn.Module):
137
+
138
+ # TODO: separate parametric and non-parametric classes?
139
+ # TODO: skip connection concatenated OR added
140
+
141
+ def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
142
+ super(UpsampleBlock, self).__init__()
143
+
144
+ self.parametric = parametric
145
+ ch_out = ch_in/2 if ch_out is None else ch_out
146
+
147
+ # first convolution: either transposed conv, or conv following the skip connection
148
+ if parametric:
149
+ # versions: kernel=4 padding=1, kernel=2 padding=0
150
+ self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
151
+ stride=2, padding=1, output_padding=0, bias=(not use_bn))
152
+ self.bn1 = MyNorm(ch_out) if use_bn else None
153
+ else:
154
+ self.up = None
155
+ ch_in = ch_in + skip_in
156
+ self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
157
+ stride=1, padding=1, bias=(not use_bn))
158
+ self.bn1 = MyNorm(ch_out) if use_bn else None
159
+
160
+ self.relu = nn.ReLU(inplace=True)
161
+
162
+ # second convolution
163
+ conv2_in = ch_out if not parametric else ch_out + skip_in
164
+ self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
165
+ stride=1, padding=1, bias=(not use_bn))
166
+ self.bn2 = MyNorm(ch_out) if use_bn else None
167
+
168
+ def forward(self, x, skip_connection=None):
169
+
170
+ x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
171
+ align_corners=None)
172
+ if self.parametric:
173
+ x = self.bn1(x) if self.bn1 is not None else x
174
+ x = self.relu(x)
175
+
176
+ if skip_connection is not None:
177
+ x = torch.cat([x, skip_connection], dim=1)
178
+
179
+ if not self.parametric:
180
+ x = self.conv1(x)
181
+ x = self.bn1(x) if self.bn1 is not None else x
182
+ x = self.relu(x)
183
+ x = self.conv2(x)
184
+ x = self.bn2(x) if self.bn2 is not None else x
185
+ x = self.relu(x)
186
+
187
+ return x
188
+
189
+
190
+ class ResEncUnet(nn.Module):
191
+
192
+ """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
193
+
194
+ def __init__(self,
195
+ backbone_name,
196
+ pretrained=True,
197
+ encoder_freeze=False,
198
+ classes=21,
199
+ decoder_filters=(512, 256, 128, 64, 32),
200
+ parametric_upsampling=True,
201
+ shortcut_features='default',
202
+ decoder_use_instancenorm=True,
203
+ map_location=None
204
+ ):
205
+ super(ResEncUnet, self).__init__()
206
+
207
+ self.backbone_name = backbone_name
208
+
209
+ self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(
210
+ backbone_name, pretrained=pretrained, map_location=map_location)
211
+ shortcut_chs, bb_out_chs = self.infer_skip_channels()
212
+ if shortcut_features != 'default':
213
+ self.shortcut_features = shortcut_features
214
+
215
+ # build decoder part
216
+ self.upsample_blocks = nn.ModuleList()
217
+ # avoiding having more blocks than skip connections
218
+ decoder_filters = decoder_filters[:len(self.shortcut_features)]
219
+ decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
220
+ num_blocks = len(self.shortcut_features)
221
+ for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
222
+ self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
223
+ skip_in=shortcut_chs[num_blocks-i-1],
224
+ parametric=parametric_upsampling,
225
+ use_bn=decoder_use_instancenorm))
226
+ self.final_conv = nn.Conv2d(
227
+ decoder_filters[-1], classes, kernel_size=(1, 1))
228
+
229
+ if encoder_freeze:
230
+ self.freeze_encoder()
231
+
232
+ def freeze_encoder(self):
233
+ """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
234
+
235
+ for param in self.backbone.parameters():
236
+ param.requires_grad = False
237
+
238
+ def forward(self, *input, ret_parser_out=True):
239
+ """ Forward propagation in U-Net. """
240
+
241
+ x, features = self.forward_backbone(*input)
242
+ output_feature = [x]
243
+ for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
244
+ skip_features = features[skip_name]
245
+ if skip_features is not None:
246
+ output_feature.append(skip_features)
247
+ if ret_parser_out:
248
+ x = upsample_block(x, skip_features)
249
+ if ret_parser_out:
250
+ x = self.final_conv(x)
251
+ # apply sigmoid later
252
+ else:
253
+ x = None
254
+
255
+ return x, output_feature
256
+
257
+ def forward_backbone(self, x):
258
+ """ Forward propagation in backbone encoder network. """
259
+
260
+ features = {None: None} if None in self.shortcut_features else dict()
261
+ for name, child in self.backbone.named_children():
262
+ x = child(x)
263
+ if name in self.shortcut_features:
264
+ features[name] = x
265
+ if name == self.bb_out_name:
266
+ break
267
+
268
+ return x, features
269
+
270
+ def infer_skip_channels(self):
271
+ """ Getting the number of channels at skip connections and at the output of the encoder. """
272
+ if self.backbone_name.find("-4") > 0:
273
+ x = torch.zeros(1, 4, 224, 224)
274
+ else:
275
+ x = torch.zeros(1, 3, 224, 224)
276
+ has_fullres_features = self.backbone_name.startswith(
277
+ 'vgg') or self.backbone_name == 'unet_encoder'
278
+ # only VGG has features at full resolution
279
+ channels = [] if has_fullres_features else [0]
280
+
281
+ # forward run in backbone to count channels (dirty solution but works for *any* Module)
282
+ for name, child in self.backbone.named_children():
283
+ x = child(x)
284
+ if name in self.shortcut_features:
285
+ channels.append(x.shape[1])
286
+ if name == self.bb_out_name:
287
+ out_channels = x.shape[1]
288
+ break
289
+ return channels, out_channels
model/decoder_small.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+
8
+ class ResBlock2d(nn.Module):
9
+ def __init__(self, in_features, kernel_size, padding):
10
+ super(ResBlock2d, self).__init__()
11
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
12
+ padding=padding)
13
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
14
+ padding=padding)
15
+
16
+ self.norm1 = nn.Conv2d(
17
+ in_channels=in_features, out_channels=in_features, kernel_size=1)
18
+ self.norm2 = nn.Conv2d(
19
+ in_channels=in_features, out_channels=in_features, kernel_size=1)
20
+
21
+ def forward(self, x):
22
+ out = self.norm1(x)
23
+ out = F.relu(out, inplace=True)
24
+ out = self.conv1(out)
25
+ out = self.norm2(out)
26
+ out = F.relu(out, inplace=True)
27
+ out = self.conv2(out)
28
+ out += x
29
+ return out
30
+
31
+
32
+ class RGBADecoderNet(nn.Module):
33
+ def __init__(self, c=64, out_planes=4, num_bottleneck_blocks=1):
34
+ super(RGBADecoderNet, self).__init__()
35
+ self.conv_rgba = nn.Sequential(nn.Conv2d(c, out_planes, kernel_size=3, stride=1,
36
+ padding=1, dilation=1, bias=True))
37
+ self.bottleneck = torch.nn.Sequential()
38
+ for i in range(num_bottleneck_blocks):
39
+ self.bottleneck.add_module(
40
+ 'r' + str(i), ResBlock2d(c, kernel_size=(3, 3), padding=(1, 1)))
41
+
42
+ def forward(self, features_weighted_mask_atfeaturesscale_list=[]):
43
+ return torch.sigmoid(self.conv_rgba(self.bottleneck(features_weighted_mask_atfeaturesscale_list.pop(0))))
model/shader.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .warplayer import warp_features
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+
8
+ class DecoderBlock(nn.Module):
9
+ def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):
10
+
11
+ super(DecoderBlock, self).__init__()
12
+ self.conv0 = nn.Sequential(
13
+ nn.Conv2d(in_planes, c, 3, 2, 1),
14
+ nn.PReLU(c),
15
+ nn.Conv2d(c, c, 3, 2, 1),
16
+ nn.PReLU(c),
17
+ )
18
+
19
+ self.convblocks = nn.ModuleList()
20
+ for i in range(block_nums):
21
+ self.convblocks.append(nn.Sequential(
22
+ nn.Conv2d(c, c, 3, 1, 1),
23
+ nn.PReLU(c),
24
+ nn.Conv2d(c, c, 3, 1, 1),
25
+ nn.PReLU(c),
26
+ nn.Conv2d(c, c, 3, 1, 1),
27
+ nn.PReLU(c),
28
+ nn.Conv2d(c, c, 3, 1, 1),
29
+ nn.PReLU(c),
30
+ nn.Conv2d(c, c, 3, 1, 1),
31
+ nn.PReLU(c),
32
+ nn.Conv2d(c, c, 3, 1, 1),
33
+ nn.PReLU(c),
34
+ ))
35
+ self.out_flows = 2
36
+ self.out_msgs = out_msgs
37
+ self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
38
+ self.out_locals = out_locals
39
+ self.out_local_flows = out_local_flows if out_locals > 0 else 0
40
+ self.out_masks = out_masks
41
+ self.out_feat_flows = out_feat_flows
42
+
43
+ self.conv_last = nn.Sequential(
44
+ nn.ConvTranspose2d(c, c, 4, 2, 1),
45
+ nn.PReLU(c),
46
+ nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
47
+ self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
48
+ )
49
+
50
+ def forward(self, accumulated_flow, *other):
51
+ x = [accumulated_flow]
52
+ for each in other:
53
+ if each is not None:
54
+ assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
55
+ accumulated_flow.shape, each.shape)
56
+ x.append(each)
57
+ feat = self.conv0(torch.cat(x, dim=1))
58
+ for convblock1 in self.convblocks:
59
+ feat = convblock1(feat) + feat
60
+ feat = self.conv_last(feat)
61
+ prev = 0
62
+ flow = feat[:, prev:prev+self.out_flows, :, :]
63
+ prev += self.out_flows
64
+ message = feat[:, prev:prev+self.out_msgs,
65
+ :, :] if self.out_msgs > 0 else None
66
+ prev += self.out_msgs
67
+ message_flow = feat[:, prev:prev + self.out_msgs_flows,
68
+ :, :] if self.out_msgs_flows > 0 else None
69
+ prev += self.out_msgs_flows
70
+ local_message = feat[:, prev:prev + self.out_locals,
71
+ :, :] if self.out_locals > 0 else None
72
+ prev += self.out_locals
73
+ local_message_flow = feat[:, prev:prev+self.out_local_flows,
74
+ :, :] if self.out_local_flows > 0 else None
75
+ prev += self.out_local_flows
76
+ mask = torch.sigmoid(
77
+ feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
78
+ prev += self.out_masks
79
+ feat_flow = feat[:, prev:prev+self.out_feat_flows,
80
+ :, :] if self.out_feat_flows > 0 else None
81
+ prev += self.out_feat_flows
82
+ return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow
83
+
84
+
85
+ class CINN(nn.Module):
86
+ def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
87
+ super(CINN, self).__init__()
88
+
89
+ self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
90
+ self.in_locals_chn = [0, *out_locals_chn[:-1]]
91
+
92
+ self.decoder_blocks = nn.ModuleList()
93
+ self.feed_weighted = True
94
+ if self.feed_weighted:
95
+ in_planes = 2+2+DIM_SHADER_REFERENCE*2
96
+ else:
97
+ in_planes = 2+DIM_SHADER_REFERENCE
98
+ for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
99
+ self.decoder_blocks.append(
100
+ DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
101
+ for i in range(len(feature_chns), len(out_locals_chn)):
102
+ #print("append extra block", i, "msg",
103
+ # out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
104
+ self.decoder_blocks.append(
105
+ DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))
106
+
107
+ def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
108
+ if each_x_reference_features is not None:
109
+ size_from = each_x_reference_features
110
+ else:
111
+ size_from = x_reference
112
+ f_size = (size_from.shape[2], size_from.shape[3])
113
+ accumulated_flow = self.flow_rescale(
114
+ accumulated_flow, size_from)
115
+ # mask = warp_features(F.interpolate(
116
+ # mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
117
+ mask = F.interpolate(
118
+ mask, size=f_size, mode="bilinear") if mask is not None else None
119
+ message = F.interpolate(
120
+ message, size=f_size, mode="bilinear") if message is not None else None
121
+ message_flow = self.flow_rescale(
122
+ message_flow, size_from) if message_flow is not None else None
123
+ message = warp_features(
124
+ message, message_flow) if message_flow is not None else message
125
+
126
+ local_message = F.interpolate(
127
+ local_message, size=f_size, mode="bilinear") if local_message is not None else None
128
+ local_message_flow = self.flow_rescale(
129
+ local_message_flow, size_from) if local_message_flow is not None else None
130
+ local_message = warp_features(
131
+ local_message, local_message_flow) if local_message_flow is not None else local_message
132
+
133
+ warp_x_reference = warp_features(F.interpolate(
134
+ x_reference, size=f_size, mode="bilinear"), accumulated_flow)
135
+
136
+ each_x_reference_features_flow = self.flow_rescale(
137
+ each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
138
+ warp_each_x_reference_features = warp_features(
139
+ each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features
140
+
141
+ return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow
142
+
143
+ def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
144
+ y_flow = []
145
+ y_feat_flow = []
146
+
147
+ y_local_message = []
148
+ y_warp_x_reference = []
149
+ y_warp_x_reference_features = []
150
+
151
+ y_weighted_flow = []
152
+ y_weighted_mask = []
153
+ y_weighted_message = []
154
+ y_weighted_x_reference = []
155
+ y_weighted_x_reference_features = []
156
+
157
+ for pyrlevel, ifblock in enumerate(self.decoder_blocks):
158
+ stacked_wref = []
159
+ stacked_feat = []
160
+ stacked_anci = []
161
+ stacked_flow = []
162
+ stacked_mask = []
163
+ stacked_mesg = []
164
+ stacked_locm = []
165
+ stacked_feat_flow = []
166
+ for view_id in range(x_reference.shape[1]): # NMCHW
167
+
168
+ if pyrlevel == 0:
169
+ # create from zero flow
170
+ feat_ev = x_reference_features[pyrlevel][:,
171
+ view_id, :, :, :] if pyrlevel < len(x_reference_features) else None
172
+
173
+ accumulated_flow = torch.zeros_like(
174
+ feat_ev[:, :2, :, :]).to(device)
175
+ accumulated_feat_flow = torch.zeros_like(
176
+ feat_ev[:, :32, :, :]).to(device)
177
+ # domestic inputs
178
+ warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
179
+ feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
180
+ warp_x_reference_features = feat_ev
181
+
182
+ local_message = None
183
+ # federated inputs
184
+ weighted_flow = accumulated_flow if self.feed_weighted else None
185
+ weighted_wref = warp_x_reference if self.feed_weighted else None
186
+ weighted_message = None
187
+ else:
188
+ # resume from last layer
189
+ accumulated_flow = y_flow[-1][:, view_id, :, :, :]
190
+ accumulated_feat_flow = y_feat_flow[-1][:,
191
+ view_id, :, :, :] if y_feat_flow[-1] is not None else None
192
+ # domestic inputs
193
+ warp_x_reference = y_warp_x_reference[-1][:,
194
+ view_id, :, :, :]
195
+ warp_x_reference_features = y_warp_x_reference_features[-1][:,
196
+ view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
197
+ local_message = y_local_message[-1][:, view_id, :,
198
+ :, :] if len(y_local_message) > 0 else None
199
+
200
+ # federated inputs
201
+ weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
202
+ weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
203
+ weighted_message = y_weighted_message[-1] if len(
204
+ y_weighted_message) > 0 else None
205
+ scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
206
+ x_target_features) else None
207
+ # compute flow
208
+ residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
209
+ accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
210
+ accumulated_flow = residual_flow + accumulated_flow
211
+ accumulated_feat_flow = accumulated_flow
212
+
213
+ feat_ev = x_reference_features[pyrlevel+1][:,
214
+ view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
215
+ mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
216
+ mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
217
+ stacked_flow.append(accumulated_flow)
218
+ if accumulated_feat_flow is not None:
219
+ stacked_feat_flow.append(accumulated_feat_flow)
220
+ stacked_mask.append(mask)
221
+ if message is not None:
222
+ stacked_mesg.append(message)
223
+ if local_message is not None:
224
+ stacked_locm.append(local_message)
225
+ stacked_wref.append(warp_x_reference)
226
+ if warp_x_reference_features is not None:
227
+ stacked_feat.append(warp_x_reference_features)
228
+
229
+ stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW
230
+ stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
231
+ stacked_feat_flow) > 0 else None
232
+ stacked_mask = torch.stack(
233
+ stacked_mask, dim=1)
234
+
235
+ stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
236
+ stacked_mesg) > 0 else None
237
+ stacked_locm = torch.stack(stacked_locm, dim=1) if len(
238
+ stacked_locm) > 0 else None
239
+
240
+ stacked_wref = torch.stack(stacked_wref, dim=1)
241
+ stacked_feat = torch.stack(stacked_feat, dim=1) if len(
242
+ stacked_feat) > 0 else None
243
+ stacked_anci = torch.stack(stacked_anci, dim=1) if len(
244
+ stacked_anci) > 0 else None
245
+ y_flow.append(stacked_flow)
246
+ y_feat_flow.append(stacked_feat_flow)
247
+
248
+ y_warp_x_reference.append(stacked_wref)
249
+ y_warp_x_reference_features.append(stacked_feat)
250
+ # compute normalized confidence
251
+ stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)
252
+
253
+ # torch.sum to remove temp dimension M from NMCHW --> NCHW
254
+ weighted_flow = torch.sum(
255
+ stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
256
+ weighted_mask = torch.sum(
257
+ stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
258
+ weighted_wref = torch.sum(
259
+ stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
260
+ weighted_feat = torch.sum(
261
+ stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
262
+ weighted_mesg = torch.sum(
263
+ stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
264
+ y_weighted_flow.append(weighted_flow)
265
+ y_weighted_mask.append(weighted_mask)
266
+ if weighted_mesg is not None:
267
+ y_weighted_message.append(weighted_mesg)
268
+ if stacked_locm is not None:
269
+ y_local_message.append(stacked_locm)
270
+ y_weighted_message.append(weighted_mesg)
271
+ y_weighted_x_reference.append(weighted_wref)
272
+ y_weighted_x_reference_features.append(weighted_feat)
273
+
274
+ if weighted_feat is not None:
275
+ y_weighted_x_reference_features.append(weighted_feat)
276
+ return {
277
+ "y_last_remote_features": [weighted_mesg],
278
+ }
279
+
280
+ def flow_rescale(self, prev_flow, each_x_reference_features):
281
+ if prev_flow is None:
282
+ prev_flow = torch.zeros_like(
283
+ each_x_reference_features[:, :2]).to(device)
284
+ else:
285
+ up_scale_factor = each_x_reference_features.shape[-1] / \
286
+ prev_flow.shape[-1]
287
+ if up_scale_factor != 1:
288
+ prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
289
+ align_corners=False, recompute_scale_factor=False) * up_scale_factor
290
+ return prev_flow
model/warplayer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ with torch.cuda.amp.autocast(enabled=False):
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ if tenInput.dtype != g.dtype:
24
+ g = g.to(tenInput.dtype)
25
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
26
+ # "zeros" "border"
27
+
28
+
29
+ def warp_features(inp, flow, ):
30
+ groups = flow.shape[1]//2 # NCHW
31
+ samples = inp.shape[0]
32
+ h = inp.shape[2]
33
+ w = inp.shape[3]
34
+ assert(flow.shape[0] == samples and flow.shape[2]
35
+ == h and flow.shape[3] == w)
36
+ chns = inp.shape[1]
37
+ chns_per_group = chns // groups
38
+ assert(flow.shape[1] % 2 == 0)
39
+ assert(chns % groups == 0)
40
+ inp = inp.contiguous().view(samples*groups, chns_per_group, h, w)
41
+ flow = flow.contiguous().view(samples*groups, 2, h, w)
42
+ feat = warp(inp, flow)
43
+ feat = feat.view(samples, chns, h, w)
44
+ return feat
45
+
46
+
47
+ def flow2rgb(flow_map_np):
48
+ h, w, _ = flow_map_np.shape
49
+ rgb_map = np.ones((h, w, 3)).astype(np.float32)/2.0
50
+ normalized_flow_map = np.concatenate(
51
+ (flow_map_np[:, :, 0:1]/h/2.0, flow_map_np[:, :, 1:2]/w/2.0), axis=2)
52
+ rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
53
+ rgb_map[:, :, 1] -= 0.5 * \
54
+ (normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
55
+ rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
56
+ return (rgb_map.clip(0, 1)*255.0).astype(np.uint8)
notebooks/conr.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/conr_chinese.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
poses_template.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d6a3ed9527dd6b04275bf4c81704ac50c9b8cd1bf4a26b1d3ec6658e200ff57
3
+ size 10662874
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pillow>=7.2.0
2
+ numpy>=1.16
3
+ tqdm>=4.35.0
4
+ torch>=1.3.0
5
+ opencv-python>=4.5.2
6
+ scikit-image>=0.14.0
7
+ torchvision>=0.2.1
8
+ lpips>=0.1.3
9
+ gdown
train.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from distutils.util import strtobool
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms
11
+ from data_loader import (FileDataset,
12
+ RandomResizedCropWithAutoCenteringAndZeroPadding)
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from conr import CoNR
15
+ from tqdm import tqdm
16
+
17
+ def data_sampler(dataset, shuffle, distributed):
18
+
19
+ if distributed:
20
+ return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
21
+
22
+ if shuffle:
23
+ return torch.utils.data.RandomSampler(dataset)
24
+
25
+ else:
26
+ return torch.utils.data.SequentialSampler(dataset)
27
+
28
+ def save_output(image_name, inputs_v, d_dir=".", crop=None):
29
+ import cv2
30
+
31
+ inputs_v = inputs_v.detach().squeeze()
32
+ input_np = torch.clamp(inputs_v*255, 0, 255).byte().cpu().numpy().transpose(
33
+ (1, 2, 0))
34
+ # cv2.setNumThreads(1)
35
+ out_render_scale = cv2.cvtColor(input_np, cv2.COLOR_RGBA2BGRA)
36
+ if crop is not None:
37
+ crop = crop.cpu().numpy()[0]
38
+ output_img = np.zeros((crop[0], crop[1], 4), dtype=np.uint8)
39
+ before_resize_scale = cv2.resize(
40
+ out_render_scale, (crop[5]-crop[4]+crop[8]+crop[9], crop[3]-crop[2]+crop[6]+crop[7]), interpolation=cv2.INTER_AREA) # w,h
41
+ output_img[crop[2]:crop[3], crop[4]:crop[5]] = before_resize_scale[crop[6]:before_resize_scale.shape[0] -
42
+ crop[7], crop[8]:before_resize_scale.shape[1]-crop[9]]
43
+ else:
44
+ output_img = out_render_scale
45
+ cv2.imwrite(d_dir+"/"+image_name.split(os.sep)[-1]+'.png',
46
+ output_img
47
+ )
48
+
49
+
50
+ def test():
51
+ source_names_list = []
52
+ for name in sorted(os.listdir(args.test_input_person_images)):
53
+ thissource = os.path.join(args.test_input_person_images, name)
54
+ if os.path.isfile(thissource):
55
+ source_names_list.append(thissource)
56
+ if os.path.isdir(thissource):
57
+ print("skipping empty folder :"+thissource)
58
+
59
+ image_names_list = []
60
+ for name in sorted(os.listdir(args.test_input_poses_images)):
61
+ thistarget = os.path.join(args.test_input_poses_images, name)
62
+ if os.path.isfile(thistarget):
63
+ image_names_list.append([thistarget, *source_names_list])
64
+ if os.path.isdir(thistarget):
65
+ print("skipping folder :"+thistarget)
66
+ print(image_names_list)
67
+
68
+ print("---building models")
69
+ conrmodel = CoNR(args)
70
+ conrmodel.load_model(path=args.test_checkpoint_dir)
71
+ conrmodel.dist()
72
+ infer(args, conrmodel, image_names_list)
73
+
74
+
75
+ def infer(args, humanflowmodel, image_names_list):
76
+ print("---test images: ", len(image_names_list))
77
+ test_salobj_dataset = FileDataset(image_names_list=image_names_list,
78
+ fg_img_lbl_transform=transforms.Compose([
79
+ RandomResizedCropWithAutoCenteringAndZeroPadding(
80
+ (args.dataloader_imgsize, args.dataloader_imgsize), scale=(1, 1), ratio=(1.0, 1.0), center_jitter=(0.0, 0.0)
81
+ )]),
82
+ shader_pose_use_gt_udp_test=not args.test_pose_use_parser_udp,
83
+ shader_target_use_gt_rgb_debug=False
84
+ )
85
+ sampler = data_sampler(test_salobj_dataset, shuffle=False,
86
+ distributed=args.distributed)
87
+ train_data = DataLoader(test_salobj_dataset,
88
+ batch_size=1,
89
+ shuffle=False,sampler=sampler,
90
+ num_workers=args.dataloaders)
91
+
92
+ # start testing
93
+
94
+ train_num = train_data.__len__()
95
+ time_stamp = time.time()
96
+ prev_frame_rgb = []
97
+ prev_frame_a = []
98
+
99
+ pbar = tqdm(range(train_num), ncols=100)
100
+ for i, data in enumerate(train_data):
101
+ data_time_interval = time.time() - time_stamp
102
+ time_stamp = time.time()
103
+ with torch.no_grad():
104
+ data["character_images"] = torch.cat(
105
+ [data["character_images"], *prev_frame_rgb], dim=1)
106
+ data["character_masks"] = torch.cat(
107
+ [data["character_masks"], *prev_frame_a], dim=1)
108
+ data = humanflowmodel.data_norm_image(data)
109
+ pred = humanflowmodel.model_step(data, training=False)
110
+ # remember to call humanflowmodel.reset_charactersheet() if you change character .
111
+
112
+ train_time_interval = time.time() - time_stamp
113
+ time_stamp = time.time()
114
+ if args.local_rank == 0:
115
+ pbar.set_description(f"Epoch {i}/{train_num}")
116
+ pbar.set_postfix({"data_time": data_time_interval, "train_time":train_time_interval})
117
+ pbar.update(1)
118
+
119
+ with torch.no_grad():
120
+
121
+ if args.test_output_video:
122
+ pred_img = pred["shader"]["y_weighted_warp_decoded_rgba"]
123
+ save_output(
124
+ str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir, crop=data["pose_crop"])
125
+
126
+ if args.test_output_udp:
127
+ pred_img = pred["shader"]["x_target_sudp_a"]
128
+ save_output(
129
+ "udp_"+str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir)
130
+
131
+
132
+ def build_args():
133
+ parser = argparse.ArgumentParser()
134
+ # distributed learning settings
135
+ parser.add_argument("--world_size", type=int, default=1,
136
+ help='world size')
137
+ parser.add_argument("--local_rank", type=int, default=0,
138
+ help='local_rank, DON\'T change it')
139
+
140
+ # model settings
141
+ parser.add_argument('--dataloader_imgsize', type=int, default=256,
142
+ help='Input image size of the model')
143
+ parser.add_argument('--batch_size', type=int, default=4,
144
+ help='minibatch size')
145
+ parser.add_argument('--model_name', default='model_result',
146
+ help='Name of the experiment')
147
+ parser.add_argument('--dataloaders', type=int, default=2,
148
+ help='Num of dataloaders')
149
+ parser.add_argument('--mode', default="test", choices=['train', 'test'],
150
+ help='Training mode or Testing mode')
151
+
152
+ # i/o settings
153
+ parser.add_argument('--test_input_person_images',
154
+ type=str, default="./character_sheet/",
155
+ help='Directory to input character sheets')
156
+ parser.add_argument('--test_input_poses_images', type=str,
157
+ default="./test_data/",
158
+ help='Directory to input UDP sequences or pose images')
159
+ parser.add_argument('--test_checkpoint_dir', type=str,
160
+ default='./weights/',
161
+ help='Directory to model weights')
162
+ parser.add_argument('--test_output_dir', type=str,
163
+ default="./results/",
164
+ help='Directory to output images')
165
+
166
+ # output content settings
167
+ parser.add_argument('--test_output_video', type=strtobool, default=True,
168
+ help='Whether to output the final result of CoNR, \
169
+ images will be output to test_output_dir while True.')
170
+ parser.add_argument('--test_output_udp', type=strtobool, default=False,
171
+ help='Whether to output UDP generated from UDP detector, \
172
+ this is meaningful ONLY when test_input_poses_images \
173
+ is not UDP sequences but pose images. Meanwhile, \
174
+ test_pose_use_parser_udp need to be True')
175
+
176
+ # UDP detector settings
177
+ parser.add_argument('--test_pose_use_parser_udp',
178
+ type=strtobool, default=False,
179
+ help='Whether to use UDP detector to generate UDP from pngs, \
180
+ pose input MUST be pose images instead of UDP sequences \
181
+ while True')
182
+
183
+ args = parser.parse_args()
184
+
185
+ args.distributed = (args.world_size > 1)
186
+ if args.local_rank == 0:
187
+ print("batch_size:", args.batch_size, flush=True)
188
+ if args.distributed:
189
+ if args.local_rank == 0:
190
+ print("world_size: ", args.world_size)
191
+ torch.distributed.init_process_group(
192
+ backend="nccl", init_method="env://", world_size=args.world_size)
193
+ torch.cuda.set_device(args.local_rank)
194
+ torch.backends.cudnn.benchmark = True
195
+ else:
196
+ args.local_rank = 0
197
+
198
+ return args
199
+
200
+
201
+ if __name__ == "__main__":
202
+ args = build_args()
203
+ test()
weights/.gitkeep ADDED
File without changes
weights/rgbadecodernet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1cc48cf4b3b6f7c4856bec8299e57466836eff3bffa73f518965fdb75bc16fd
3
+ size 341897
weights/shader.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c8fa74e07db0e0a853fe2aa8d299c854b5c957037dcf4858dc0ca755bec9a95
3
+ size 384615535
weights/target_pose_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a077437779043b543b015d56ca4f521668645c9ad1cd67ee843aa8a94bf59034
3
+ size 107891883
weights/udpparsernet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13bb149d18146277cfb11f90b96ae0ecb25828a1ce57ae47ba04d5468cd3322e
3
+ size 228957615