Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +31 -0
- models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/.gitattributes +35 -0
- models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/LICENSE +201 -0
- models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/README.md +65 -0
- models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/config.json +29 -0
- models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/diffusion_pytorch_model.safetensors +3 -0
- models/CogVideo/cogvideox_vae.safetensors +3 -0
- models/InfiniteYou/aes_stage2_img_proj.bin +3 -0
- models/InfiniteYou/sim_stage1_img_proj.bin +3 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/.gitattributes +35 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/README.md +71 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/added_tokens.json +1026 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/config.json +138 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/configuration_florence2.py +340 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/generation_config.json +4 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/merges.txt +0 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/model.safetensors +3 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/modeling_florence2.py +0 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/preprocessor_config.json +33 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/processing_florence2.py +1088 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/special_tokens_map.json +0 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/tokenizer.json +0 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/tokenizer_config.json +0 -0
- models/LLM/Florence-2-large-PromptGen-v2.0/vocab.json +0 -0
- models/RMBG/BEN2/BEN2.py +1401 -0
- models/RMBG/BEN2/BEN2_Base.pth +3 -0
- models/RMBG/BEN2/__pycache__/BEN2.cpython-310.pyc +0 -0
- models/RMBG/RMBG-2.0/BiRefNet_config.py +11 -0
- models/RMBG/RMBG-2.0/birefnet.py +2244 -0
- models/RMBG/RMBG-2.0/config.json +20 -0
- models/RMBG/RMBG-2.0/model.safetensors +3 -0
- models/TTS/DiffRhythm/.gitattributes +36 -0
- models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-MuLan-large/1d2f0a1aedbc66ea23e7fef7985c875c3e98c08d.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-MuLan-large/d42ae3f7cb9b66759ee0089ddc70e2f28b130c2d8ba621457358272d32dd0444.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-large-msd-iter/334df3de2832ec1acfd8b6ce54e7de4073401fe821f7ec0ad0d954832be2d26a.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-large-msd-iter/fec6c73f7b811281b440462fcf4d98c7953c3d94.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/1960141250d189366dfb76630ba794a9c104ec07.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/34ddbd64a4cd3f2d9d8a9120d3662d0bf91baead.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/463f3414782c1c9405828c9b31bfa36dda1f45c5.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/6fd4797bc397c3b8b55d6bb5740366b57e6a3ce91c04c77f22aafc0c128e6feb.lock +0 -0
- models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/db9af13bf09fd3028ca32be90d3fb66d5e470399.lock +0 -0
- models/TTS/DiffRhythm/LICENSE +201 -0
- models/TTS/DiffRhythm/LICENSE.md +58 -0
- models/TTS/DiffRhythm/README.md +76 -0
- models/TTS/DiffRhythm/cfm_full_model.pt +3 -0
- models/TTS/DiffRhythm/cfm_model.pt +3 -0
- models/TTS/DiffRhythm/config.json +13 -0
- models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/.no_exist/8a081dbcf84edd47ea7db3c4ecb8fd1ec1ddacfe/model.safetensors +0 -0
- models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/blobs/1d2f0a1aedbc66ea23e7fef7985c875c3e98c08d +41 -0
- models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/blobs/d42ae3f7cb9b66759ee0089ddc70e2f28b130c2d8ba621457358272d32dd0444 +3 -0
.gitattributes
CHANGED
@@ -872,3 +872,34 @@ custom_nodes/was-node-suite-comfyui/repos/SAM/assets/notebook2.png filter=lfs di
|
|
872 |
custom_nodes/was-node-suite-comfyui/repos/SAM/demo/src/assets/data/dogs.jpg filter=lfs diff=lfs merge=lfs -text
|
873 |
custom_nodes/was-node-suite-comfyui/repos/SAM/notebooks/images/groceries.jpg filter=lfs diff=lfs merge=lfs -text
|
874 |
custom_nodes/was-node-suite-comfyui/repos/SAM/notebooks/images/truck.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
872 |
custom_nodes/was-node-suite-comfyui/repos/SAM/demo/src/assets/data/dogs.jpg filter=lfs diff=lfs merge=lfs -text
|
873 |
custom_nodes/was-node-suite-comfyui/repos/SAM/notebooks/images/groceries.jpg filter=lfs diff=lfs merge=lfs -text
|
874 |
custom_nodes/was-node-suite-comfyui/repos/SAM/notebooks/images/truck.jpg filter=lfs diff=lfs merge=lfs -text
|
875 |
+
models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/blobs/d42ae3f7cb9b66759ee0089ddc70e2f28b130c2d8ba621457358272d32dd0444 filter=lfs diff=lfs merge=lfs -text
|
876 |
+
models/TTS/DiffRhythm/models--OpenMuQ--MuQ-large-msd-iter/blobs/334df3de2832ec1acfd8b6ce54e7de4073401fe821f7ec0ad0d954832be2d26a filter=lfs diff=lfs merge=lfs -text
|
877 |
+
models/TTS/DiffRhythm/models--xlm-roberta-base/blobs/6fd4797bc397c3b8b55d6bb5740366b57e6a3ce91c04c77f22aafc0c128e6feb filter=lfs diff=lfs merge=lfs -text
|
878 |
+
models/TTS/DiffRhythm/models--xlm-roberta-base/blobs/db9af13bf09fd3028ca32be90d3fb66d5e470399 filter=lfs diff=lfs merge=lfs -text
|
879 |
+
models/TTS/DiffRhythm/src/ASLP.jpg filter=lfs diff=lfs merge=lfs -text
|
880 |
+
models/TTS/DiffRhythm/src/diffrhythm.jpg filter=lfs diff=lfs merge=lfs -text
|
881 |
+
models/blip/models--Salesforce--blip-image-captioning-base/blobs/9339497cee045b8434a4ebf8f5a30e2f83984e7695a53030e99283a5786693d9 filter=lfs diff=lfs merge=lfs -text
|
882 |
+
models/diffusers/models--ZhengPeng7--BiRefNet/blobs/9ab37426bf4de0567af6b5d21b16151357149139362e6e8992021b8ce356a154 filter=lfs diff=lfs merge=lfs -text
|
883 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/0d9c7c663217d1c3d44a6deed4e1cf1ac09fbc2c4137c47de1e3d74c959833de filter=lfs diff=lfs merge=lfs -text
|
884 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/37a587d0ff3d9dda0d8ab59d65342c0242ffb909573d8d998d599e3401d3d7e9 filter=lfs diff=lfs merge=lfs -text
|
885 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/5e830704a83aa938dfaf23da308100a1c44b83fa084283abf1d163ea727e5f7a filter=lfs diff=lfs merge=lfs -text
|
886 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/893d67a23f4693ed42cdab4cbad7fe3e727cf59609c40da28a46b5470f9ed082 filter=lfs diff=lfs merge=lfs -text
|
887 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/a5640855b301fcdbceddfa90ae8066cd9414aff020552a201a255ecf2059da00 filter=lfs diff=lfs merge=lfs -text
|
888 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/afc8e28272cd15db3919bacdb6918ce9c1ed22e96cb12c4d5ed0fba823529e38 filter=lfs diff=lfs merge=lfs -text
|
889 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86 filter=lfs diff=lfs merge=lfs -text
|
890 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/d86a3038eacaa720682cb9b1da3c49fecf8a3ded605af4def6061eaa18903eb8 filter=lfs diff=lfs merge=lfs -text
|
891 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/ec87bffd1923e8b2774a6d240c922a41f6143081d52cf83b8fe39e9d838c893e filter=lfs diff=lfs merge=lfs -text
|
892 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/blobs/f5b59a26851551b67ae1fe58d32e76486e1e812def4696a4bea97f16604d40a3 filter=lfs diff=lfs merge=lfs -text
|
893 |
+
models/diffusers/models--black-forest-labs--FLUX.1-dev/snapshots/0ef5fff789c832c5c7f4e127f94c8b54bbcced44/dev_grid.jpg filter=lfs diff=lfs merge=lfs -text
|
894 |
+
models/diffusers/models--madebyollin--sdxl-vae-fp16-fix/blobs/1b909373b28f2137098b0fd9dbc6f97f8410854f31f84ddc9fa04b077b0ace2c filter=lfs diff=lfs merge=lfs -text
|
895 |
+
models/diffusers/models--stabilityai--stable-diffusion-xl-base-1.0/blobs/1598f3d24932bcfe6634e8b618ea1e30ab1d57f5aad13a6d2de446d2199f2341 filter=lfs diff=lfs merge=lfs -text
|
896 |
+
models/diffusers/models--stabilityai--stable-diffusion-xl-base-1.0/blobs/27ed3b02e09638568e99d4398c67bc654dde04e6c0db61fb2d21dba630e7058a filter=lfs diff=lfs merge=lfs -text
|
897 |
+
models/diffusers/models--stabilityai--stable-diffusion-xl-base-1.0/blobs/357650fbfb3c7b4d94c1f5fd7664da819ad1ff5a839430484b4ec422d03f710a filter=lfs diff=lfs merge=lfs -text
|
898 |
+
models/diffusers/models--stabilityai--stable-diffusion-xl-base-1.0/blobs/3a6032f63d37ae02bbc74ccd6a27440578cd71701f96532229d0154f55a8d3ff filter=lfs diff=lfs merge=lfs -text
|
899 |
+
models/diffusers/models--stabilityai--stable-diffusion-xl-base-1.0/blobs/5c3d6454dd2d23414b56aa1b5858a72487a656937847b6fea8d0606d7a42cdbc filter=lfs diff=lfs merge=lfs -text
|
900 |
+
models/echo_mimic/pose/sapiens_1b_goliath_best_goliath_AP_639_torchscript.pt2 filter=lfs diff=lfs merge=lfs -text
|
901 |
+
models/inpaint/inpaint_v26.fooocus.patch filter=lfs diff=lfs merge=lfs -text
|
902 |
+
models/text_encoders/models--openai--clip-vit-large-patch14/blobs/a2bf730a0c7debf160f7a6b50b3aaf3703e7e88ac73de7a314903141db026dcb filter=lfs diff=lfs merge=lfs -text
|
903 |
+
models/text_encoders/models--xlabs-ai--xflux_text_encoders/blobs/a5640855b301fcdbceddfa90ae8066cd9414aff020552a201a255ecf2059da00 filter=lfs diff=lfs merge=lfs -text
|
904 |
+
models/text_encoders/models--xlabs-ai--xflux_text_encoders/blobs/d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86 filter=lfs diff=lfs merge=lfs -text
|
905 |
+
models/text_encoders/models--xlabs-ai--xflux_text_encoders/blobs/ec87bffd1923e8b2774a6d240c922a41f6143081d52cf83b8fe39e9d838c893e filter=lfs diff=lfs merge=lfs -text
|
models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2024 CogVideo Model Team @ Zhipu AI
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- cogvideox
|
7 |
+
- video-generation
|
8 |
+
- video-to-video
|
9 |
+
- controlnet
|
10 |
+
- diffusers
|
11 |
+
---
|
12 |
+
|
13 |
+
|
14 |
+
<video controls autoplay src="https://cdn-uploads.huggingface.co/production/uploads/63fde49f6315a264aba6a7ed/VFtwr_VimGF6g51PGQYwN.mp4"></video>
|
15 |
+
|
16 |
+
<video controls autoplay src="https://cdn-uploads.huggingface.co/production/uploads/63fde49f6315a264aba6a7ed/YaCSr74Iiw6nuqtT1Gtei.mp4"></video>
|
17 |
+
|
18 |
+
### ComfyUI
|
19 |
+
<a href="https://github.com/kijai/ComfyUI-CogVideoXWrapper">ComfyUI-CogVideoXWrapper
|
20 |
+
</a> supports controlnet pipeline. See an <a href="https://github.com/kijai/ComfyUI-CogVideoXWrapper/blob/main/examples/cogvideox_1_0_2b_controlnet_02.json">example
|
21 |
+
</a> file.
|
22 |
+
|
23 |
+
### How to
|
24 |
+
Clone repo
|
25 |
+
```bash
|
26 |
+
git clone https://github.com/TheDenk/cogvideox-controlnet.git
|
27 |
+
cd cogvideox-controlnet
|
28 |
+
```
|
29 |
+
|
30 |
+
Create venv
|
31 |
+
```bash
|
32 |
+
python -m venv venv
|
33 |
+
source venv/bin/activate
|
34 |
+
```
|
35 |
+
|
36 |
+
Install requirements
|
37 |
+
```bash
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
### Inference examples
|
42 |
+
#### Inference with cli
|
43 |
+
```bash
|
44 |
+
python -m inference.cli_demo \
|
45 |
+
--video_path "resources/car.mp4" \
|
46 |
+
--prompt "car is moving among mountains" \
|
47 |
+
--controlnet_type "hed" \
|
48 |
+
--base_model_path THUDM/CogVideoX-2b \
|
49 |
+
--controlnet_model_path TheDenk/cogvideox-2b-controlnet-hed-v1
|
50 |
+
```
|
51 |
+
|
52 |
+
#### Inference with Gradio
|
53 |
+
```bash
|
54 |
+
python -m inference.gradio_web_demo \
|
55 |
+
--controlnet_type "hed" \
|
56 |
+
--base_model_path THUDM/CogVideoX-2b \
|
57 |
+
--controlnet_model_path TheDenk/cogvideox-2b-controlnet-hed-v1
|
58 |
+
```
|
59 |
+
|
60 |
+
|
61 |
+
## Acknowledgements
|
62 |
+
Original code and models [CogVideoX](https://github.com/THUDM/CogVideo/tree/main).
|
63 |
+
|
64 |
+
## Contacts
|
65 |
+
<p>Issues should be raised directly in the repository. For professional support and recommendations please <a>[email protected]</a>.</p>
|
models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "CogVideoXControlnet",
|
3 |
+
"_diffusers_version": "0.31.0.dev0",
|
4 |
+
"activation_fn": "gelu-approximate",
|
5 |
+
"attention_bias": true,
|
6 |
+
"attention_head_dim": 64,
|
7 |
+
"downscale_coef": 8,
|
8 |
+
"dropout": 0.0,
|
9 |
+
"flip_sin_to_cos": true,
|
10 |
+
"freq_shift": 0,
|
11 |
+
"in_channels": 3,
|
12 |
+
"max_text_seq_length": 226,
|
13 |
+
"norm_elementwise_affine": true,
|
14 |
+
"norm_eps": 1e-05,
|
15 |
+
"num_attention_heads": 30,
|
16 |
+
"num_layers": 8,
|
17 |
+
"patch_size": 2,
|
18 |
+
"sample_frames": 49,
|
19 |
+
"sample_height": 60,
|
20 |
+
"sample_width": 90,
|
21 |
+
"spatial_interpolation_scale": 1.875,
|
22 |
+
"temporal_compression_ratio": 4,
|
23 |
+
"temporal_interpolation_scale": 1.0,
|
24 |
+
"time_embed_dim": 512,
|
25 |
+
"timestep_activation_fn": "silu",
|
26 |
+
"use_learned_positional_embeddings": false,
|
27 |
+
"use_rotary_positional_embeddings": false,
|
28 |
+
"vae_channels": 16
|
29 |
+
}
|
models/CogVideo/ControlNet/cogvideox-2b-controlnet-hed-v1/diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ffa03a9155ebaeac449b6656420117f8a52c04a6337523e10582d2f657631634
|
3 |
+
size 1833149512
|
models/CogVideo/cogvideox_vae.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd47d57ad948ff80da0af0cb2e4dcdef65073aba59bccfd383ada9a7d1c02024
|
3 |
+
size 431221142
|
models/InfiniteYou/aes_stage2_img_proj.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f85518431d7367de30b9558a939c003462a7e331e8c8b929146916dc27e471d6
|
3 |
+
size 338413026
|
models/InfiniteYou/sim_stage1_img_proj.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7a8a1b6fecf2731b2ac64bde33a1885014949c8c08f3efaebd0665bb0e8ad8f
|
3 |
+
size 338413026
|
models/LLM/Florence-2-large-PromptGen-v2.0/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
models/LLM/Florence-2-large-PromptGen-v2.0/README.md
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
4 |
+
# Florence-2-large-PromptGen v2.0
|
5 |
+
This upgrade is based on PromptGen 1.5 with some new features to the model:
|
6 |
+
|
7 |
+
## Features:
|
8 |
+
* Improved caption quality for \<GENERATE_TAGS\>, \<DETAILED_CAPTION\> and \<MORE_DETAILED_CAPTION\>.
|
9 |
+
<img style="width:100%; hight:100%" src="https://msdn.miaoshouai.com/miaoshou/bo/2024-11-05_03-15-15.png" />
|
10 |
+
<img style="width:100%; hight:100%" src="https://msdn.miaoshouai.com/miaoshou/bo/2024-11-05_03-40-29.png" />
|
11 |
+
* A new \<ANALYZE\> instruction, which helps the model to better understands the image composition of the input image.
|
12 |
+
<img style="width:100%; hight:100%" src="https://msdn.miaoshouai.com/miaoshou/bo/2024-11-05_03-42-58.png" />
|
13 |
+
<img style="width:100%; hight:100%" src="https://msdn.miaoshouai.com/miaoshou/bo/2024-11-05_07-42-36.png" />
|
14 |
+
* Memory efficient compare to other models! This is a really light weight caption model that allows you to use a little more than 1G of VRAM and produce lightening fast and high quality image captions.
|
15 |
+
<img style="width:100%; hight:100%" src="https://msdn.miaoshouai.com/miaoshou/bo/2024-09-05_12-56-39.png" />
|
16 |
+
* Designed to handle image captions for Flux model for both T5XXL CLIP and CLIP_L, the Miaoshou Tagger new node called "Flux CLIP Text Encode" which eliminates the need to run two separate tagger tools for caption creation. You can easily populate both CLIPs in a single generation, significantly boosting speed when working with Flux models.
|
17 |
+
<img style="width:100%; hight:100%" src="https://msdn.miaoshouai.com/miaoshou/bo/2024-09-05_14-11-02.png" />
|
18 |
+
|
19 |
+
## Instruction prompt:
|
20 |
+
\<GENERATE_TAGS\> generate prompt as danbooru style tags<br>
|
21 |
+
\<CAPTION\> a one line caption for the image<br>
|
22 |
+
\<DETAILED_CAPTION\> a structured caption format which detects the position of the subjects in the image<br>
|
23 |
+
\<MORE_DETAILED_CAPTION\> a very detailed description for the image<br>
|
24 |
+
\<ANALYZE\> image composition analysis mode<br>
|
25 |
+
\<MIXED_CAPTION\> a mixed caption style of more detailed caption and tags, this is extremely useful for FLUX model when using T5XXL and CLIP_L together. A new node in MiaoshouTagger ComfyUI is added to support this instruction.<br>
|
26 |
+
\<MIXED_CAPTION_PLUS\> Combine the power of mixed caption with analyze.<br>
|
27 |
+
|
28 |
+
## Version History:
|
29 |
+
For version 2.0, you will notice the following
|
30 |
+
1. \<ANALYZE\> along with a beta node in ComfyUI for partial image analysis
|
31 |
+
2. A new instruction for \<MIXED_CAPTION_PLUS\>
|
32 |
+
3. A much improve accuracy for \<GENERATE_TAGS\>, \<DETAILED_CAPTION\> and \<MORE_DETAILED_CAPTION\>
|
33 |
+
|
34 |
+
|
35 |
+
## How to use:
|
36 |
+
|
37 |
+
To use this model, you can load it directly from the Hugging Face Model Hub:
|
38 |
+
|
39 |
+
```python
|
40 |
+
|
41 |
+
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-large-PromptGen-v2.0", trust_remote_code=True)
|
42 |
+
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-large-PromptGen-v2.0", trust_remote_code=True)
|
43 |
+
|
44 |
+
prompt = "<MORE_DETAILED_CAPTION>"
|
45 |
+
|
46 |
+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
|
47 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
48 |
+
|
49 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
|
50 |
+
|
51 |
+
generated_ids = model.generate(
|
52 |
+
input_ids=inputs["input_ids"],
|
53 |
+
pixel_values=inputs["pixel_values"],
|
54 |
+
max_new_tokens=1024,
|
55 |
+
do_sample=False,
|
56 |
+
num_beams=3
|
57 |
+
)
|
58 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
59 |
+
|
60 |
+
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
|
61 |
+
|
62 |
+
print(parsed_answer)
|
63 |
+
```
|
64 |
+
|
65 |
+
## Use under MiaoshouAI Tagger ComfyUI
|
66 |
+
If you just want to use this model, you can use it under ComfyUI-Miaoshouai-Tagger
|
67 |
+
|
68 |
+
https://github.com/miaoshouai/ComfyUI-Miaoshouai-Tagger
|
69 |
+
|
70 |
+
A detailed use and install instruction is already there.
|
71 |
+
(If you have already installed MiaoshouAI Tagger, you need to update the node in ComfyUI Manager first or use git pull to get the latest update.)
|
models/LLM/Florence-2-large-PromptGen-v2.0/added_tokens.json
ADDED
@@ -0,0 +1,1026 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</cap>": 51270,
|
3 |
+
"</dcap>": 51274,
|
4 |
+
"</grounding>": 51276,
|
5 |
+
"</ncap>": 51272,
|
6 |
+
"</ocr>": 50268,
|
7 |
+
"</od>": 50266,
|
8 |
+
"</poly>": 51287,
|
9 |
+
"</proposal>": 51285,
|
10 |
+
"</region_cap>": 51281,
|
11 |
+
"</region_to_desciption>": 51283,
|
12 |
+
"</seg>": 51278,
|
13 |
+
"<and>": 51288,
|
14 |
+
"<cap>": 51269,
|
15 |
+
"<dcap>": 51273,
|
16 |
+
"<grounding>": 51275,
|
17 |
+
"<loc_0>": 50269,
|
18 |
+
"<loc_100>": 50369,
|
19 |
+
"<loc_101>": 50370,
|
20 |
+
"<loc_102>": 50371,
|
21 |
+
"<loc_103>": 50372,
|
22 |
+
"<loc_104>": 50373,
|
23 |
+
"<loc_105>": 50374,
|
24 |
+
"<loc_106>": 50375,
|
25 |
+
"<loc_107>": 50376,
|
26 |
+
"<loc_108>": 50377,
|
27 |
+
"<loc_109>": 50378,
|
28 |
+
"<loc_10>": 50279,
|
29 |
+
"<loc_110>": 50379,
|
30 |
+
"<loc_111>": 50380,
|
31 |
+
"<loc_112>": 50381,
|
32 |
+
"<loc_113>": 50382,
|
33 |
+
"<loc_114>": 50383,
|
34 |
+
"<loc_115>": 50384,
|
35 |
+
"<loc_116>": 50385,
|
36 |
+
"<loc_117>": 50386,
|
37 |
+
"<loc_118>": 50387,
|
38 |
+
"<loc_119>": 50388,
|
39 |
+
"<loc_11>": 50280,
|
40 |
+
"<loc_120>": 50389,
|
41 |
+
"<loc_121>": 50390,
|
42 |
+
"<loc_122>": 50391,
|
43 |
+
"<loc_123>": 50392,
|
44 |
+
"<loc_124>": 50393,
|
45 |
+
"<loc_125>": 50394,
|
46 |
+
"<loc_126>": 50395,
|
47 |
+
"<loc_127>": 50396,
|
48 |
+
"<loc_128>": 50397,
|
49 |
+
"<loc_129>": 50398,
|
50 |
+
"<loc_12>": 50281,
|
51 |
+
"<loc_130>": 50399,
|
52 |
+
"<loc_131>": 50400,
|
53 |
+
"<loc_132>": 50401,
|
54 |
+
"<loc_133>": 50402,
|
55 |
+
"<loc_134>": 50403,
|
56 |
+
"<loc_135>": 50404,
|
57 |
+
"<loc_136>": 50405,
|
58 |
+
"<loc_137>": 50406,
|
59 |
+
"<loc_138>": 50407,
|
60 |
+
"<loc_139>": 50408,
|
61 |
+
"<loc_13>": 50282,
|
62 |
+
"<loc_140>": 50409,
|
63 |
+
"<loc_141>": 50410,
|
64 |
+
"<loc_142>": 50411,
|
65 |
+
"<loc_143>": 50412,
|
66 |
+
"<loc_144>": 50413,
|
67 |
+
"<loc_145>": 50414,
|
68 |
+
"<loc_146>": 50415,
|
69 |
+
"<loc_147>": 50416,
|
70 |
+
"<loc_148>": 50417,
|
71 |
+
"<loc_149>": 50418,
|
72 |
+
"<loc_14>": 50283,
|
73 |
+
"<loc_150>": 50419,
|
74 |
+
"<loc_151>": 50420,
|
75 |
+
"<loc_152>": 50421,
|
76 |
+
"<loc_153>": 50422,
|
77 |
+
"<loc_154>": 50423,
|
78 |
+
"<loc_155>": 50424,
|
79 |
+
"<loc_156>": 50425,
|
80 |
+
"<loc_157>": 50426,
|
81 |
+
"<loc_158>": 50427,
|
82 |
+
"<loc_159>": 50428,
|
83 |
+
"<loc_15>": 50284,
|
84 |
+
"<loc_160>": 50429,
|
85 |
+
"<loc_161>": 50430,
|
86 |
+
"<loc_162>": 50431,
|
87 |
+
"<loc_163>": 50432,
|
88 |
+
"<loc_164>": 50433,
|
89 |
+
"<loc_165>": 50434,
|
90 |
+
"<loc_166>": 50435,
|
91 |
+
"<loc_167>": 50436,
|
92 |
+
"<loc_168>": 50437,
|
93 |
+
"<loc_169>": 50438,
|
94 |
+
"<loc_16>": 50285,
|
95 |
+
"<loc_170>": 50439,
|
96 |
+
"<loc_171>": 50440,
|
97 |
+
"<loc_172>": 50441,
|
98 |
+
"<loc_173>": 50442,
|
99 |
+
"<loc_174>": 50443,
|
100 |
+
"<loc_175>": 50444,
|
101 |
+
"<loc_176>": 50445,
|
102 |
+
"<loc_177>": 50446,
|
103 |
+
"<loc_178>": 50447,
|
104 |
+
"<loc_179>": 50448,
|
105 |
+
"<loc_17>": 50286,
|
106 |
+
"<loc_180>": 50449,
|
107 |
+
"<loc_181>": 50450,
|
108 |
+
"<loc_182>": 50451,
|
109 |
+
"<loc_183>": 50452,
|
110 |
+
"<loc_184>": 50453,
|
111 |
+
"<loc_185>": 50454,
|
112 |
+
"<loc_186>": 50455,
|
113 |
+
"<loc_187>": 50456,
|
114 |
+
"<loc_188>": 50457,
|
115 |
+
"<loc_189>": 50458,
|
116 |
+
"<loc_18>": 50287,
|
117 |
+
"<loc_190>": 50459,
|
118 |
+
"<loc_191>": 50460,
|
119 |
+
"<loc_192>": 50461,
|
120 |
+
"<loc_193>": 50462,
|
121 |
+
"<loc_194>": 50463,
|
122 |
+
"<loc_195>": 50464,
|
123 |
+
"<loc_196>": 50465,
|
124 |
+
"<loc_197>": 50466,
|
125 |
+
"<loc_198>": 50467,
|
126 |
+
"<loc_199>": 50468,
|
127 |
+
"<loc_19>": 50288,
|
128 |
+
"<loc_1>": 50270,
|
129 |
+
"<loc_200>": 50469,
|
130 |
+
"<loc_201>": 50470,
|
131 |
+
"<loc_202>": 50471,
|
132 |
+
"<loc_203>": 50472,
|
133 |
+
"<loc_204>": 50473,
|
134 |
+
"<loc_205>": 50474,
|
135 |
+
"<loc_206>": 50475,
|
136 |
+
"<loc_207>": 50476,
|
137 |
+
"<loc_208>": 50477,
|
138 |
+
"<loc_209>": 50478,
|
139 |
+
"<loc_20>": 50289,
|
140 |
+
"<loc_210>": 50479,
|
141 |
+
"<loc_211>": 50480,
|
142 |
+
"<loc_212>": 50481,
|
143 |
+
"<loc_213>": 50482,
|
144 |
+
"<loc_214>": 50483,
|
145 |
+
"<loc_215>": 50484,
|
146 |
+
"<loc_216>": 50485,
|
147 |
+
"<loc_217>": 50486,
|
148 |
+
"<loc_218>": 50487,
|
149 |
+
"<loc_219>": 50488,
|
150 |
+
"<loc_21>": 50290,
|
151 |
+
"<loc_220>": 50489,
|
152 |
+
"<loc_221>": 50490,
|
153 |
+
"<loc_222>": 50491,
|
154 |
+
"<loc_223>": 50492,
|
155 |
+
"<loc_224>": 50493,
|
156 |
+
"<loc_225>": 50494,
|
157 |
+
"<loc_226>": 50495,
|
158 |
+
"<loc_227>": 50496,
|
159 |
+
"<loc_228>": 50497,
|
160 |
+
"<loc_229>": 50498,
|
161 |
+
"<loc_22>": 50291,
|
162 |
+
"<loc_230>": 50499,
|
163 |
+
"<loc_231>": 50500,
|
164 |
+
"<loc_232>": 50501,
|
165 |
+
"<loc_233>": 50502,
|
166 |
+
"<loc_234>": 50503,
|
167 |
+
"<loc_235>": 50504,
|
168 |
+
"<loc_236>": 50505,
|
169 |
+
"<loc_237>": 50506,
|
170 |
+
"<loc_238>": 50507,
|
171 |
+
"<loc_239>": 50508,
|
172 |
+
"<loc_23>": 50292,
|
173 |
+
"<loc_240>": 50509,
|
174 |
+
"<loc_241>": 50510,
|
175 |
+
"<loc_242>": 50511,
|
176 |
+
"<loc_243>": 50512,
|
177 |
+
"<loc_244>": 50513,
|
178 |
+
"<loc_245>": 50514,
|
179 |
+
"<loc_246>": 50515,
|
180 |
+
"<loc_247>": 50516,
|
181 |
+
"<loc_248>": 50517,
|
182 |
+
"<loc_249>": 50518,
|
183 |
+
"<loc_24>": 50293,
|
184 |
+
"<loc_250>": 50519,
|
185 |
+
"<loc_251>": 50520,
|
186 |
+
"<loc_252>": 50521,
|
187 |
+
"<loc_253>": 50522,
|
188 |
+
"<loc_254>": 50523,
|
189 |
+
"<loc_255>": 50524,
|
190 |
+
"<loc_256>": 50525,
|
191 |
+
"<loc_257>": 50526,
|
192 |
+
"<loc_258>": 50527,
|
193 |
+
"<loc_259>": 50528,
|
194 |
+
"<loc_25>": 50294,
|
195 |
+
"<loc_260>": 50529,
|
196 |
+
"<loc_261>": 50530,
|
197 |
+
"<loc_262>": 50531,
|
198 |
+
"<loc_263>": 50532,
|
199 |
+
"<loc_264>": 50533,
|
200 |
+
"<loc_265>": 50534,
|
201 |
+
"<loc_266>": 50535,
|
202 |
+
"<loc_267>": 50536,
|
203 |
+
"<loc_268>": 50537,
|
204 |
+
"<loc_269>": 50538,
|
205 |
+
"<loc_26>": 50295,
|
206 |
+
"<loc_270>": 50539,
|
207 |
+
"<loc_271>": 50540,
|
208 |
+
"<loc_272>": 50541,
|
209 |
+
"<loc_273>": 50542,
|
210 |
+
"<loc_274>": 50543,
|
211 |
+
"<loc_275>": 50544,
|
212 |
+
"<loc_276>": 50545,
|
213 |
+
"<loc_277>": 50546,
|
214 |
+
"<loc_278>": 50547,
|
215 |
+
"<loc_279>": 50548,
|
216 |
+
"<loc_27>": 50296,
|
217 |
+
"<loc_280>": 50549,
|
218 |
+
"<loc_281>": 50550,
|
219 |
+
"<loc_282>": 50551,
|
220 |
+
"<loc_283>": 50552,
|
221 |
+
"<loc_284>": 50553,
|
222 |
+
"<loc_285>": 50554,
|
223 |
+
"<loc_286>": 50555,
|
224 |
+
"<loc_287>": 50556,
|
225 |
+
"<loc_288>": 50557,
|
226 |
+
"<loc_289>": 50558,
|
227 |
+
"<loc_28>": 50297,
|
228 |
+
"<loc_290>": 50559,
|
229 |
+
"<loc_291>": 50560,
|
230 |
+
"<loc_292>": 50561,
|
231 |
+
"<loc_293>": 50562,
|
232 |
+
"<loc_294>": 50563,
|
233 |
+
"<loc_295>": 50564,
|
234 |
+
"<loc_296>": 50565,
|
235 |
+
"<loc_297>": 50566,
|
236 |
+
"<loc_298>": 50567,
|
237 |
+
"<loc_299>": 50568,
|
238 |
+
"<loc_29>": 50298,
|
239 |
+
"<loc_2>": 50271,
|
240 |
+
"<loc_300>": 50569,
|
241 |
+
"<loc_301>": 50570,
|
242 |
+
"<loc_302>": 50571,
|
243 |
+
"<loc_303>": 50572,
|
244 |
+
"<loc_304>": 50573,
|
245 |
+
"<loc_305>": 50574,
|
246 |
+
"<loc_306>": 50575,
|
247 |
+
"<loc_307>": 50576,
|
248 |
+
"<loc_308>": 50577,
|
249 |
+
"<loc_309>": 50578,
|
250 |
+
"<loc_30>": 50299,
|
251 |
+
"<loc_310>": 50579,
|
252 |
+
"<loc_311>": 50580,
|
253 |
+
"<loc_312>": 50581,
|
254 |
+
"<loc_313>": 50582,
|
255 |
+
"<loc_314>": 50583,
|
256 |
+
"<loc_315>": 50584,
|
257 |
+
"<loc_316>": 50585,
|
258 |
+
"<loc_317>": 50586,
|
259 |
+
"<loc_318>": 50587,
|
260 |
+
"<loc_319>": 50588,
|
261 |
+
"<loc_31>": 50300,
|
262 |
+
"<loc_320>": 50589,
|
263 |
+
"<loc_321>": 50590,
|
264 |
+
"<loc_322>": 50591,
|
265 |
+
"<loc_323>": 50592,
|
266 |
+
"<loc_324>": 50593,
|
267 |
+
"<loc_325>": 50594,
|
268 |
+
"<loc_326>": 50595,
|
269 |
+
"<loc_327>": 50596,
|
270 |
+
"<loc_328>": 50597,
|
271 |
+
"<loc_329>": 50598,
|
272 |
+
"<loc_32>": 50301,
|
273 |
+
"<loc_330>": 50599,
|
274 |
+
"<loc_331>": 50600,
|
275 |
+
"<loc_332>": 50601,
|
276 |
+
"<loc_333>": 50602,
|
277 |
+
"<loc_334>": 50603,
|
278 |
+
"<loc_335>": 50604,
|
279 |
+
"<loc_336>": 50605,
|
280 |
+
"<loc_337>": 50606,
|
281 |
+
"<loc_338>": 50607,
|
282 |
+
"<loc_339>": 50608,
|
283 |
+
"<loc_33>": 50302,
|
284 |
+
"<loc_340>": 50609,
|
285 |
+
"<loc_341>": 50610,
|
286 |
+
"<loc_342>": 50611,
|
287 |
+
"<loc_343>": 50612,
|
288 |
+
"<loc_344>": 50613,
|
289 |
+
"<loc_345>": 50614,
|
290 |
+
"<loc_346>": 50615,
|
291 |
+
"<loc_347>": 50616,
|
292 |
+
"<loc_348>": 50617,
|
293 |
+
"<loc_349>": 50618,
|
294 |
+
"<loc_34>": 50303,
|
295 |
+
"<loc_350>": 50619,
|
296 |
+
"<loc_351>": 50620,
|
297 |
+
"<loc_352>": 50621,
|
298 |
+
"<loc_353>": 50622,
|
299 |
+
"<loc_354>": 50623,
|
300 |
+
"<loc_355>": 50624,
|
301 |
+
"<loc_356>": 50625,
|
302 |
+
"<loc_357>": 50626,
|
303 |
+
"<loc_358>": 50627,
|
304 |
+
"<loc_359>": 50628,
|
305 |
+
"<loc_35>": 50304,
|
306 |
+
"<loc_360>": 50629,
|
307 |
+
"<loc_361>": 50630,
|
308 |
+
"<loc_362>": 50631,
|
309 |
+
"<loc_363>": 50632,
|
310 |
+
"<loc_364>": 50633,
|
311 |
+
"<loc_365>": 50634,
|
312 |
+
"<loc_366>": 50635,
|
313 |
+
"<loc_367>": 50636,
|
314 |
+
"<loc_368>": 50637,
|
315 |
+
"<loc_369>": 50638,
|
316 |
+
"<loc_36>": 50305,
|
317 |
+
"<loc_370>": 50639,
|
318 |
+
"<loc_371>": 50640,
|
319 |
+
"<loc_372>": 50641,
|
320 |
+
"<loc_373>": 50642,
|
321 |
+
"<loc_374>": 50643,
|
322 |
+
"<loc_375>": 50644,
|
323 |
+
"<loc_376>": 50645,
|
324 |
+
"<loc_377>": 50646,
|
325 |
+
"<loc_378>": 50647,
|
326 |
+
"<loc_379>": 50648,
|
327 |
+
"<loc_37>": 50306,
|
328 |
+
"<loc_380>": 50649,
|
329 |
+
"<loc_381>": 50650,
|
330 |
+
"<loc_382>": 50651,
|
331 |
+
"<loc_383>": 50652,
|
332 |
+
"<loc_384>": 50653,
|
333 |
+
"<loc_385>": 50654,
|
334 |
+
"<loc_386>": 50655,
|
335 |
+
"<loc_387>": 50656,
|
336 |
+
"<loc_388>": 50657,
|
337 |
+
"<loc_389>": 50658,
|
338 |
+
"<loc_38>": 50307,
|
339 |
+
"<loc_390>": 50659,
|
340 |
+
"<loc_391>": 50660,
|
341 |
+
"<loc_392>": 50661,
|
342 |
+
"<loc_393>": 50662,
|
343 |
+
"<loc_394>": 50663,
|
344 |
+
"<loc_395>": 50664,
|
345 |
+
"<loc_396>": 50665,
|
346 |
+
"<loc_397>": 50666,
|
347 |
+
"<loc_398>": 50667,
|
348 |
+
"<loc_399>": 50668,
|
349 |
+
"<loc_39>": 50308,
|
350 |
+
"<loc_3>": 50272,
|
351 |
+
"<loc_400>": 50669,
|
352 |
+
"<loc_401>": 50670,
|
353 |
+
"<loc_402>": 50671,
|
354 |
+
"<loc_403>": 50672,
|
355 |
+
"<loc_404>": 50673,
|
356 |
+
"<loc_405>": 50674,
|
357 |
+
"<loc_406>": 50675,
|
358 |
+
"<loc_407>": 50676,
|
359 |
+
"<loc_408>": 50677,
|
360 |
+
"<loc_409>": 50678,
|
361 |
+
"<loc_40>": 50309,
|
362 |
+
"<loc_410>": 50679,
|
363 |
+
"<loc_411>": 50680,
|
364 |
+
"<loc_412>": 50681,
|
365 |
+
"<loc_413>": 50682,
|
366 |
+
"<loc_414>": 50683,
|
367 |
+
"<loc_415>": 50684,
|
368 |
+
"<loc_416>": 50685,
|
369 |
+
"<loc_417>": 50686,
|
370 |
+
"<loc_418>": 50687,
|
371 |
+
"<loc_419>": 50688,
|
372 |
+
"<loc_41>": 50310,
|
373 |
+
"<loc_420>": 50689,
|
374 |
+
"<loc_421>": 50690,
|
375 |
+
"<loc_422>": 50691,
|
376 |
+
"<loc_423>": 50692,
|
377 |
+
"<loc_424>": 50693,
|
378 |
+
"<loc_425>": 50694,
|
379 |
+
"<loc_426>": 50695,
|
380 |
+
"<loc_427>": 50696,
|
381 |
+
"<loc_428>": 50697,
|
382 |
+
"<loc_429>": 50698,
|
383 |
+
"<loc_42>": 50311,
|
384 |
+
"<loc_430>": 50699,
|
385 |
+
"<loc_431>": 50700,
|
386 |
+
"<loc_432>": 50701,
|
387 |
+
"<loc_433>": 50702,
|
388 |
+
"<loc_434>": 50703,
|
389 |
+
"<loc_435>": 50704,
|
390 |
+
"<loc_436>": 50705,
|
391 |
+
"<loc_437>": 50706,
|
392 |
+
"<loc_438>": 50707,
|
393 |
+
"<loc_439>": 50708,
|
394 |
+
"<loc_43>": 50312,
|
395 |
+
"<loc_440>": 50709,
|
396 |
+
"<loc_441>": 50710,
|
397 |
+
"<loc_442>": 50711,
|
398 |
+
"<loc_443>": 50712,
|
399 |
+
"<loc_444>": 50713,
|
400 |
+
"<loc_445>": 50714,
|
401 |
+
"<loc_446>": 50715,
|
402 |
+
"<loc_447>": 50716,
|
403 |
+
"<loc_448>": 50717,
|
404 |
+
"<loc_449>": 50718,
|
405 |
+
"<loc_44>": 50313,
|
406 |
+
"<loc_450>": 50719,
|
407 |
+
"<loc_451>": 50720,
|
408 |
+
"<loc_452>": 50721,
|
409 |
+
"<loc_453>": 50722,
|
410 |
+
"<loc_454>": 50723,
|
411 |
+
"<loc_455>": 50724,
|
412 |
+
"<loc_456>": 50725,
|
413 |
+
"<loc_457>": 50726,
|
414 |
+
"<loc_458>": 50727,
|
415 |
+
"<loc_459>": 50728,
|
416 |
+
"<loc_45>": 50314,
|
417 |
+
"<loc_460>": 50729,
|
418 |
+
"<loc_461>": 50730,
|
419 |
+
"<loc_462>": 50731,
|
420 |
+
"<loc_463>": 50732,
|
421 |
+
"<loc_464>": 50733,
|
422 |
+
"<loc_465>": 50734,
|
423 |
+
"<loc_466>": 50735,
|
424 |
+
"<loc_467>": 50736,
|
425 |
+
"<loc_468>": 50737,
|
426 |
+
"<loc_469>": 50738,
|
427 |
+
"<loc_46>": 50315,
|
428 |
+
"<loc_470>": 50739,
|
429 |
+
"<loc_471>": 50740,
|
430 |
+
"<loc_472>": 50741,
|
431 |
+
"<loc_473>": 50742,
|
432 |
+
"<loc_474>": 50743,
|
433 |
+
"<loc_475>": 50744,
|
434 |
+
"<loc_476>": 50745,
|
435 |
+
"<loc_477>": 50746,
|
436 |
+
"<loc_478>": 50747,
|
437 |
+
"<loc_479>": 50748,
|
438 |
+
"<loc_47>": 50316,
|
439 |
+
"<loc_480>": 50749,
|
440 |
+
"<loc_481>": 50750,
|
441 |
+
"<loc_482>": 50751,
|
442 |
+
"<loc_483>": 50752,
|
443 |
+
"<loc_484>": 50753,
|
444 |
+
"<loc_485>": 50754,
|
445 |
+
"<loc_486>": 50755,
|
446 |
+
"<loc_487>": 50756,
|
447 |
+
"<loc_488>": 50757,
|
448 |
+
"<loc_489>": 50758,
|
449 |
+
"<loc_48>": 50317,
|
450 |
+
"<loc_490>": 50759,
|
451 |
+
"<loc_491>": 50760,
|
452 |
+
"<loc_492>": 50761,
|
453 |
+
"<loc_493>": 50762,
|
454 |
+
"<loc_494>": 50763,
|
455 |
+
"<loc_495>": 50764,
|
456 |
+
"<loc_496>": 50765,
|
457 |
+
"<loc_497>": 50766,
|
458 |
+
"<loc_498>": 50767,
|
459 |
+
"<loc_499>": 50768,
|
460 |
+
"<loc_49>": 50318,
|
461 |
+
"<loc_4>": 50273,
|
462 |
+
"<loc_500>": 50769,
|
463 |
+
"<loc_501>": 50770,
|
464 |
+
"<loc_502>": 50771,
|
465 |
+
"<loc_503>": 50772,
|
466 |
+
"<loc_504>": 50773,
|
467 |
+
"<loc_505>": 50774,
|
468 |
+
"<loc_506>": 50775,
|
469 |
+
"<loc_507>": 50776,
|
470 |
+
"<loc_508>": 50777,
|
471 |
+
"<loc_509>": 50778,
|
472 |
+
"<loc_50>": 50319,
|
473 |
+
"<loc_510>": 50779,
|
474 |
+
"<loc_511>": 50780,
|
475 |
+
"<loc_512>": 50781,
|
476 |
+
"<loc_513>": 50782,
|
477 |
+
"<loc_514>": 50783,
|
478 |
+
"<loc_515>": 50784,
|
479 |
+
"<loc_516>": 50785,
|
480 |
+
"<loc_517>": 50786,
|
481 |
+
"<loc_518>": 50787,
|
482 |
+
"<loc_519>": 50788,
|
483 |
+
"<loc_51>": 50320,
|
484 |
+
"<loc_520>": 50789,
|
485 |
+
"<loc_521>": 50790,
|
486 |
+
"<loc_522>": 50791,
|
487 |
+
"<loc_523>": 50792,
|
488 |
+
"<loc_524>": 50793,
|
489 |
+
"<loc_525>": 50794,
|
490 |
+
"<loc_526>": 50795,
|
491 |
+
"<loc_527>": 50796,
|
492 |
+
"<loc_528>": 50797,
|
493 |
+
"<loc_529>": 50798,
|
494 |
+
"<loc_52>": 50321,
|
495 |
+
"<loc_530>": 50799,
|
496 |
+
"<loc_531>": 50800,
|
497 |
+
"<loc_532>": 50801,
|
498 |
+
"<loc_533>": 50802,
|
499 |
+
"<loc_534>": 50803,
|
500 |
+
"<loc_535>": 50804,
|
501 |
+
"<loc_536>": 50805,
|
502 |
+
"<loc_537>": 50806,
|
503 |
+
"<loc_538>": 50807,
|
504 |
+
"<loc_539>": 50808,
|
505 |
+
"<loc_53>": 50322,
|
506 |
+
"<loc_540>": 50809,
|
507 |
+
"<loc_541>": 50810,
|
508 |
+
"<loc_542>": 50811,
|
509 |
+
"<loc_543>": 50812,
|
510 |
+
"<loc_544>": 50813,
|
511 |
+
"<loc_545>": 50814,
|
512 |
+
"<loc_546>": 50815,
|
513 |
+
"<loc_547>": 50816,
|
514 |
+
"<loc_548>": 50817,
|
515 |
+
"<loc_549>": 50818,
|
516 |
+
"<loc_54>": 50323,
|
517 |
+
"<loc_550>": 50819,
|
518 |
+
"<loc_551>": 50820,
|
519 |
+
"<loc_552>": 50821,
|
520 |
+
"<loc_553>": 50822,
|
521 |
+
"<loc_554>": 50823,
|
522 |
+
"<loc_555>": 50824,
|
523 |
+
"<loc_556>": 50825,
|
524 |
+
"<loc_557>": 50826,
|
525 |
+
"<loc_558>": 50827,
|
526 |
+
"<loc_559>": 50828,
|
527 |
+
"<loc_55>": 50324,
|
528 |
+
"<loc_560>": 50829,
|
529 |
+
"<loc_561>": 50830,
|
530 |
+
"<loc_562>": 50831,
|
531 |
+
"<loc_563>": 50832,
|
532 |
+
"<loc_564>": 50833,
|
533 |
+
"<loc_565>": 50834,
|
534 |
+
"<loc_566>": 50835,
|
535 |
+
"<loc_567>": 50836,
|
536 |
+
"<loc_568>": 50837,
|
537 |
+
"<loc_569>": 50838,
|
538 |
+
"<loc_56>": 50325,
|
539 |
+
"<loc_570>": 50839,
|
540 |
+
"<loc_571>": 50840,
|
541 |
+
"<loc_572>": 50841,
|
542 |
+
"<loc_573>": 50842,
|
543 |
+
"<loc_574>": 50843,
|
544 |
+
"<loc_575>": 50844,
|
545 |
+
"<loc_576>": 50845,
|
546 |
+
"<loc_577>": 50846,
|
547 |
+
"<loc_578>": 50847,
|
548 |
+
"<loc_579>": 50848,
|
549 |
+
"<loc_57>": 50326,
|
550 |
+
"<loc_580>": 50849,
|
551 |
+
"<loc_581>": 50850,
|
552 |
+
"<loc_582>": 50851,
|
553 |
+
"<loc_583>": 50852,
|
554 |
+
"<loc_584>": 50853,
|
555 |
+
"<loc_585>": 50854,
|
556 |
+
"<loc_586>": 50855,
|
557 |
+
"<loc_587>": 50856,
|
558 |
+
"<loc_588>": 50857,
|
559 |
+
"<loc_589>": 50858,
|
560 |
+
"<loc_58>": 50327,
|
561 |
+
"<loc_590>": 50859,
|
562 |
+
"<loc_591>": 50860,
|
563 |
+
"<loc_592>": 50861,
|
564 |
+
"<loc_593>": 50862,
|
565 |
+
"<loc_594>": 50863,
|
566 |
+
"<loc_595>": 50864,
|
567 |
+
"<loc_596>": 50865,
|
568 |
+
"<loc_597>": 50866,
|
569 |
+
"<loc_598>": 50867,
|
570 |
+
"<loc_599>": 50868,
|
571 |
+
"<loc_59>": 50328,
|
572 |
+
"<loc_5>": 50274,
|
573 |
+
"<loc_600>": 50869,
|
574 |
+
"<loc_601>": 50870,
|
575 |
+
"<loc_602>": 50871,
|
576 |
+
"<loc_603>": 50872,
|
577 |
+
"<loc_604>": 50873,
|
578 |
+
"<loc_605>": 50874,
|
579 |
+
"<loc_606>": 50875,
|
580 |
+
"<loc_607>": 50876,
|
581 |
+
"<loc_608>": 50877,
|
582 |
+
"<loc_609>": 50878,
|
583 |
+
"<loc_60>": 50329,
|
584 |
+
"<loc_610>": 50879,
|
585 |
+
"<loc_611>": 50880,
|
586 |
+
"<loc_612>": 50881,
|
587 |
+
"<loc_613>": 50882,
|
588 |
+
"<loc_614>": 50883,
|
589 |
+
"<loc_615>": 50884,
|
590 |
+
"<loc_616>": 50885,
|
591 |
+
"<loc_617>": 50886,
|
592 |
+
"<loc_618>": 50887,
|
593 |
+
"<loc_619>": 50888,
|
594 |
+
"<loc_61>": 50330,
|
595 |
+
"<loc_620>": 50889,
|
596 |
+
"<loc_621>": 50890,
|
597 |
+
"<loc_622>": 50891,
|
598 |
+
"<loc_623>": 50892,
|
599 |
+
"<loc_624>": 50893,
|
600 |
+
"<loc_625>": 50894,
|
601 |
+
"<loc_626>": 50895,
|
602 |
+
"<loc_627>": 50896,
|
603 |
+
"<loc_628>": 50897,
|
604 |
+
"<loc_629>": 50898,
|
605 |
+
"<loc_62>": 50331,
|
606 |
+
"<loc_630>": 50899,
|
607 |
+
"<loc_631>": 50900,
|
608 |
+
"<loc_632>": 50901,
|
609 |
+
"<loc_633>": 50902,
|
610 |
+
"<loc_634>": 50903,
|
611 |
+
"<loc_635>": 50904,
|
612 |
+
"<loc_636>": 50905,
|
613 |
+
"<loc_637>": 50906,
|
614 |
+
"<loc_638>": 50907,
|
615 |
+
"<loc_639>": 50908,
|
616 |
+
"<loc_63>": 50332,
|
617 |
+
"<loc_640>": 50909,
|
618 |
+
"<loc_641>": 50910,
|
619 |
+
"<loc_642>": 50911,
|
620 |
+
"<loc_643>": 50912,
|
621 |
+
"<loc_644>": 50913,
|
622 |
+
"<loc_645>": 50914,
|
623 |
+
"<loc_646>": 50915,
|
624 |
+
"<loc_647>": 50916,
|
625 |
+
"<loc_648>": 50917,
|
626 |
+
"<loc_649>": 50918,
|
627 |
+
"<loc_64>": 50333,
|
628 |
+
"<loc_650>": 50919,
|
629 |
+
"<loc_651>": 50920,
|
630 |
+
"<loc_652>": 50921,
|
631 |
+
"<loc_653>": 50922,
|
632 |
+
"<loc_654>": 50923,
|
633 |
+
"<loc_655>": 50924,
|
634 |
+
"<loc_656>": 50925,
|
635 |
+
"<loc_657>": 50926,
|
636 |
+
"<loc_658>": 50927,
|
637 |
+
"<loc_659>": 50928,
|
638 |
+
"<loc_65>": 50334,
|
639 |
+
"<loc_660>": 50929,
|
640 |
+
"<loc_661>": 50930,
|
641 |
+
"<loc_662>": 50931,
|
642 |
+
"<loc_663>": 50932,
|
643 |
+
"<loc_664>": 50933,
|
644 |
+
"<loc_665>": 50934,
|
645 |
+
"<loc_666>": 50935,
|
646 |
+
"<loc_667>": 50936,
|
647 |
+
"<loc_668>": 50937,
|
648 |
+
"<loc_669>": 50938,
|
649 |
+
"<loc_66>": 50335,
|
650 |
+
"<loc_670>": 50939,
|
651 |
+
"<loc_671>": 50940,
|
652 |
+
"<loc_672>": 50941,
|
653 |
+
"<loc_673>": 50942,
|
654 |
+
"<loc_674>": 50943,
|
655 |
+
"<loc_675>": 50944,
|
656 |
+
"<loc_676>": 50945,
|
657 |
+
"<loc_677>": 50946,
|
658 |
+
"<loc_678>": 50947,
|
659 |
+
"<loc_679>": 50948,
|
660 |
+
"<loc_67>": 50336,
|
661 |
+
"<loc_680>": 50949,
|
662 |
+
"<loc_681>": 50950,
|
663 |
+
"<loc_682>": 50951,
|
664 |
+
"<loc_683>": 50952,
|
665 |
+
"<loc_684>": 50953,
|
666 |
+
"<loc_685>": 50954,
|
667 |
+
"<loc_686>": 50955,
|
668 |
+
"<loc_687>": 50956,
|
669 |
+
"<loc_688>": 50957,
|
670 |
+
"<loc_689>": 50958,
|
671 |
+
"<loc_68>": 50337,
|
672 |
+
"<loc_690>": 50959,
|
673 |
+
"<loc_691>": 50960,
|
674 |
+
"<loc_692>": 50961,
|
675 |
+
"<loc_693>": 50962,
|
676 |
+
"<loc_694>": 50963,
|
677 |
+
"<loc_695>": 50964,
|
678 |
+
"<loc_696>": 50965,
|
679 |
+
"<loc_697>": 50966,
|
680 |
+
"<loc_698>": 50967,
|
681 |
+
"<loc_699>": 50968,
|
682 |
+
"<loc_69>": 50338,
|
683 |
+
"<loc_6>": 50275,
|
684 |
+
"<loc_700>": 50969,
|
685 |
+
"<loc_701>": 50970,
|
686 |
+
"<loc_702>": 50971,
|
687 |
+
"<loc_703>": 50972,
|
688 |
+
"<loc_704>": 50973,
|
689 |
+
"<loc_705>": 50974,
|
690 |
+
"<loc_706>": 50975,
|
691 |
+
"<loc_707>": 50976,
|
692 |
+
"<loc_708>": 50977,
|
693 |
+
"<loc_709>": 50978,
|
694 |
+
"<loc_70>": 50339,
|
695 |
+
"<loc_710>": 50979,
|
696 |
+
"<loc_711>": 50980,
|
697 |
+
"<loc_712>": 50981,
|
698 |
+
"<loc_713>": 50982,
|
699 |
+
"<loc_714>": 50983,
|
700 |
+
"<loc_715>": 50984,
|
701 |
+
"<loc_716>": 50985,
|
702 |
+
"<loc_717>": 50986,
|
703 |
+
"<loc_718>": 50987,
|
704 |
+
"<loc_719>": 50988,
|
705 |
+
"<loc_71>": 50340,
|
706 |
+
"<loc_720>": 50989,
|
707 |
+
"<loc_721>": 50990,
|
708 |
+
"<loc_722>": 50991,
|
709 |
+
"<loc_723>": 50992,
|
710 |
+
"<loc_724>": 50993,
|
711 |
+
"<loc_725>": 50994,
|
712 |
+
"<loc_726>": 50995,
|
713 |
+
"<loc_727>": 50996,
|
714 |
+
"<loc_728>": 50997,
|
715 |
+
"<loc_729>": 50998,
|
716 |
+
"<loc_72>": 50341,
|
717 |
+
"<loc_730>": 50999,
|
718 |
+
"<loc_731>": 51000,
|
719 |
+
"<loc_732>": 51001,
|
720 |
+
"<loc_733>": 51002,
|
721 |
+
"<loc_734>": 51003,
|
722 |
+
"<loc_735>": 51004,
|
723 |
+
"<loc_736>": 51005,
|
724 |
+
"<loc_737>": 51006,
|
725 |
+
"<loc_738>": 51007,
|
726 |
+
"<loc_739>": 51008,
|
727 |
+
"<loc_73>": 50342,
|
728 |
+
"<loc_740>": 51009,
|
729 |
+
"<loc_741>": 51010,
|
730 |
+
"<loc_742>": 51011,
|
731 |
+
"<loc_743>": 51012,
|
732 |
+
"<loc_744>": 51013,
|
733 |
+
"<loc_745>": 51014,
|
734 |
+
"<loc_746>": 51015,
|
735 |
+
"<loc_747>": 51016,
|
736 |
+
"<loc_748>": 51017,
|
737 |
+
"<loc_749>": 51018,
|
738 |
+
"<loc_74>": 50343,
|
739 |
+
"<loc_750>": 51019,
|
740 |
+
"<loc_751>": 51020,
|
741 |
+
"<loc_752>": 51021,
|
742 |
+
"<loc_753>": 51022,
|
743 |
+
"<loc_754>": 51023,
|
744 |
+
"<loc_755>": 51024,
|
745 |
+
"<loc_756>": 51025,
|
746 |
+
"<loc_757>": 51026,
|
747 |
+
"<loc_758>": 51027,
|
748 |
+
"<loc_759>": 51028,
|
749 |
+
"<loc_75>": 50344,
|
750 |
+
"<loc_760>": 51029,
|
751 |
+
"<loc_761>": 51030,
|
752 |
+
"<loc_762>": 51031,
|
753 |
+
"<loc_763>": 51032,
|
754 |
+
"<loc_764>": 51033,
|
755 |
+
"<loc_765>": 51034,
|
756 |
+
"<loc_766>": 51035,
|
757 |
+
"<loc_767>": 51036,
|
758 |
+
"<loc_768>": 51037,
|
759 |
+
"<loc_769>": 51038,
|
760 |
+
"<loc_76>": 50345,
|
761 |
+
"<loc_770>": 51039,
|
762 |
+
"<loc_771>": 51040,
|
763 |
+
"<loc_772>": 51041,
|
764 |
+
"<loc_773>": 51042,
|
765 |
+
"<loc_774>": 51043,
|
766 |
+
"<loc_775>": 51044,
|
767 |
+
"<loc_776>": 51045,
|
768 |
+
"<loc_777>": 51046,
|
769 |
+
"<loc_778>": 51047,
|
770 |
+
"<loc_779>": 51048,
|
771 |
+
"<loc_77>": 50346,
|
772 |
+
"<loc_780>": 51049,
|
773 |
+
"<loc_781>": 51050,
|
774 |
+
"<loc_782>": 51051,
|
775 |
+
"<loc_783>": 51052,
|
776 |
+
"<loc_784>": 51053,
|
777 |
+
"<loc_785>": 51054,
|
778 |
+
"<loc_786>": 51055,
|
779 |
+
"<loc_787>": 51056,
|
780 |
+
"<loc_788>": 51057,
|
781 |
+
"<loc_789>": 51058,
|
782 |
+
"<loc_78>": 50347,
|
783 |
+
"<loc_790>": 51059,
|
784 |
+
"<loc_791>": 51060,
|
785 |
+
"<loc_792>": 51061,
|
786 |
+
"<loc_793>": 51062,
|
787 |
+
"<loc_794>": 51063,
|
788 |
+
"<loc_795>": 51064,
|
789 |
+
"<loc_796>": 51065,
|
790 |
+
"<loc_797>": 51066,
|
791 |
+
"<loc_798>": 51067,
|
792 |
+
"<loc_799>": 51068,
|
793 |
+
"<loc_79>": 50348,
|
794 |
+
"<loc_7>": 50276,
|
795 |
+
"<loc_800>": 51069,
|
796 |
+
"<loc_801>": 51070,
|
797 |
+
"<loc_802>": 51071,
|
798 |
+
"<loc_803>": 51072,
|
799 |
+
"<loc_804>": 51073,
|
800 |
+
"<loc_805>": 51074,
|
801 |
+
"<loc_806>": 51075,
|
802 |
+
"<loc_807>": 51076,
|
803 |
+
"<loc_808>": 51077,
|
804 |
+
"<loc_809>": 51078,
|
805 |
+
"<loc_80>": 50349,
|
806 |
+
"<loc_810>": 51079,
|
807 |
+
"<loc_811>": 51080,
|
808 |
+
"<loc_812>": 51081,
|
809 |
+
"<loc_813>": 51082,
|
810 |
+
"<loc_814>": 51083,
|
811 |
+
"<loc_815>": 51084,
|
812 |
+
"<loc_816>": 51085,
|
813 |
+
"<loc_817>": 51086,
|
814 |
+
"<loc_818>": 51087,
|
815 |
+
"<loc_819>": 51088,
|
816 |
+
"<loc_81>": 50350,
|
817 |
+
"<loc_820>": 51089,
|
818 |
+
"<loc_821>": 51090,
|
819 |
+
"<loc_822>": 51091,
|
820 |
+
"<loc_823>": 51092,
|
821 |
+
"<loc_824>": 51093,
|
822 |
+
"<loc_825>": 51094,
|
823 |
+
"<loc_826>": 51095,
|
824 |
+
"<loc_827>": 51096,
|
825 |
+
"<loc_828>": 51097,
|
826 |
+
"<loc_829>": 51098,
|
827 |
+
"<loc_82>": 50351,
|
828 |
+
"<loc_830>": 51099,
|
829 |
+
"<loc_831>": 51100,
|
830 |
+
"<loc_832>": 51101,
|
831 |
+
"<loc_833>": 51102,
|
832 |
+
"<loc_834>": 51103,
|
833 |
+
"<loc_835>": 51104,
|
834 |
+
"<loc_836>": 51105,
|
835 |
+
"<loc_837>": 51106,
|
836 |
+
"<loc_838>": 51107,
|
837 |
+
"<loc_839>": 51108,
|
838 |
+
"<loc_83>": 50352,
|
839 |
+
"<loc_840>": 51109,
|
840 |
+
"<loc_841>": 51110,
|
841 |
+
"<loc_842>": 51111,
|
842 |
+
"<loc_843>": 51112,
|
843 |
+
"<loc_844>": 51113,
|
844 |
+
"<loc_845>": 51114,
|
845 |
+
"<loc_846>": 51115,
|
846 |
+
"<loc_847>": 51116,
|
847 |
+
"<loc_848>": 51117,
|
848 |
+
"<loc_849>": 51118,
|
849 |
+
"<loc_84>": 50353,
|
850 |
+
"<loc_850>": 51119,
|
851 |
+
"<loc_851>": 51120,
|
852 |
+
"<loc_852>": 51121,
|
853 |
+
"<loc_853>": 51122,
|
854 |
+
"<loc_854>": 51123,
|
855 |
+
"<loc_855>": 51124,
|
856 |
+
"<loc_856>": 51125,
|
857 |
+
"<loc_857>": 51126,
|
858 |
+
"<loc_858>": 51127,
|
859 |
+
"<loc_859>": 51128,
|
860 |
+
"<loc_85>": 50354,
|
861 |
+
"<loc_860>": 51129,
|
862 |
+
"<loc_861>": 51130,
|
863 |
+
"<loc_862>": 51131,
|
864 |
+
"<loc_863>": 51132,
|
865 |
+
"<loc_864>": 51133,
|
866 |
+
"<loc_865>": 51134,
|
867 |
+
"<loc_866>": 51135,
|
868 |
+
"<loc_867>": 51136,
|
869 |
+
"<loc_868>": 51137,
|
870 |
+
"<loc_869>": 51138,
|
871 |
+
"<loc_86>": 50355,
|
872 |
+
"<loc_870>": 51139,
|
873 |
+
"<loc_871>": 51140,
|
874 |
+
"<loc_872>": 51141,
|
875 |
+
"<loc_873>": 51142,
|
876 |
+
"<loc_874>": 51143,
|
877 |
+
"<loc_875>": 51144,
|
878 |
+
"<loc_876>": 51145,
|
879 |
+
"<loc_877>": 51146,
|
880 |
+
"<loc_878>": 51147,
|
881 |
+
"<loc_879>": 51148,
|
882 |
+
"<loc_87>": 50356,
|
883 |
+
"<loc_880>": 51149,
|
884 |
+
"<loc_881>": 51150,
|
885 |
+
"<loc_882>": 51151,
|
886 |
+
"<loc_883>": 51152,
|
887 |
+
"<loc_884>": 51153,
|
888 |
+
"<loc_885>": 51154,
|
889 |
+
"<loc_886>": 51155,
|
890 |
+
"<loc_887>": 51156,
|
891 |
+
"<loc_888>": 51157,
|
892 |
+
"<loc_889>": 51158,
|
893 |
+
"<loc_88>": 50357,
|
894 |
+
"<loc_890>": 51159,
|
895 |
+
"<loc_891>": 51160,
|
896 |
+
"<loc_892>": 51161,
|
897 |
+
"<loc_893>": 51162,
|
898 |
+
"<loc_894>": 51163,
|
899 |
+
"<loc_895>": 51164,
|
900 |
+
"<loc_896>": 51165,
|
901 |
+
"<loc_897>": 51166,
|
902 |
+
"<loc_898>": 51167,
|
903 |
+
"<loc_899>": 51168,
|
904 |
+
"<loc_89>": 50358,
|
905 |
+
"<loc_8>": 50277,
|
906 |
+
"<loc_900>": 51169,
|
907 |
+
"<loc_901>": 51170,
|
908 |
+
"<loc_902>": 51171,
|
909 |
+
"<loc_903>": 51172,
|
910 |
+
"<loc_904>": 51173,
|
911 |
+
"<loc_905>": 51174,
|
912 |
+
"<loc_906>": 51175,
|
913 |
+
"<loc_907>": 51176,
|
914 |
+
"<loc_908>": 51177,
|
915 |
+
"<loc_909>": 51178,
|
916 |
+
"<loc_90>": 50359,
|
917 |
+
"<loc_910>": 51179,
|
918 |
+
"<loc_911>": 51180,
|
919 |
+
"<loc_912>": 51181,
|
920 |
+
"<loc_913>": 51182,
|
921 |
+
"<loc_914>": 51183,
|
922 |
+
"<loc_915>": 51184,
|
923 |
+
"<loc_916>": 51185,
|
924 |
+
"<loc_917>": 51186,
|
925 |
+
"<loc_918>": 51187,
|
926 |
+
"<loc_919>": 51188,
|
927 |
+
"<loc_91>": 50360,
|
928 |
+
"<loc_920>": 51189,
|
929 |
+
"<loc_921>": 51190,
|
930 |
+
"<loc_922>": 51191,
|
931 |
+
"<loc_923>": 51192,
|
932 |
+
"<loc_924>": 51193,
|
933 |
+
"<loc_925>": 51194,
|
934 |
+
"<loc_926>": 51195,
|
935 |
+
"<loc_927>": 51196,
|
936 |
+
"<loc_928>": 51197,
|
937 |
+
"<loc_929>": 51198,
|
938 |
+
"<loc_92>": 50361,
|
939 |
+
"<loc_930>": 51199,
|
940 |
+
"<loc_931>": 51200,
|
941 |
+
"<loc_932>": 51201,
|
942 |
+
"<loc_933>": 51202,
|
943 |
+
"<loc_934>": 51203,
|
944 |
+
"<loc_935>": 51204,
|
945 |
+
"<loc_936>": 51205,
|
946 |
+
"<loc_937>": 51206,
|
947 |
+
"<loc_938>": 51207,
|
948 |
+
"<loc_939>": 51208,
|
949 |
+
"<loc_93>": 50362,
|
950 |
+
"<loc_940>": 51209,
|
951 |
+
"<loc_941>": 51210,
|
952 |
+
"<loc_942>": 51211,
|
953 |
+
"<loc_943>": 51212,
|
954 |
+
"<loc_944>": 51213,
|
955 |
+
"<loc_945>": 51214,
|
956 |
+
"<loc_946>": 51215,
|
957 |
+
"<loc_947>": 51216,
|
958 |
+
"<loc_948>": 51217,
|
959 |
+
"<loc_949>": 51218,
|
960 |
+
"<loc_94>": 50363,
|
961 |
+
"<loc_950>": 51219,
|
962 |
+
"<loc_951>": 51220,
|
963 |
+
"<loc_952>": 51221,
|
964 |
+
"<loc_953>": 51222,
|
965 |
+
"<loc_954>": 51223,
|
966 |
+
"<loc_955>": 51224,
|
967 |
+
"<loc_956>": 51225,
|
968 |
+
"<loc_957>": 51226,
|
969 |
+
"<loc_958>": 51227,
|
970 |
+
"<loc_959>": 51228,
|
971 |
+
"<loc_95>": 50364,
|
972 |
+
"<loc_960>": 51229,
|
973 |
+
"<loc_961>": 51230,
|
974 |
+
"<loc_962>": 51231,
|
975 |
+
"<loc_963>": 51232,
|
976 |
+
"<loc_964>": 51233,
|
977 |
+
"<loc_965>": 51234,
|
978 |
+
"<loc_966>": 51235,
|
979 |
+
"<loc_967>": 51236,
|
980 |
+
"<loc_968>": 51237,
|
981 |
+
"<loc_969>": 51238,
|
982 |
+
"<loc_96>": 50365,
|
983 |
+
"<loc_970>": 51239,
|
984 |
+
"<loc_971>": 51240,
|
985 |
+
"<loc_972>": 51241,
|
986 |
+
"<loc_973>": 51242,
|
987 |
+
"<loc_974>": 51243,
|
988 |
+
"<loc_975>": 51244,
|
989 |
+
"<loc_976>": 51245,
|
990 |
+
"<loc_977>": 51246,
|
991 |
+
"<loc_978>": 51247,
|
992 |
+
"<loc_979>": 51248,
|
993 |
+
"<loc_97>": 50366,
|
994 |
+
"<loc_980>": 51249,
|
995 |
+
"<loc_981>": 51250,
|
996 |
+
"<loc_982>": 51251,
|
997 |
+
"<loc_983>": 51252,
|
998 |
+
"<loc_984>": 51253,
|
999 |
+
"<loc_985>": 51254,
|
1000 |
+
"<loc_986>": 51255,
|
1001 |
+
"<loc_987>": 51256,
|
1002 |
+
"<loc_988>": 51257,
|
1003 |
+
"<loc_989>": 51258,
|
1004 |
+
"<loc_98>": 50367,
|
1005 |
+
"<loc_990>": 51259,
|
1006 |
+
"<loc_991>": 51260,
|
1007 |
+
"<loc_992>": 51261,
|
1008 |
+
"<loc_993>": 51262,
|
1009 |
+
"<loc_994>": 51263,
|
1010 |
+
"<loc_995>": 51264,
|
1011 |
+
"<loc_996>": 51265,
|
1012 |
+
"<loc_997>": 51266,
|
1013 |
+
"<loc_998>": 51267,
|
1014 |
+
"<loc_999>": 51268,
|
1015 |
+
"<loc_99>": 50368,
|
1016 |
+
"<loc_9>": 50278,
|
1017 |
+
"<ncap>": 51271,
|
1018 |
+
"<ocr>": 50267,
|
1019 |
+
"<od>": 50265,
|
1020 |
+
"<poly>": 51286,
|
1021 |
+
"<proposal>": 51284,
|
1022 |
+
"<region_cap>": 51280,
|
1023 |
+
"<region_to_desciption>": 51282,
|
1024 |
+
"<seg>": 51277,
|
1025 |
+
"<sep>": 51279
|
1026 |
+
}
|
models/LLM/Florence-2-large-PromptGen-v2.0/config.json
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/Florence-2-large",
|
3 |
+
"architectures": [
|
4 |
+
"Florence2ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_florence2.Florence2Config",
|
8 |
+
"AutoModelForCausalLM": "modeling_florence2.Florence2ForConditionalGeneration"
|
9 |
+
},
|
10 |
+
"bos_token_id": 0,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"ignore_index": -100,
|
13 |
+
"is_encoder_decoder": true,
|
14 |
+
"model_type": "florence2",
|
15 |
+
"pad_token_id": 1,
|
16 |
+
"projection_dim": 1024,
|
17 |
+
"text_config": {
|
18 |
+
"_attn_implementation_autoset": true,
|
19 |
+
"_name_or_path": "",
|
20 |
+
"activation_dropout": 0.1,
|
21 |
+
"activation_function": "gelu",
|
22 |
+
"add_bias_logits": false,
|
23 |
+
"add_cross_attention": false,
|
24 |
+
"add_final_layer_norm": false,
|
25 |
+
"architectures": null,
|
26 |
+
"attention_dropout": 0.1,
|
27 |
+
"bad_words_ids": null,
|
28 |
+
"begin_suppress_tokens": null,
|
29 |
+
"bos_token_id": 0,
|
30 |
+
"chunk_size_feed_forward": 0,
|
31 |
+
"classif_dropout": 0.1,
|
32 |
+
"classifier_dropout": 0.0,
|
33 |
+
"cross_attention_hidden_size": null,
|
34 |
+
"d_model": 1024,
|
35 |
+
"decoder_attention_heads": 16,
|
36 |
+
"decoder_ffn_dim": 4096,
|
37 |
+
"decoder_layerdrop": 0.0,
|
38 |
+
"decoder_layers": 12,
|
39 |
+
"decoder_start_token_id": 2,
|
40 |
+
"diversity_penalty": 0.0,
|
41 |
+
"do_sample": false,
|
42 |
+
"dropout": 0.1,
|
43 |
+
"early_stopping": true,
|
44 |
+
"encoder_attention_heads": 16,
|
45 |
+
"encoder_ffn_dim": 4096,
|
46 |
+
"encoder_layerdrop": 0.0,
|
47 |
+
"encoder_layers": 12,
|
48 |
+
"encoder_no_repeat_ngram_size": 0,
|
49 |
+
"eos_token_id": 2,
|
50 |
+
"exponential_decay_length_penalty": null,
|
51 |
+
"finetuning_task": null,
|
52 |
+
"forced_bos_token_id": 0,
|
53 |
+
"forced_eos_token_id": 2,
|
54 |
+
"gradient_checkpointing": false,
|
55 |
+
"id2label": {
|
56 |
+
"0": "LABEL_0",
|
57 |
+
"1": "LABEL_1",
|
58 |
+
"2": "LABEL_2"
|
59 |
+
},
|
60 |
+
"init_std": 0.02,
|
61 |
+
"is_decoder": false,
|
62 |
+
"is_encoder_decoder": true,
|
63 |
+
"label2id": {
|
64 |
+
"LABEL_0": 0,
|
65 |
+
"LABEL_1": 1,
|
66 |
+
"LABEL_2": 2
|
67 |
+
},
|
68 |
+
"length_penalty": 1.0,
|
69 |
+
"max_length": 20,
|
70 |
+
"max_position_embeddings": 1024,
|
71 |
+
"min_length": 0,
|
72 |
+
"model_type": "florence2_language",
|
73 |
+
"no_repeat_ngram_size": 3,
|
74 |
+
"normalize_before": false,
|
75 |
+
"num_beam_groups": 1,
|
76 |
+
"num_beams": 3,
|
77 |
+
"num_hidden_layers": 12,
|
78 |
+
"num_return_sequences": 1,
|
79 |
+
"output_attentions": false,
|
80 |
+
"output_hidden_states": false,
|
81 |
+
"output_scores": false,
|
82 |
+
"pad_token_id": 1,
|
83 |
+
"prefix": null,
|
84 |
+
"problem_type": null,
|
85 |
+
"pruned_heads": {},
|
86 |
+
"remove_invalid_values": false,
|
87 |
+
"repetition_penalty": 1.0,
|
88 |
+
"return_dict": true,
|
89 |
+
"return_dict_in_generate": false,
|
90 |
+
"scale_embedding": false,
|
91 |
+
"sep_token_id": null,
|
92 |
+
"suppress_tokens": null,
|
93 |
+
"task_specific_params": null,
|
94 |
+
"temperature": 1.0,
|
95 |
+
"tf_legacy_loss": false,
|
96 |
+
"tie_encoder_decoder": false,
|
97 |
+
"tie_word_embeddings": true,
|
98 |
+
"tokenizer_class": null,
|
99 |
+
"top_k": 50,
|
100 |
+
"top_p": 1.0,
|
101 |
+
"torch_dtype": null,
|
102 |
+
"torchscript": false,
|
103 |
+
"typical_p": 1.0,
|
104 |
+
"use_bfloat16": false,
|
105 |
+
"use_cache": true,
|
106 |
+
"vocab_size": 51289
|
107 |
+
},
|
108 |
+
"torch_dtype": "float32",
|
109 |
+
"transformers_version": "4.46.1",
|
110 |
+
"vision_config": {
|
111 |
+
"model_type": "davit",
|
112 |
+
"drop_path_rate": 0.1,
|
113 |
+
"patch_size": [7, 3, 3, 3],
|
114 |
+
"patch_stride": [4, 2, 2, 2],
|
115 |
+
"patch_padding": [3, 1, 1, 1],
|
116 |
+
"patch_prenorm": [false, true, true, true],
|
117 |
+
"enable_checkpoint": false,
|
118 |
+
"dim_embed": [256, 512, 1024, 2048],
|
119 |
+
"num_heads": [8, 16, 32, 64],
|
120 |
+
"num_groups": [8, 16, 32, 64],
|
121 |
+
"depths": [1, 1, 9, 1],
|
122 |
+
"window_size": 12,
|
123 |
+
"projection_dim": 1024,
|
124 |
+
"visual_temporal_embedding": {
|
125 |
+
"type": "COSINE",
|
126 |
+
"max_temporal_embeddings": 100
|
127 |
+
},
|
128 |
+
"image_pos_embed": {
|
129 |
+
"type": "learned_abs_2d",
|
130 |
+
"max_pos_embeddings": 50
|
131 |
+
},
|
132 |
+
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
133 |
+
},
|
134 |
+
"vocab_size": 51289,
|
135 |
+
"torch_dtype": "float16",
|
136 |
+
"transformers_version": "4.41.0.dev0",
|
137 |
+
"is_encoder_decoder": true
|
138 |
+
}
|
models/LLM/Florence-2-large-PromptGen-v2.0/configuration_florence2.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import warnings
|
15 |
+
""" Florence-2 configuration"""
|
16 |
+
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
from transformers import AutoConfig
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
class Florence2VisionConfig(PretrainedConfig):
|
26 |
+
r"""
|
27 |
+
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
28 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
29 |
+
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
30 |
+
|
31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
32 |
+
documentation from [`PretrainedConfig`] for more information.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
36 |
+
The dropout rate of the drop path layer.
|
37 |
+
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
38 |
+
The patch size of the image.
|
39 |
+
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
40 |
+
The patch stride of the image.
|
41 |
+
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
42 |
+
The patch padding of the image.
|
43 |
+
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
44 |
+
Whether to apply layer normalization before the patch embedding layer.
|
45 |
+
enable_checkpoint (`bool`, *optional*, defaults to False):
|
46 |
+
Whether to enable checkpointing.
|
47 |
+
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
48 |
+
The dimension of the embedding layer.
|
49 |
+
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
50 |
+
The number of attention heads.
|
51 |
+
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
52 |
+
The number of groups.
|
53 |
+
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
54 |
+
The depth of the model.
|
55 |
+
window_size (`int`, *optional*, defaults to 12):
|
56 |
+
The window size of the model.
|
57 |
+
projection_dim (`int`, *optional*, defaults to 1024):
|
58 |
+
The dimension of the projection layer.
|
59 |
+
visual_temporal_embedding (`dict`, *optional*):
|
60 |
+
The configuration of the visual temporal embedding.
|
61 |
+
image_pos_embed (`dict`, *optional*):
|
62 |
+
The configuration of the image position embedding.
|
63 |
+
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
64 |
+
The source of the image feature.
|
65 |
+
Example:
|
66 |
+
|
67 |
+
```python
|
68 |
+
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
69 |
+
|
70 |
+
>>> # Initializing a Florence2 Vision style configuration
|
71 |
+
>>> configuration = Florence2VisionConfig()
|
72 |
+
|
73 |
+
>>> # Initializing a model (with random weights)
|
74 |
+
>>> model = Florence2VisionModel(configuration)
|
75 |
+
|
76 |
+
>>> # Accessing the model configuration
|
77 |
+
>>> configuration = model.config
|
78 |
+
```"""
|
79 |
+
|
80 |
+
model_type = "florence2_vision"
|
81 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
drop_path_rate=0.1,
|
86 |
+
patch_size=[7, 3, 3, 3],
|
87 |
+
patch_stride=[4, 2, 2, 2],
|
88 |
+
patch_padding=[3, 1, 1, 1],
|
89 |
+
patch_prenorm=[False, True, True, True],
|
90 |
+
enable_checkpoint=False,
|
91 |
+
dim_embed=[256, 512, 1024, 2048],
|
92 |
+
num_heads=[8, 16, 32, 64],
|
93 |
+
num_groups=[8, 16, 32, 64],
|
94 |
+
depths=[1, 1, 9, 1],
|
95 |
+
window_size=12,
|
96 |
+
projection_dim=1024,
|
97 |
+
visual_temporal_embedding=None,
|
98 |
+
image_pos_embed=None,
|
99 |
+
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
self.drop_path_rate = drop_path_rate
|
103 |
+
self.patch_size = patch_size
|
104 |
+
self.patch_stride = patch_stride
|
105 |
+
self.patch_padding = patch_padding
|
106 |
+
self.patch_prenorm = patch_prenorm
|
107 |
+
self.enable_checkpoint = enable_checkpoint
|
108 |
+
self.dim_embed = dim_embed
|
109 |
+
self.num_heads = num_heads
|
110 |
+
self.num_groups = num_groups
|
111 |
+
self.depths = depths
|
112 |
+
self.window_size = window_size
|
113 |
+
self.projection_dim = projection_dim
|
114 |
+
self.visual_temporal_embedding = visual_temporal_embedding
|
115 |
+
self.image_pos_embed = image_pos_embed
|
116 |
+
self.image_feature_source = image_feature_source
|
117 |
+
|
118 |
+
super().__init__(**kwargs)
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
class Florence2LanguageConfig(PretrainedConfig):
|
123 |
+
r"""
|
124 |
+
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
125 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
126 |
+
defaults will yield a similar configuration to that of the BART
|
127 |
+
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
128 |
+
|
129 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
130 |
+
documentation from [`PretrainedConfig`] for more information.
|
131 |
+
|
132 |
+
|
133 |
+
Args:
|
134 |
+
vocab_size (`int`, *optional*, defaults to 51289):
|
135 |
+
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
136 |
+
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
137 |
+
d_model (`int`, *optional*, defaults to 1024):
|
138 |
+
Dimensionality of the layers and the pooler layer.
|
139 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
140 |
+
Number of encoder layers.
|
141 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
142 |
+
Number of decoder layers.
|
143 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
144 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
145 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
146 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
147 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
148 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
149 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
150 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
151 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
153 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
154 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
155 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
157 |
+
The dropout ratio for the attention probabilities.
|
158 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
159 |
+
The dropout ratio for activations inside the fully connected layer.
|
160 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
161 |
+
The dropout ratio for classifier.
|
162 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
163 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
164 |
+
just in case (e.g., 512 or 1024 or 2048).
|
165 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
166 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
167 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
168 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
169 |
+
for more details.
|
170 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
171 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
172 |
+
for more details.
|
173 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
174 |
+
Scale embeddings by diving by sqrt(d_model).
|
175 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
176 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
177 |
+
num_labels (`int`, *optional*, defaults to 3):
|
178 |
+
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
179 |
+
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
180 |
+
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
181 |
+
`eos_token_id`.
|
182 |
+
|
183 |
+
Example:
|
184 |
+
|
185 |
+
```python
|
186 |
+
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
187 |
+
|
188 |
+
>>> # Initializing a Florence2 Language style configuration
|
189 |
+
>>> configuration = Florence2LanguageConfig()
|
190 |
+
|
191 |
+
>>> # Initializing a model (with random weights)
|
192 |
+
>>> model = Florence2LangaugeModel(configuration)
|
193 |
+
|
194 |
+
>>> # Accessing the model configuration
|
195 |
+
>>> configuration = model.config
|
196 |
+
```"""
|
197 |
+
|
198 |
+
model_type = "florence2_language"
|
199 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
200 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
vocab_size=51289,
|
205 |
+
max_position_embeddings=1024,
|
206 |
+
encoder_layers=12,
|
207 |
+
encoder_ffn_dim=4096,
|
208 |
+
encoder_attention_heads=16,
|
209 |
+
decoder_layers=12,
|
210 |
+
decoder_ffn_dim=4096,
|
211 |
+
decoder_attention_heads=16,
|
212 |
+
encoder_layerdrop=0.0,
|
213 |
+
decoder_layerdrop=0.0,
|
214 |
+
activation_function="gelu",
|
215 |
+
d_model=1024,
|
216 |
+
dropout=0.1,
|
217 |
+
attention_dropout=0.0,
|
218 |
+
activation_dropout=0.0,
|
219 |
+
init_std=0.02,
|
220 |
+
classifier_dropout=0.0,
|
221 |
+
scale_embedding=False,
|
222 |
+
use_cache=True,
|
223 |
+
num_labels=3,
|
224 |
+
pad_token_id=1,
|
225 |
+
bos_token_id=0,
|
226 |
+
eos_token_id=2,
|
227 |
+
is_encoder_decoder=True,
|
228 |
+
decoder_start_token_id=2,
|
229 |
+
forced_eos_token_id=2,
|
230 |
+
**kwargs,
|
231 |
+
):
|
232 |
+
self.vocab_size = vocab_size
|
233 |
+
self.max_position_embeddings = max_position_embeddings
|
234 |
+
self.d_model = d_model
|
235 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
236 |
+
self.encoder_layers = encoder_layers
|
237 |
+
self.encoder_attention_heads = encoder_attention_heads
|
238 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
239 |
+
self.decoder_layers = decoder_layers
|
240 |
+
self.decoder_attention_heads = decoder_attention_heads
|
241 |
+
self.dropout = dropout
|
242 |
+
self.attention_dropout = attention_dropout
|
243 |
+
self.activation_dropout = activation_dropout
|
244 |
+
self.activation_function = activation_function
|
245 |
+
self.init_std = init_std
|
246 |
+
self.encoder_layerdrop = encoder_layerdrop
|
247 |
+
self.decoder_layerdrop = decoder_layerdrop
|
248 |
+
self.classifier_dropout = classifier_dropout
|
249 |
+
self.use_cache = use_cache
|
250 |
+
self.num_hidden_layers = encoder_layers
|
251 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
252 |
+
|
253 |
+
super().__init__(
|
254 |
+
num_labels=num_labels,
|
255 |
+
pad_token_id=pad_token_id,
|
256 |
+
bos_token_id=bos_token_id,
|
257 |
+
eos_token_id=eos_token_id,
|
258 |
+
is_encoder_decoder=is_encoder_decoder,
|
259 |
+
decoder_start_token_id=decoder_start_token_id,
|
260 |
+
forced_eos_token_id=forced_eos_token_id,
|
261 |
+
**kwargs,
|
262 |
+
)
|
263 |
+
|
264 |
+
# ensure backward compatibility for BART CNN models
|
265 |
+
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
266 |
+
self.forced_bos_token_id = self.bos_token_id
|
267 |
+
warnings.warn(
|
268 |
+
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
269 |
+
"The config can simply be saved and uploaded again to be fixed."
|
270 |
+
)
|
271 |
+
|
272 |
+
class Florence2Config(PretrainedConfig):
|
273 |
+
r"""
|
274 |
+
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
275 |
+
Florence-2 model according to the specified arguments, defining the model architecture.
|
276 |
+
|
277 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
278 |
+
documentation from [`PretrainedConfig`] for more information.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
vision_config (`Florence2VisionConfig`, *optional*):
|
282 |
+
Custom vision config or dict
|
283 |
+
text_config (`Union[AutoConfig, dict]`, *optional*):
|
284 |
+
The config object of the text backbone.
|
285 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
286 |
+
The ignore index for the loss function.
|
287 |
+
vocab_size (`int`, *optional*, defaults to 51289):
|
288 |
+
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
289 |
+
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
290 |
+
projection_dim (`int`, *optional*, defaults to 1024):
|
291 |
+
Dimension of the multimodal projection space.
|
292 |
+
|
293 |
+
Example:
|
294 |
+
|
295 |
+
```python
|
296 |
+
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
297 |
+
|
298 |
+
>>> # Initializing a clip-like vision config
|
299 |
+
>>> vision_config = CLIPVisionConfig()
|
300 |
+
|
301 |
+
>>> # Initializing a Bart config
|
302 |
+
>>> text_config = BartConfig()
|
303 |
+
|
304 |
+
>>> # Initializing a Florence-2 configuration
|
305 |
+
>>> configuration = Florence2Config(vision_config, text_config)
|
306 |
+
|
307 |
+
>>> # Initializing a model from the florence-2 configuration
|
308 |
+
>>> model = Florence2ForConditionalGeneration(configuration)
|
309 |
+
|
310 |
+
>>> # Accessing the model configuration
|
311 |
+
>>> configuration = model.config
|
312 |
+
```"""
|
313 |
+
|
314 |
+
model_type = "florence2"
|
315 |
+
is_composition = False
|
316 |
+
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
vision_config=None,
|
320 |
+
text_config=None,
|
321 |
+
ignore_index=-100,
|
322 |
+
vocab_size=51289,
|
323 |
+
projection_dim=1024,
|
324 |
+
**kwargs,
|
325 |
+
):
|
326 |
+
self.ignore_index = ignore_index
|
327 |
+
self.vocab_size = vocab_size
|
328 |
+
self.projection_dim = projection_dim
|
329 |
+
if vision_config is not None:
|
330 |
+
vision_config = PretrainedConfig(**vision_config)
|
331 |
+
self.vision_config = vision_config
|
332 |
+
self.vocab_size = self.vocab_size
|
333 |
+
|
334 |
+
self.text_config = text_config
|
335 |
+
if text_config is not None:
|
336 |
+
self.text_config = Florence2LanguageConfig(**text_config)
|
337 |
+
|
338 |
+
|
339 |
+
super().__init__(**kwargs)
|
340 |
+
|
models/LLM/Florence-2-large-PromptGen-v2.0/generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_beams": 3,
|
3 |
+
"transformers_version": "4.46.1"
|
4 |
+
}
|
models/LLM/Florence-2-large-PromptGen-v2.0/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/LLM/Florence-2-large-PromptGen-v2.0/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95b6441fb8e3a96b1f6ec0ac894a7632ea49fc77c0dd623a7a53d1d879390321
|
3 |
+
size 3291921348
|
models/LLM/Florence-2-large-PromptGen-v2.0/modeling_florence2.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/LLM/Florence-2-large-PromptGen-v2.0/preprocessor_config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_florence2.Florence2Processor"
|
4 |
+
},
|
5 |
+
"crop_size": {
|
6 |
+
"height": 768,
|
7 |
+
"width": 768
|
8 |
+
},
|
9 |
+
"do_center_crop": false,
|
10 |
+
"do_convert_rgb": null,
|
11 |
+
"do_normalize": true,
|
12 |
+
"do_rescale": true,
|
13 |
+
"do_resize": true,
|
14 |
+
"image_mean": [
|
15 |
+
0.485,
|
16 |
+
0.456,
|
17 |
+
0.406
|
18 |
+
],
|
19 |
+
"image_processor_type": "CLIPImageProcessor",
|
20 |
+
"image_seq_length": 577,
|
21 |
+
"image_std": [
|
22 |
+
0.229,
|
23 |
+
0.224,
|
24 |
+
0.225
|
25 |
+
],
|
26 |
+
"processor_class": "Florence2Processor",
|
27 |
+
"resample": 3,
|
28 |
+
"rescale_factor": 0.00392156862745098,
|
29 |
+
"size": {
|
30 |
+
"height": 768,
|
31 |
+
"width": 768
|
32 |
+
}
|
33 |
+
}
|
models/LLM/Florence-2-large-PromptGen-v2.0/processing_florence2.py
ADDED
@@ -0,0 +1,1088 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Microsoft and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Processor class for Florence-2.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import re
|
20 |
+
import logging
|
21 |
+
from typing import List, Optional, Union
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from transformers.feature_extraction_utils import BatchFeature
|
27 |
+
from transformers.image_utils import ImageInput, is_valid_image
|
28 |
+
from transformers.processing_utils import ProcessorMixin
|
29 |
+
from transformers.tokenization_utils_base import (
|
30 |
+
PaddingStrategy,
|
31 |
+
PreTokenizedInput,
|
32 |
+
TextInput,
|
33 |
+
TruncationStrategy,
|
34 |
+
)
|
35 |
+
from transformers.utils import TensorType
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
# Copied from transformers.models.idefics2.processing_idefics2.is_url
|
41 |
+
def is_url(val) -> bool:
|
42 |
+
return isinstance(val, str) and val.startswith("http")
|
43 |
+
|
44 |
+
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
|
45 |
+
def is_image_or_image_url(elem):
|
46 |
+
return is_url(elem) or is_valid_image(elem)
|
47 |
+
|
48 |
+
|
49 |
+
def _is_str_or_image(elem):
|
50 |
+
return isinstance(elem, (str)) or is_image_or_image_url(elem)
|
51 |
+
|
52 |
+
|
53 |
+
class Florence2Processor(ProcessorMixin):
|
54 |
+
r"""
|
55 |
+
Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
|
56 |
+
|
57 |
+
[`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
|
58 |
+
[`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
image_processor ([`CLIPImageProcessor`], *optional*):
|
62 |
+
The image processor is a required input.
|
63 |
+
tokenizer ([`BartTokenizerFast`], *optional*):
|
64 |
+
The tokenizer is a required input.
|
65 |
+
"""
|
66 |
+
|
67 |
+
attributes = ["image_processor", "tokenizer"]
|
68 |
+
image_processor_class = "CLIPImageProcessor"
|
69 |
+
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
image_processor=None,
|
74 |
+
tokenizer=None,
|
75 |
+
):
|
76 |
+
if image_processor is None:
|
77 |
+
raise ValueError("You need to specify an `image_processor`.")
|
78 |
+
if tokenizer is None:
|
79 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
80 |
+
if not hasattr(image_processor, "image_seq_length"):
|
81 |
+
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
|
82 |
+
|
83 |
+
self.image_seq_length = image_processor.image_seq_length
|
84 |
+
|
85 |
+
tokens_to_add = {
|
86 |
+
'additional_special_tokens': \
|
87 |
+
tokenizer.additional_special_tokens + \
|
88 |
+
['<od>', '</od>', '<ocr>', '</ocr>'] + \
|
89 |
+
[f'<loc_{x}>' for x in range(1000)] + \
|
90 |
+
['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
|
91 |
+
}
|
92 |
+
tokenizer.add_special_tokens(tokens_to_add)
|
93 |
+
|
94 |
+
self.tasks_answer_post_processing_type = {
|
95 |
+
'<OCR>': 'pure_text',
|
96 |
+
'<OCR_WITH_REGION>': 'ocr',
|
97 |
+
'<CAPTION>': 'pure_text',
|
98 |
+
'<DETAILED_CAPTION>': 'pure_text',
|
99 |
+
'<MORE_DETAILED_CAPTION>': 'pure_text',
|
100 |
+
'<OD>': 'description_with_bboxes',
|
101 |
+
'<DENSE_REGION_CAPTION>': 'description_with_bboxes',
|
102 |
+
'<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
|
103 |
+
'<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
|
104 |
+
'<REGION_TO_SEGMENTATION>': 'polygons',
|
105 |
+
'<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
|
106 |
+
'<REGION_TO_CATEGORY>': 'pure_text',
|
107 |
+
'<REGION_TO_DESCRIPTION>': 'pure_text',
|
108 |
+
'<REGION_TO_OCR>': 'pure_text',
|
109 |
+
'<REGION_PROPOSAL>': 'bboxes'
|
110 |
+
}
|
111 |
+
|
112 |
+
self.task_prompts_without_inputs = {
|
113 |
+
'<OCR>': 'What is the text in the image?',
|
114 |
+
'<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
|
115 |
+
'<CAPTION>': 'What does the image describe?',
|
116 |
+
'<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
|
117 |
+
'<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
|
118 |
+
'<OD>': 'Locate the objects with category name in the image.',
|
119 |
+
'<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
|
120 |
+
'<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
|
121 |
+
}
|
122 |
+
|
123 |
+
self.task_prompts_with_input = {
|
124 |
+
'<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
|
125 |
+
'<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
|
126 |
+
'<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
|
127 |
+
'<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
|
128 |
+
'<REGION_TO_CATEGORY>': 'What is the region {input}?',
|
129 |
+
'<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
|
130 |
+
'<REGION_TO_OCR>': 'What text is in the region {input}?',
|
131 |
+
}
|
132 |
+
|
133 |
+
self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
|
134 |
+
|
135 |
+
|
136 |
+
super().__init__(image_processor, tokenizer)
|
137 |
+
|
138 |
+
def _construct_prompts(self, text):
|
139 |
+
# replace the task tokens with the task prompts if task token is in the text
|
140 |
+
prompts = []
|
141 |
+
for _text in text:
|
142 |
+
# 1. fixed task prompts without additional inputs
|
143 |
+
for task_token, task_prompt in self.task_prompts_without_inputs.items():
|
144 |
+
if task_token in _text:
|
145 |
+
assert _text == task_token, f"Task token {task_token} should be the only token in the text."
|
146 |
+
_text = task_prompt
|
147 |
+
break
|
148 |
+
# 2. task prompts with additional inputs
|
149 |
+
for task_token, task_prompt in self.task_prompts_with_input.items():
|
150 |
+
if task_token in _text:
|
151 |
+
_text = task_prompt.format(input=_text.replace(task_token, ''))
|
152 |
+
break
|
153 |
+
prompts.append(_text)
|
154 |
+
return prompts
|
155 |
+
|
156 |
+
def __call__(
|
157 |
+
self,
|
158 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
159 |
+
images: ImageInput = None,
|
160 |
+
tokenize_newline_separately: bool = True,
|
161 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
162 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
163 |
+
max_length=None,
|
164 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
165 |
+
do_resize: bool = None,
|
166 |
+
do_normalize: bool = None,
|
167 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
168 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
169 |
+
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
|
170 |
+
input_data_format: Optional[
|
171 |
+
Union[str, "ChannelDimension"] # noqa: F821
|
172 |
+
] = None,
|
173 |
+
resample: "PILImageResampling" = None, # noqa: F821
|
174 |
+
do_convert_rgb: bool = None,
|
175 |
+
do_thumbnail: bool = None,
|
176 |
+
do_align_long_axis: bool = None,
|
177 |
+
do_rescale: bool = None,
|
178 |
+
) -> BatchFeature:
|
179 |
+
"""
|
180 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
181 |
+
and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
|
182 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
183 |
+
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
184 |
+
of the above two methods for more information.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
188 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
189 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
190 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
191 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
192 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
193 |
+
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
194 |
+
number of channels, H and W are image height and width.
|
195 |
+
tokenize_newline_separately (`bool`, defaults to `True`):
|
196 |
+
Adds a separately tokenized '\n' at the end of the prompt.
|
197 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
198 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
199 |
+
index) among:
|
200 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
201 |
+
sequence if provided).
|
202 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
203 |
+
acceptable input length for the model if that argument is not provided.
|
204 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
205 |
+
lengths).
|
206 |
+
max_length (`int`, *optional*):
|
207 |
+
Maximum length of the returned list and optionally padding length (see above).
|
208 |
+
truncation (`bool`, *optional*):
|
209 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
210 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
211 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
212 |
+
|
213 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
214 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
215 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
216 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
220 |
+
|
221 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
|
222 |
+
is provided, the `input_ids` will also contain the suffix input ids.
|
223 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
224 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
225 |
+
`None`).
|
226 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
227 |
+
- **labels** -- Labels compatible with training if `suffix` is not None
|
228 |
+
"""
|
229 |
+
|
230 |
+
return_token_type_ids = False
|
231 |
+
|
232 |
+
if images is None:
|
233 |
+
raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
|
234 |
+
if text is None:
|
235 |
+
logger.warning_once(
|
236 |
+
"You are using Florence-2 without a text prompt."
|
237 |
+
)
|
238 |
+
text = ""
|
239 |
+
|
240 |
+
if isinstance(text, List) and isinstance(images, List):
|
241 |
+
if len(images) < len(text):
|
242 |
+
raise ValueError(
|
243 |
+
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
|
244 |
+
)
|
245 |
+
if _is_str_or_image(text):
|
246 |
+
text = [text]
|
247 |
+
elif isinstance(text, list) and _is_str_or_image(text[0]):
|
248 |
+
pass
|
249 |
+
|
250 |
+
pixel_values = self.image_processor(
|
251 |
+
images,
|
252 |
+
do_resize=do_resize,
|
253 |
+
do_normalize=do_normalize,
|
254 |
+
return_tensors=return_tensors,
|
255 |
+
image_mean=image_mean,
|
256 |
+
image_std=image_std,
|
257 |
+
input_data_format=input_data_format,
|
258 |
+
data_format=data_format,
|
259 |
+
resample=resample,
|
260 |
+
do_convert_rgb=do_convert_rgb,
|
261 |
+
)["pixel_values"]
|
262 |
+
|
263 |
+
if max_length is not None:
|
264 |
+
max_length -= self.image_seq_length # max_length has to account for the image tokens
|
265 |
+
|
266 |
+
text = self._construct_prompts(text)
|
267 |
+
|
268 |
+
inputs = self.tokenizer(
|
269 |
+
text,
|
270 |
+
return_tensors=return_tensors,
|
271 |
+
padding=padding,
|
272 |
+
max_length=max_length,
|
273 |
+
truncation=truncation,
|
274 |
+
return_token_type_ids=return_token_type_ids,
|
275 |
+
)
|
276 |
+
|
277 |
+
return_data = {**inputs, "pixel_values": pixel_values}
|
278 |
+
|
279 |
+
if return_token_type_ids:
|
280 |
+
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
|
281 |
+
return_data.update({"labels": labels})
|
282 |
+
return BatchFeature(data=return_data)
|
283 |
+
|
284 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
|
285 |
+
def batch_decode(self, *args, **kwargs):
|
286 |
+
"""
|
287 |
+
This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
288 |
+
refer to the docstring of this method for more information.
|
289 |
+
"""
|
290 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
291 |
+
|
292 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
|
293 |
+
def decode(self, *args, **kwargs):
|
294 |
+
"""
|
295 |
+
This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
296 |
+
the docstring of this method for more information.
|
297 |
+
"""
|
298 |
+
return self.tokenizer.decode(*args, **kwargs)
|
299 |
+
|
300 |
+
@property
|
301 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
|
302 |
+
def model_input_names(self):
|
303 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
304 |
+
image_processor_input_names = self.image_processor.model_input_names
|
305 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
306 |
+
|
307 |
+
def post_process_generation(self, text, task, image_size):
|
308 |
+
"""
|
309 |
+
Post-process the output of the model to each of the task outputs.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
text (`str`): The text to post-process.
|
313 |
+
task (`str`): The task to post-process the text for.
|
314 |
+
image_size (`Tuple[int, int]`): The size of the image. height x width.
|
315 |
+
"""
|
316 |
+
|
317 |
+
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
318 |
+
task_answer = self.post_processor(
|
319 |
+
text=text,
|
320 |
+
image_size=image_size,
|
321 |
+
parse_tasks=task_answer_post_processing_type,
|
322 |
+
)[task_answer_post_processing_type]
|
323 |
+
|
324 |
+
if task_answer_post_processing_type == 'pure_text':
|
325 |
+
final_answer = task_answer
|
326 |
+
# remove the special tokens
|
327 |
+
final_answer = final_answer.replace('<s>', '').replace('</s>', '')
|
328 |
+
elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
|
329 |
+
od_instances = task_answer
|
330 |
+
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
331 |
+
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
332 |
+
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
333 |
+
elif task_answer_post_processing_type in ['ocr']:
|
334 |
+
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
335 |
+
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
336 |
+
final_answer = {'quad_boxes': bboxes, 'labels': labels}
|
337 |
+
elif task_answer_post_processing_type in ['phrase_grounding']:
|
338 |
+
bboxes = []
|
339 |
+
labels = []
|
340 |
+
for _grounded_phrase in task_answer:
|
341 |
+
for _bbox in _grounded_phrase['bbox']:
|
342 |
+
bboxes.append(_bbox)
|
343 |
+
labels.append(_grounded_phrase['cat_name'])
|
344 |
+
final_answer = {'bboxes': bboxes, 'labels': labels}
|
345 |
+
elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
|
346 |
+
labels = []
|
347 |
+
polygons = []
|
348 |
+
for result in task_answer:
|
349 |
+
label = result['cat_name']
|
350 |
+
_polygons = result['polygons']
|
351 |
+
labels.append(label)
|
352 |
+
polygons.append(_polygons)
|
353 |
+
final_answer = {'polygons': polygons, 'labels': labels}
|
354 |
+
elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
|
355 |
+
bboxes = []
|
356 |
+
bboxes_labels = []
|
357 |
+
polygons = []
|
358 |
+
polygons_labels = []
|
359 |
+
for result in task_answer:
|
360 |
+
label = result['cat_name']
|
361 |
+
if 'polygons' in result:
|
362 |
+
_polygons = result['polygons']
|
363 |
+
polygons.append(_polygons)
|
364 |
+
polygons_labels.append(label)
|
365 |
+
else:
|
366 |
+
_bbox = result['bbox']
|
367 |
+
bboxes.append(_bbox)
|
368 |
+
bboxes_labels.append(label)
|
369 |
+
final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
|
370 |
+
else:
|
371 |
+
raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
|
372 |
+
|
373 |
+
final_answer = {
|
374 |
+
task: final_answer}
|
375 |
+
return final_answer
|
376 |
+
|
377 |
+
class BoxQuantizer(object):
|
378 |
+
def __init__(self, mode, bins):
|
379 |
+
self.mode = mode
|
380 |
+
self.bins = bins
|
381 |
+
|
382 |
+
def quantize(self, boxes: torch.Tensor, size):
|
383 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
384 |
+
size_w, size_h = size # Original image size.
|
385 |
+
size_per_bin_w = size_w / bins_w
|
386 |
+
size_per_bin_h = size_h / bins_h
|
387 |
+
xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
|
388 |
+
|
389 |
+
if self.mode == 'floor':
|
390 |
+
quantized_xmin = (
|
391 |
+
xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
392 |
+
quantized_ymin = (
|
393 |
+
ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
394 |
+
quantized_xmax = (
|
395 |
+
xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
396 |
+
quantized_ymax = (
|
397 |
+
ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
398 |
+
|
399 |
+
elif self.mode == 'round':
|
400 |
+
raise NotImplementedError()
|
401 |
+
|
402 |
+
else:
|
403 |
+
raise ValueError('Incorrect quantization type.')
|
404 |
+
|
405 |
+
quantized_boxes = torch.cat(
|
406 |
+
(quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
|
407 |
+
).int()
|
408 |
+
|
409 |
+
return quantized_boxes
|
410 |
+
|
411 |
+
def dequantize(self, boxes: torch.Tensor, size):
|
412 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
413 |
+
size_w, size_h = size # Original image size.
|
414 |
+
size_per_bin_w = size_w / bins_w
|
415 |
+
size_per_bin_h = size_h / bins_h
|
416 |
+
xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
|
417 |
+
|
418 |
+
if self.mode == 'floor':
|
419 |
+
# Add 0.5 to use the center position of the bin as the coordinate.
|
420 |
+
dequantized_xmin = (xmin + 0.5) * size_per_bin_w
|
421 |
+
dequantized_ymin = (ymin + 0.5) * size_per_bin_h
|
422 |
+
dequantized_xmax = (xmax + 0.5) * size_per_bin_w
|
423 |
+
dequantized_ymax = (ymax + 0.5) * size_per_bin_h
|
424 |
+
|
425 |
+
elif self.mode == 'round':
|
426 |
+
raise NotImplementedError()
|
427 |
+
|
428 |
+
else:
|
429 |
+
raise ValueError('Incorrect quantization type.')
|
430 |
+
|
431 |
+
dequantized_boxes = torch.cat(
|
432 |
+
(dequantized_xmin, dequantized_ymin,
|
433 |
+
dequantized_xmax, dequantized_ymax), dim=-1
|
434 |
+
)
|
435 |
+
|
436 |
+
return dequantized_boxes
|
437 |
+
|
438 |
+
|
439 |
+
class CoordinatesQuantizer(object):
|
440 |
+
"""
|
441 |
+
Quantize coornidates (Nx2)
|
442 |
+
"""
|
443 |
+
|
444 |
+
def __init__(self, mode, bins):
|
445 |
+
self.mode = mode
|
446 |
+
self.bins = bins
|
447 |
+
|
448 |
+
def quantize(self, coordinates: torch.Tensor, size):
|
449 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
450 |
+
size_w, size_h = size # Original image size.
|
451 |
+
size_per_bin_w = size_w / bins_w
|
452 |
+
size_per_bin_h = size_h / bins_h
|
453 |
+
assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
|
454 |
+
x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
|
455 |
+
|
456 |
+
if self.mode == 'floor':
|
457 |
+
quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
458 |
+
quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
459 |
+
|
460 |
+
elif self.mode == 'round':
|
461 |
+
raise NotImplementedError()
|
462 |
+
|
463 |
+
else:
|
464 |
+
raise ValueError('Incorrect quantization type.')
|
465 |
+
|
466 |
+
quantized_coordinates = torch.cat(
|
467 |
+
(quantized_x, quantized_y), dim=-1
|
468 |
+
).int()
|
469 |
+
|
470 |
+
return quantized_coordinates
|
471 |
+
|
472 |
+
def dequantize(self, coordinates: torch.Tensor, size):
|
473 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
474 |
+
size_w, size_h = size # Original image size.
|
475 |
+
size_per_bin_w = size_w / bins_w
|
476 |
+
size_per_bin_h = size_h / bins_h
|
477 |
+
assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
|
478 |
+
x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
|
479 |
+
|
480 |
+
if self.mode == 'floor':
|
481 |
+
# Add 0.5 to use the center position of the bin as the coordinate.
|
482 |
+
dequantized_x = (x + 0.5) * size_per_bin_w
|
483 |
+
dequantized_y = (y + 0.5) * size_per_bin_h
|
484 |
+
|
485 |
+
elif self.mode == 'round':
|
486 |
+
raise NotImplementedError()
|
487 |
+
|
488 |
+
else:
|
489 |
+
raise ValueError('Incorrect quantization type.')
|
490 |
+
|
491 |
+
dequantized_coordinates = torch.cat(
|
492 |
+
(dequantized_x, dequantized_y), dim=-1
|
493 |
+
)
|
494 |
+
|
495 |
+
return dequantized_coordinates
|
496 |
+
|
497 |
+
|
498 |
+
class Florence2PostProcesser(object):
|
499 |
+
r"""
|
500 |
+
Florence-2 post process for converting text prediction to various tasks results.
|
501 |
+
|
502 |
+
Args:
|
503 |
+
config: A dict of configs.
|
504 |
+
tokenizer: A tokenizer for decoding text to spans.
|
505 |
+
sample config:
|
506 |
+
UNIFIED_POST_PROCESS:
|
507 |
+
# commom configs
|
508 |
+
NUM_BBOX_HEIGHT_BINS: 1000
|
509 |
+
NUM_BBOX_WIDTH_BINS: 1000
|
510 |
+
COORDINATES_HEIGHT_BINS: 1000
|
511 |
+
COORDINATES_WIDTH_BINS: 1000
|
512 |
+
# task specific configs, override the common configs
|
513 |
+
PRASE_TASKS:
|
514 |
+
- TASK_NAME: 'video_dense_caption'
|
515 |
+
PATTERN: 'r<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
|
516 |
+
SCORE_MODE: 'avg_cat_name_scores'
|
517 |
+
NUM_BINS: 100
|
518 |
+
- TASK_NAME: 'od'
|
519 |
+
PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
|
520 |
+
SCORE_MODE: 'avg_cat_name_scores'
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
parsed_dict (dict): A dict of parsed results.
|
524 |
+
"""
|
525 |
+
def __init__(
|
526 |
+
self,
|
527 |
+
tokenizer=None
|
528 |
+
):
|
529 |
+
parse_tasks = []
|
530 |
+
parse_task_configs = {}
|
531 |
+
config = self._create_default_config()
|
532 |
+
for task in config['PARSE_TASKS']:
|
533 |
+
parse_tasks.append(task['TASK_NAME'])
|
534 |
+
parse_task_configs[task['TASK_NAME']] = task
|
535 |
+
|
536 |
+
self.config = config
|
537 |
+
self.parse_tasks = parse_tasks
|
538 |
+
self.parse_tasks_configs = parse_task_configs
|
539 |
+
|
540 |
+
self.tokenizer = tokenizer
|
541 |
+
if self.tokenizer is not None:
|
542 |
+
self.all_special_tokens = set(self.tokenizer.all_special_tokens)
|
543 |
+
|
544 |
+
self.init_quantizers()
|
545 |
+
self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
|
546 |
+
|
547 |
+
def _create_black_list_of_phrase_grounding(self):
|
548 |
+
black_list = {}
|
549 |
+
|
550 |
+
if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
|
551 |
+
black_list = set(
|
552 |
+
['it', 'I', 'me', 'mine',
|
553 |
+
'you', 'your', 'yours',
|
554 |
+
'he', 'him', 'his',
|
555 |
+
'she', 'her', 'hers',
|
556 |
+
'they', 'them', 'their', 'theirs',
|
557 |
+
'one', 'oneself',
|
558 |
+
'we', 'us', 'our', 'ours',
|
559 |
+
'you', 'your', 'yours',
|
560 |
+
'they', 'them', 'their', 'theirs',
|
561 |
+
'mine', 'yours', 'his', 'hers', 'its',
|
562 |
+
'ours', 'yours', 'theirs',
|
563 |
+
'myself', 'yourself', 'himself', 'herself', 'itself',
|
564 |
+
'ourselves', 'yourselves', 'themselves',
|
565 |
+
'this', 'that',
|
566 |
+
'these', 'those',
|
567 |
+
'who', 'whom', 'whose', 'which', 'what',
|
568 |
+
'who', 'whom', 'whose', 'which', 'that',
|
569 |
+
'all', 'another', 'any', 'anybody', 'anyone', 'anything',
|
570 |
+
'each', 'everybody', 'everyone', 'everything',
|
571 |
+
'few', 'many', 'nobody', 'none', 'one', 'several',
|
572 |
+
'some', 'somebody', 'someone', 'something',
|
573 |
+
'each other', 'one another',
|
574 |
+
'myself', 'yourself', 'himself', 'herself', 'itself',
|
575 |
+
'ourselves', 'yourselves', 'themselves',
|
576 |
+
'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
|
577 |
+
'other objects', 'lots', 'a set',
|
578 |
+
]
|
579 |
+
)
|
580 |
+
|
581 |
+
return black_list
|
582 |
+
|
583 |
+
def _create_default_config(self):
|
584 |
+
config = {
|
585 |
+
'NUM_BBOX_HEIGHT_BINS': 1000,
|
586 |
+
'NUM_BBOX_WIDTH_BINS': 1000,
|
587 |
+
'BOX_QUANTIZATION_MODE': 'floor',
|
588 |
+
'COORDINATES_HEIGHT_BINS': 1000,
|
589 |
+
'COORDINATES_WIDTH_BINS': 1000,
|
590 |
+
'COORDINATES_QUANTIZATION_MODE': 'floor',
|
591 |
+
'PARSE_TASKS': [
|
592 |
+
{
|
593 |
+
'TASK_NAME': 'od',
|
594 |
+
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
|
595 |
+
},
|
596 |
+
{
|
597 |
+
'TASK_NAME': 'ocr',
|
598 |
+
'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>',
|
599 |
+
'AREA_THRESHOLD': 0.00
|
600 |
+
},
|
601 |
+
{
|
602 |
+
'TASK_NAME': 'phrase_grounding',
|
603 |
+
'FILTER_BY_BLACK_LIST': True
|
604 |
+
},
|
605 |
+
{
|
606 |
+
'TASK_NAME': 'pure_text',
|
607 |
+
},
|
608 |
+
{
|
609 |
+
'TASK_NAME': 'description_with_bboxes',
|
610 |
+
},
|
611 |
+
{
|
612 |
+
'TASK_NAME': 'description_with_polygons',
|
613 |
+
},
|
614 |
+
{
|
615 |
+
'TASK_NAME': 'polygons',
|
616 |
+
},
|
617 |
+
{
|
618 |
+
'TASK_NAME': 'bboxes',
|
619 |
+
},
|
620 |
+
{
|
621 |
+
'TASK_NAME': 'description_with_bboxes_or_polygons',
|
622 |
+
}
|
623 |
+
]
|
624 |
+
}
|
625 |
+
|
626 |
+
return config
|
627 |
+
|
628 |
+
def init_quantizers(self):
|
629 |
+
# we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
|
630 |
+
num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
|
631 |
+
num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
|
632 |
+
box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
|
633 |
+
self.box_quantizer = BoxQuantizer(
|
634 |
+
box_quantization_mode,
|
635 |
+
(num_bbox_width_bins, num_bbox_height_bins),
|
636 |
+
)
|
637 |
+
|
638 |
+
num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
|
639 |
+
num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
|
640 |
+
box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
|
641 |
+
self.coordinates_quantizer = CoordinatesQuantizer(
|
642 |
+
box_quantization_mode,
|
643 |
+
(num_bbox_width_bins, num_bbox_height_bins),
|
644 |
+
)
|
645 |
+
|
646 |
+
def decode_with_spans(self, tokenizer, token_ids):
|
647 |
+
filtered_tokens = tokenizer.convert_ids_to_tokens(
|
648 |
+
token_ids, skip_special_tokens=False)
|
649 |
+
assert len(filtered_tokens) == len(token_ids)
|
650 |
+
|
651 |
+
# To avoid mixing byte-level and unicode for byte-level BPT
|
652 |
+
# we need to build string separately for added tokens and byte-level tokens
|
653 |
+
# cf. https://github.com/huggingface/transformers/issues/1133
|
654 |
+
sub_texts = []
|
655 |
+
for token in filtered_tokens:
|
656 |
+
if token in self.all_special_tokens:
|
657 |
+
sub_texts.append(token)
|
658 |
+
else:
|
659 |
+
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
660 |
+
sub_text = tokenizer.convert_tokens_to_string([token])
|
661 |
+
elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
|
662 |
+
# Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
|
663 |
+
# Note: Do not strip sub_text as it may have functional whitespace
|
664 |
+
sub_text = token.replace('▁', ' ')
|
665 |
+
else:
|
666 |
+
raise ValueError(f'type {type(tokenizer)} not supported')
|
667 |
+
sub_texts.append(sub_text)
|
668 |
+
|
669 |
+
text = ''
|
670 |
+
spans = []
|
671 |
+
for sub_text in sub_texts:
|
672 |
+
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
|
673 |
+
text += sub_text
|
674 |
+
spans.append(span)
|
675 |
+
|
676 |
+
# Text format:
|
677 |
+
# 1. T5Tokenizer/T5TokenizerFast:
|
678 |
+
# "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
|
679 |
+
# Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
680 |
+
# 2. BartTokenizer (need to double check):
|
681 |
+
# "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
|
682 |
+
# Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
683 |
+
return text, spans
|
684 |
+
|
685 |
+
def parse_od_from_text_and_spans(
|
686 |
+
self,
|
687 |
+
text,
|
688 |
+
pattern,
|
689 |
+
image_size,
|
690 |
+
phrase_centric=False
|
691 |
+
):
|
692 |
+
parsed = list(re.finditer(pattern, text))
|
693 |
+
|
694 |
+
instances = []
|
695 |
+
for i in range(len(parsed)):
|
696 |
+
# Prepare instance.
|
697 |
+
instance = {}
|
698 |
+
|
699 |
+
if phrase_centric:
|
700 |
+
bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
|
701 |
+
else:
|
702 |
+
bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
|
703 |
+
instance['bbox'] = self.box_quantizer.dequantize(
|
704 |
+
boxes=torch.tensor(bbox_bins),
|
705 |
+
size=image_size
|
706 |
+
).tolist()
|
707 |
+
|
708 |
+
if phrase_centric:
|
709 |
+
instance['cat_name'] = parsed[i].group(1).lower().strip()
|
710 |
+
else:
|
711 |
+
instance['cat_name'] = parsed[i].group(5).lower().strip()
|
712 |
+
instances.append(instance)
|
713 |
+
|
714 |
+
return instances
|
715 |
+
|
716 |
+
def parse_ocr_from_text_and_spans(self,
|
717 |
+
text,
|
718 |
+
pattern,
|
719 |
+
image_size,
|
720 |
+
area_threshold=-1.0,
|
721 |
+
):
|
722 |
+
bboxes = []
|
723 |
+
labels = []
|
724 |
+
text = text.replace('<s>', '')
|
725 |
+
# ocr with regions
|
726 |
+
parsed = re.findall(pattern, text)
|
727 |
+
instances = []
|
728 |
+
image_width, image_height = image_size
|
729 |
+
|
730 |
+
for ocr_line in parsed:
|
731 |
+
ocr_content = ocr_line[0]
|
732 |
+
quad_box = ocr_line[1:]
|
733 |
+
quad_box = [int(i) for i in quad_box]
|
734 |
+
quad_box = self.coordinates_quantizer.dequantize(
|
735 |
+
torch.tensor(np.array(quad_box).reshape(-1, 2)),
|
736 |
+
size=image_size
|
737 |
+
).reshape(-1).tolist()
|
738 |
+
|
739 |
+
if area_threshold > 0:
|
740 |
+
x_coords = [i for i in quad_box[0::2]]
|
741 |
+
y_coords = [i for i in quad_box[1::2]]
|
742 |
+
|
743 |
+
# apply the Shoelace formula
|
744 |
+
area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
|
745 |
+
|
746 |
+
if area < (image_width * image_height) * area_threshold:
|
747 |
+
continue
|
748 |
+
|
749 |
+
bboxes.append(quad_box)
|
750 |
+
labels.append(ocr_content)
|
751 |
+
instances.append({
|
752 |
+
'quad_box': quad_box,
|
753 |
+
'text': ocr_content,
|
754 |
+
})
|
755 |
+
return instances
|
756 |
+
|
757 |
+
def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
|
758 |
+
# ignore <s> </s> and <pad>
|
759 |
+
cur_span = 0
|
760 |
+
if text.startswith('<s>'):
|
761 |
+
cur_span += 3
|
762 |
+
|
763 |
+
text = text.replace('<s>', '')
|
764 |
+
text = text.replace('</s>', '')
|
765 |
+
text = text.replace('<pad>', '')
|
766 |
+
|
767 |
+
pattern = r"([^<]+(?:<loc_\d+>){4,})"
|
768 |
+
phrases = re.findall(pattern, text)
|
769 |
+
|
770 |
+
# pattern should be text pattern and od pattern
|
771 |
+
pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
|
772 |
+
box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
|
773 |
+
|
774 |
+
instances = []
|
775 |
+
for pharse_text in phrases:
|
776 |
+
phrase_text_strip = pharse_text.replace('<ground>', '', 1)
|
777 |
+
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
778 |
+
|
779 |
+
if phrase_text_strip == '':
|
780 |
+
cur_span += len(pharse_text)
|
781 |
+
continue
|
782 |
+
|
783 |
+
# Prepare instance.
|
784 |
+
instance = {}
|
785 |
+
|
786 |
+
# parse phrase, get string
|
787 |
+
phrase = re.search(pattern, phrase_text_strip)
|
788 |
+
if phrase is None:
|
789 |
+
cur_span += len(pharse_text)
|
790 |
+
continue
|
791 |
+
|
792 |
+
# parse bboxes by box_pattern
|
793 |
+
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
794 |
+
if len(bboxes_parsed) == 0:
|
795 |
+
cur_span += len(pharse_text)
|
796 |
+
continue
|
797 |
+
|
798 |
+
phrase = phrase.group()
|
799 |
+
# remove leading and trailing spaces
|
800 |
+
phrase = phrase.strip()
|
801 |
+
|
802 |
+
if phrase in self.black_list_of_phrase_grounding:
|
803 |
+
cur_span += len(pharse_text)
|
804 |
+
continue
|
805 |
+
|
806 |
+
# a list of list
|
807 |
+
bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
|
808 |
+
instance['bbox'] = self.box_quantizer.dequantize(
|
809 |
+
boxes=torch.tensor(bbox_bins),
|
810 |
+
size=image_size
|
811 |
+
).tolist()
|
812 |
+
|
813 |
+
# exclude non-ascii characters
|
814 |
+
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
815 |
+
instance['cat_name'] = phrase
|
816 |
+
|
817 |
+
instances.append(instance)
|
818 |
+
|
819 |
+
return instances
|
820 |
+
|
821 |
+
def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
|
822 |
+
# temporary parse solution, split by '.'
|
823 |
+
# ignore <s> </s> and <pad>
|
824 |
+
|
825 |
+
text = text.replace('<s>', '')
|
826 |
+
text = text.replace('</s>', '')
|
827 |
+
text = text.replace('<pad>', '')
|
828 |
+
|
829 |
+
if allow_empty_phrase:
|
830 |
+
pattern = rf"(?:(?:<loc_\d+>){{4,}})"
|
831 |
+
else:
|
832 |
+
pattern = r"([^<]+(?:<loc_\d+>){4,})"
|
833 |
+
phrases = re.findall(pattern, text)
|
834 |
+
|
835 |
+
# pattern should be text pattern and od pattern
|
836 |
+
pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
|
837 |
+
box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
|
838 |
+
|
839 |
+
instances = []
|
840 |
+
for pharse_text in phrases:
|
841 |
+
phrase_text_strip = pharse_text.replace('<ground>', '', 1)
|
842 |
+
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
843 |
+
|
844 |
+
if phrase_text_strip == '' and not allow_empty_phrase:
|
845 |
+
continue
|
846 |
+
|
847 |
+
# parse phrase, get string
|
848 |
+
phrase = re.search(pattern, phrase_text_strip)
|
849 |
+
if phrase is None:
|
850 |
+
continue
|
851 |
+
|
852 |
+
phrase = phrase.group()
|
853 |
+
# remove leading and trailing spaces
|
854 |
+
phrase = phrase.strip()
|
855 |
+
|
856 |
+
# parse bboxes by box_pattern
|
857 |
+
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
858 |
+
if len(bboxes_parsed) == 0:
|
859 |
+
continue
|
860 |
+
|
861 |
+
# a list of list
|
862 |
+
bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
|
863 |
+
|
864 |
+
bboxes = self.box_quantizer.dequantize(
|
865 |
+
boxes=torch.tensor(bbox_bins),
|
866 |
+
size=image_size
|
867 |
+
).tolist()
|
868 |
+
|
869 |
+
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
870 |
+
for _bboxes in bboxes:
|
871 |
+
# Prepare instance.
|
872 |
+
instance = {}
|
873 |
+
instance['bbox'] = _bboxes
|
874 |
+
# exclude non-ascii characters
|
875 |
+
instance['cat_name'] = phrase
|
876 |
+
instances.append(instance)
|
877 |
+
|
878 |
+
return instances
|
879 |
+
|
880 |
+
def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
|
881 |
+
allow_empty_phrase=False,
|
882 |
+
polygon_sep_token='<sep>',
|
883 |
+
polygon_start_token='<poly>',
|
884 |
+
polygon_end_token='</poly>',
|
885 |
+
with_box_at_start=False,
|
886 |
+
):
|
887 |
+
|
888 |
+
# ref_seg format: '<expression><x1><y1><x2><y2><><><sep><><><><>'
|
889 |
+
# ignore <s> </s> and <pad>
|
890 |
+
|
891 |
+
text = text.replace('<s>', '')
|
892 |
+
text = text.replace('</s>', '')
|
893 |
+
text = text.replace('<pad>', '')
|
894 |
+
|
895 |
+
if allow_empty_phrase:
|
896 |
+
pattern = rf"(?:(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
|
897 |
+
else:
|
898 |
+
# [^<]+: This part matches one or more characters that are not the < symbol.
|
899 |
+
# The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
|
900 |
+
#
|
901 |
+
pattern = rf"([^<]+(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
|
902 |
+
phrases = re.findall(pattern, text)
|
903 |
+
|
904 |
+
phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
|
905 |
+
box_pattern = rf'((?:<loc_\d+>)+)(?:{re.escape(polygon_sep_token)}|$)'
|
906 |
+
|
907 |
+
# one polygons instance is separated by polygon_start_token and polygon_end_token
|
908 |
+
polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
|
909 |
+
|
910 |
+
instances = []
|
911 |
+
for phrase_text in phrases:
|
912 |
+
|
913 |
+
# exclude loc_\d+>
|
914 |
+
# need to get span if want to include category score
|
915 |
+
phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
|
916 |
+
|
917 |
+
# phrase = phrase.replace('<poly>', '')
|
918 |
+
# phrase = phrase.replace('poly>', '')
|
919 |
+
|
920 |
+
if phrase_text_strip == '' and not allow_empty_phrase:
|
921 |
+
continue
|
922 |
+
|
923 |
+
|
924 |
+
# parse phrase, get string
|
925 |
+
phrase = re.search(phrase_string_pattern, phrase_text_strip)
|
926 |
+
if phrase is None:
|
927 |
+
continue
|
928 |
+
phrase = phrase.group()
|
929 |
+
# remove leading and trailing spaces
|
930 |
+
phrase = phrase.strip()
|
931 |
+
|
932 |
+
# parse bboxes by box_pattern
|
933 |
+
|
934 |
+
# split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
|
935 |
+
if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
|
936 |
+
polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
|
937 |
+
else:
|
938 |
+
polygons_instances_parsed = [phrase_text]
|
939 |
+
|
940 |
+
for _polygons_instances_parsed in polygons_instances_parsed:
|
941 |
+
# Prepare instance.
|
942 |
+
instance = {}
|
943 |
+
|
944 |
+
# polygons_parsed= list(re.finditer(box_pattern, phrase_text))
|
945 |
+
if isinstance(_polygons_instances_parsed, str):
|
946 |
+
polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
|
947 |
+
else:
|
948 |
+
polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
|
949 |
+
if len(polygons_parsed) == 0:
|
950 |
+
continue
|
951 |
+
|
952 |
+
# a list of list (polygon)
|
953 |
+
bbox = []
|
954 |
+
polygons = []
|
955 |
+
for _polygon_parsed in polygons_parsed:
|
956 |
+
# group 1: whole <loc_\d+>...</loc_\d+>
|
957 |
+
_polygon = _polygon_parsed.group(1)
|
958 |
+
# parse into list of int
|
959 |
+
_polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'<loc_(\d+)>', _polygon)]
|
960 |
+
if with_box_at_start and len(bbox) == 0:
|
961 |
+
if len(_polygon) > 4:
|
962 |
+
# no valid bbox prediction
|
963 |
+
bbox = _polygon[:4]
|
964 |
+
_polygon = _polygon[4:]
|
965 |
+
else:
|
966 |
+
bbox = [0, 0, 0, 0]
|
967 |
+
# abandon last element if is not paired
|
968 |
+
if len(_polygon) % 2 == 1:
|
969 |
+
_polygon = _polygon[:-1]
|
970 |
+
|
971 |
+
# reshape into (n, 2)
|
972 |
+
_polygon = self.coordinates_quantizer.dequantize(
|
973 |
+
torch.tensor(np.array(_polygon).reshape(-1, 2)),
|
974 |
+
size=image_size
|
975 |
+
).reshape(-1).tolist()
|
976 |
+
# reshape back
|
977 |
+
polygons.append(_polygon)
|
978 |
+
|
979 |
+
instance['cat_name'] = phrase
|
980 |
+
instance['polygons'] = polygons
|
981 |
+
if len(bbox) != 0:
|
982 |
+
instance['bbox'] = self.box_quantizer.dequantize(
|
983 |
+
boxes=torch.tensor([bbox]),
|
984 |
+
size=image_size
|
985 |
+
).tolist()[0]
|
986 |
+
|
987 |
+
instances.append(instance)
|
988 |
+
|
989 |
+
return instances
|
990 |
+
|
991 |
+
def __call__(
|
992 |
+
self,
|
993 |
+
text=None,
|
994 |
+
image_size=None,
|
995 |
+
parse_tasks=None,
|
996 |
+
):
|
997 |
+
"""
|
998 |
+
Args:
|
999 |
+
text: model outputs
|
1000 |
+
image_size: (width, height)
|
1001 |
+
parse_tasks: a list of tasks to parse, if None, parse all tasks.
|
1002 |
+
|
1003 |
+
"""
|
1004 |
+
if parse_tasks is not None:
|
1005 |
+
if isinstance(parse_tasks, str):
|
1006 |
+
parse_tasks = [parse_tasks]
|
1007 |
+
for _parse_task in parse_tasks:
|
1008 |
+
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
1009 |
+
|
1010 |
+
# sequence or text should be provided
|
1011 |
+
assert text is not None, 'text should be provided'
|
1012 |
+
|
1013 |
+
parsed_dict = {
|
1014 |
+
'text': text
|
1015 |
+
}
|
1016 |
+
|
1017 |
+
for task in self.parse_tasks:
|
1018 |
+
if parse_tasks is not None and task not in parse_tasks:
|
1019 |
+
continue
|
1020 |
+
|
1021 |
+
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
1022 |
+
|
1023 |
+
if task == 'ocr':
|
1024 |
+
instances = self.parse_ocr_from_text_and_spans(
|
1025 |
+
text,
|
1026 |
+
pattern=pattern,
|
1027 |
+
image_size=image_size,
|
1028 |
+
area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0),
|
1029 |
+
)
|
1030 |
+
parsed_dict['ocr'] = instances
|
1031 |
+
elif task == 'phrase_grounding':
|
1032 |
+
instances = self.parse_phrase_grounding_from_text_and_spans(
|
1033 |
+
text,
|
1034 |
+
pattern=pattern,
|
1035 |
+
image_size=image_size,
|
1036 |
+
)
|
1037 |
+
parsed_dict['phrase_grounding'] = instances
|
1038 |
+
elif task == 'pure_text':
|
1039 |
+
parsed_dict['pure_text'] = text
|
1040 |
+
elif task == 'description_with_bboxes':
|
1041 |
+
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1042 |
+
text,
|
1043 |
+
pattern=pattern,
|
1044 |
+
image_size=image_size,
|
1045 |
+
)
|
1046 |
+
parsed_dict['description_with_bboxes'] = instances
|
1047 |
+
elif task == 'description_with_polygons':
|
1048 |
+
instances = self.parse_description_with_polygons_from_text_and_spans(
|
1049 |
+
text,
|
1050 |
+
pattern=pattern,
|
1051 |
+
image_size=image_size,
|
1052 |
+
)
|
1053 |
+
parsed_dict['description_with_polygons'] = instances
|
1054 |
+
elif task == 'polygons':
|
1055 |
+
instances = self.parse_description_with_polygons_from_text_and_spans(
|
1056 |
+
text,
|
1057 |
+
pattern=pattern,
|
1058 |
+
image_size=image_size,
|
1059 |
+
allow_empty_phrase=True,
|
1060 |
+
)
|
1061 |
+
parsed_dict['polygons'] = instances
|
1062 |
+
elif task == 'bboxes':
|
1063 |
+
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1064 |
+
text,
|
1065 |
+
pattern=pattern,
|
1066 |
+
image_size=image_size,
|
1067 |
+
allow_empty_phrase=True,
|
1068 |
+
)
|
1069 |
+
parsed_dict['bboxes'] = instances
|
1070 |
+
elif task == 'description_with_bboxes_or_polygons':
|
1071 |
+
if '<poly>' in text:
|
1072 |
+
# only support either polygons or bboxes, not both at the same time
|
1073 |
+
instances = self.parse_description_with_polygons_from_text_and_spans(
|
1074 |
+
text,
|
1075 |
+
pattern=pattern,
|
1076 |
+
image_size=image_size,
|
1077 |
+
)
|
1078 |
+
else:
|
1079 |
+
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1080 |
+
text,
|
1081 |
+
pattern=pattern,
|
1082 |
+
image_size=image_size,
|
1083 |
+
)
|
1084 |
+
parsed_dict['description_with_bboxes_or_polygons'] = instances
|
1085 |
+
else:
|
1086 |
+
raise ValueError("task {} is not supported".format(task))
|
1087 |
+
|
1088 |
+
return parsed_dict
|
models/LLM/Florence-2-large-PromptGen-v2.0/special_tokens_map.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/LLM/Florence-2-large-PromptGen-v2.0/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/LLM/Florence-2-large-PromptGen-v2.0/tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/LLM/Florence-2-large-PromptGen-v2.0/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/RMBG/BEN2/BEN2.py
ADDED
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
import torch.utils.checkpoint as checkpoint
|
8 |
+
import numpy as np
|
9 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
from torchvision import transforms
|
12 |
+
import numpy as np
|
13 |
+
import random
|
14 |
+
import cv2
|
15 |
+
import os
|
16 |
+
import subprocess
|
17 |
+
import time
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def set_random_seed(seed):
|
24 |
+
random.seed(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
torch.cuda.manual_seed_all(seed)
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
torch.backends.cudnn.benchmark = False
|
31 |
+
set_random_seed(9)
|
32 |
+
|
33 |
+
|
34 |
+
torch.set_float32_matmul_precision('highest')
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class Mlp(nn.Module):
|
39 |
+
""" Multilayer perceptron."""
|
40 |
+
|
41 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
42 |
+
super().__init__()
|
43 |
+
out_features = out_features or in_features
|
44 |
+
hidden_features = hidden_features or in_features
|
45 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
46 |
+
self.act = act_layer()
|
47 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
48 |
+
self.drop = nn.Dropout(drop)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.fc1(x)
|
52 |
+
x = self.act(x)
|
53 |
+
x = self.drop(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def window_partition(x, window_size):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
x: (B, H, W, C)
|
63 |
+
window_size (int): window size
|
64 |
+
Returns:
|
65 |
+
windows: (num_windows*B, window_size, window_size, C)
|
66 |
+
"""
|
67 |
+
B, H, W, C = x.shape
|
68 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
69 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
70 |
+
return windows
|
71 |
+
|
72 |
+
|
73 |
+
def window_reverse(windows, window_size, H, W):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
windows: (num_windows*B, window_size, window_size, C)
|
77 |
+
window_size (int): Window size
|
78 |
+
H (int): Height of image
|
79 |
+
W (int): Width of image
|
80 |
+
Returns:
|
81 |
+
x: (B, H, W, C)
|
82 |
+
"""
|
83 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
84 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
85 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class WindowAttention(nn.Module):
|
90 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
91 |
+
It supports both of shifted and non-shifted window.
|
92 |
+
Args:
|
93 |
+
dim (int): Number of input channels.
|
94 |
+
window_size (tuple[int]): The height and width of the window.
|
95 |
+
num_heads (int): Number of attention heads.
|
96 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
97 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
98 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
99 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
103 |
+
|
104 |
+
super().__init__()
|
105 |
+
self.dim = dim
|
106 |
+
self.window_size = window_size # Wh, Ww
|
107 |
+
self.num_heads = num_heads
|
108 |
+
head_dim = dim // num_heads
|
109 |
+
self.scale = qk_scale or head_dim ** -0.5
|
110 |
+
|
111 |
+
# define a parameter table of relative position bias
|
112 |
+
self.relative_position_bias_table = nn.Parameter(
|
113 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
114 |
+
|
115 |
+
# get pair-wise relative position index for each token inside the window
|
116 |
+
coords_h = torch.arange(self.window_size[0])
|
117 |
+
coords_w = torch.arange(self.window_size[1])
|
118 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
119 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
120 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
121 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
122 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
123 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
124 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
125 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
126 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
127 |
+
|
128 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
129 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
130 |
+
self.proj = nn.Linear(dim, dim)
|
131 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
132 |
+
|
133 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
134 |
+
self.softmax = nn.Softmax(dim=-1)
|
135 |
+
|
136 |
+
def forward(self, x, mask=None):
|
137 |
+
""" Forward function.
|
138 |
+
Args:
|
139 |
+
x: input features with shape of (num_windows*B, N, C)
|
140 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
141 |
+
"""
|
142 |
+
B_, N, C = x.shape
|
143 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
144 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
145 |
+
|
146 |
+
q = q * self.scale
|
147 |
+
attn = (q @ k.transpose(-2, -1))
|
148 |
+
|
149 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
150 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
151 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
152 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
153 |
+
|
154 |
+
if mask is not None:
|
155 |
+
nW = mask.shape[0]
|
156 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
157 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
158 |
+
attn = self.softmax(attn)
|
159 |
+
else:
|
160 |
+
attn = self.softmax(attn)
|
161 |
+
|
162 |
+
attn = self.attn_drop(attn)
|
163 |
+
|
164 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
165 |
+
x = self.proj(x)
|
166 |
+
x = self.proj_drop(x)
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class SwinTransformerBlock(nn.Module):
|
171 |
+
""" Swin Transformer Block.
|
172 |
+
Args:
|
173 |
+
dim (int): Number of input channels.
|
174 |
+
num_heads (int): Number of attention heads.
|
175 |
+
window_size (int): Window size.
|
176 |
+
shift_size (int): Shift size for SW-MSA.
|
177 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
178 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
179 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
180 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
181 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
182 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
183 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
184 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
188 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
189 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
190 |
+
super().__init__()
|
191 |
+
self.dim = dim
|
192 |
+
self.num_heads = num_heads
|
193 |
+
self.window_size = window_size
|
194 |
+
self.shift_size = shift_size
|
195 |
+
self.mlp_ratio = mlp_ratio
|
196 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
197 |
+
|
198 |
+
self.norm1 = norm_layer(dim)
|
199 |
+
self.attn = WindowAttention(
|
200 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
201 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
202 |
+
|
203 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
204 |
+
self.norm2 = norm_layer(dim)
|
205 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
206 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
207 |
+
|
208 |
+
self.H = None
|
209 |
+
self.W = None
|
210 |
+
|
211 |
+
def forward(self, x, mask_matrix):
|
212 |
+
""" Forward function.
|
213 |
+
Args:
|
214 |
+
x: Input feature, tensor size (B, H*W, C).
|
215 |
+
H, W: Spatial resolution of the input feature.
|
216 |
+
mask_matrix: Attention mask for cyclic shift.
|
217 |
+
"""
|
218 |
+
B, L, C = x.shape
|
219 |
+
H, W = self.H, self.W
|
220 |
+
assert L == H * W, "input feature has wrong size"
|
221 |
+
|
222 |
+
shortcut = x
|
223 |
+
x = self.norm1(x)
|
224 |
+
x = x.view(B, H, W, C)
|
225 |
+
|
226 |
+
# pad feature maps to multiples of window size
|
227 |
+
pad_l = pad_t = 0
|
228 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
229 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
230 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
231 |
+
_, Hp, Wp, _ = x.shape
|
232 |
+
|
233 |
+
# cyclic shift
|
234 |
+
if self.shift_size > 0:
|
235 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
236 |
+
attn_mask = mask_matrix
|
237 |
+
else:
|
238 |
+
shifted_x = x
|
239 |
+
attn_mask = None
|
240 |
+
|
241 |
+
# partition windows
|
242 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
243 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
244 |
+
|
245 |
+
# W-MSA/SW-MSA
|
246 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
247 |
+
|
248 |
+
# merge windows
|
249 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
250 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
251 |
+
|
252 |
+
# reverse cyclic shift
|
253 |
+
if self.shift_size > 0:
|
254 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
255 |
+
else:
|
256 |
+
x = shifted_x
|
257 |
+
|
258 |
+
if pad_r > 0 or pad_b > 0:
|
259 |
+
x = x[:, :H, :W, :].contiguous()
|
260 |
+
|
261 |
+
x = x.view(B, H * W, C)
|
262 |
+
|
263 |
+
# FFN
|
264 |
+
x = shortcut + self.drop_path(x)
|
265 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
266 |
+
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class PatchMerging(nn.Module):
|
271 |
+
""" Patch Merging Layer
|
272 |
+
Args:
|
273 |
+
dim (int): Number of input channels.
|
274 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
275 |
+
"""
|
276 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
277 |
+
super().__init__()
|
278 |
+
self.dim = dim
|
279 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
280 |
+
self.norm = norm_layer(4 * dim)
|
281 |
+
|
282 |
+
def forward(self, x, H, W):
|
283 |
+
""" Forward function.
|
284 |
+
Args:
|
285 |
+
x: Input feature, tensor size (B, H*W, C).
|
286 |
+
H, W: Spatial resolution of the input feature.
|
287 |
+
"""
|
288 |
+
B, L, C = x.shape
|
289 |
+
assert L == H * W, "input feature has wrong size"
|
290 |
+
|
291 |
+
x = x.view(B, H, W, C)
|
292 |
+
|
293 |
+
# padding
|
294 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
295 |
+
if pad_input:
|
296 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
297 |
+
|
298 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
299 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
300 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
301 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
302 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
303 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
304 |
+
|
305 |
+
x = self.norm(x)
|
306 |
+
x = self.reduction(x)
|
307 |
+
|
308 |
+
return x
|
309 |
+
|
310 |
+
|
311 |
+
class BasicLayer(nn.Module):
|
312 |
+
""" A basic Swin Transformer layer for one stage.
|
313 |
+
Args:
|
314 |
+
dim (int): Number of feature channels
|
315 |
+
depth (int): Depths of this stage.
|
316 |
+
num_heads (int): Number of attention head.
|
317 |
+
window_size (int): Local window size. Default: 7.
|
318 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
319 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
320 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
321 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
322 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
323 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
324 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
325 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
326 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self,
|
330 |
+
dim,
|
331 |
+
depth,
|
332 |
+
num_heads,
|
333 |
+
window_size=7,
|
334 |
+
mlp_ratio=4.,
|
335 |
+
qkv_bias=True,
|
336 |
+
qk_scale=None,
|
337 |
+
drop=0.,
|
338 |
+
attn_drop=0.,
|
339 |
+
drop_path=0.,
|
340 |
+
norm_layer=nn.LayerNorm,
|
341 |
+
downsample=None,
|
342 |
+
use_checkpoint=False):
|
343 |
+
super().__init__()
|
344 |
+
self.window_size = window_size
|
345 |
+
self.shift_size = window_size // 2
|
346 |
+
self.depth = depth
|
347 |
+
self.use_checkpoint = use_checkpoint
|
348 |
+
|
349 |
+
# build blocks
|
350 |
+
self.blocks = nn.ModuleList([
|
351 |
+
SwinTransformerBlock(
|
352 |
+
dim=dim,
|
353 |
+
num_heads=num_heads,
|
354 |
+
window_size=window_size,
|
355 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
356 |
+
mlp_ratio=mlp_ratio,
|
357 |
+
qkv_bias=qkv_bias,
|
358 |
+
qk_scale=qk_scale,
|
359 |
+
drop=drop,
|
360 |
+
attn_drop=attn_drop,
|
361 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
362 |
+
norm_layer=norm_layer)
|
363 |
+
for i in range(depth)])
|
364 |
+
|
365 |
+
# patch merging layer
|
366 |
+
if downsample is not None:
|
367 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
368 |
+
else:
|
369 |
+
self.downsample = None
|
370 |
+
|
371 |
+
def forward(self, x, H, W):
|
372 |
+
""" Forward function.
|
373 |
+
Args:
|
374 |
+
x: Input feature, tensor size (B, H*W, C).
|
375 |
+
H, W: Spatial resolution of the input feature.
|
376 |
+
"""
|
377 |
+
|
378 |
+
# calculate attention mask for SW-MSA
|
379 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
380 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
381 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
382 |
+
h_slices = (slice(0, -self.window_size),
|
383 |
+
slice(-self.window_size, -self.shift_size),
|
384 |
+
slice(-self.shift_size, None))
|
385 |
+
w_slices = (slice(0, -self.window_size),
|
386 |
+
slice(-self.window_size, -self.shift_size),
|
387 |
+
slice(-self.shift_size, None))
|
388 |
+
cnt = 0
|
389 |
+
for h in h_slices:
|
390 |
+
for w in w_slices:
|
391 |
+
img_mask[:, h, w, :] = cnt
|
392 |
+
cnt += 1
|
393 |
+
|
394 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
395 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
396 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
397 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
398 |
+
|
399 |
+
for blk in self.blocks:
|
400 |
+
blk.H, blk.W = H, W
|
401 |
+
if self.use_checkpoint:
|
402 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
403 |
+
else:
|
404 |
+
x = blk(x, attn_mask)
|
405 |
+
if self.downsample is not None:
|
406 |
+
x_down = self.downsample(x, H, W)
|
407 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
408 |
+
return x, H, W, x_down, Wh, Ww
|
409 |
+
else:
|
410 |
+
return x, H, W, x, H, W
|
411 |
+
|
412 |
+
|
413 |
+
class PatchEmbed(nn.Module):
|
414 |
+
""" Image to Patch Embedding
|
415 |
+
Args:
|
416 |
+
patch_size (int): Patch token size. Default: 4.
|
417 |
+
in_chans (int): Number of input image channels. Default: 3.
|
418 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
419 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
423 |
+
super().__init__()
|
424 |
+
patch_size = to_2tuple(patch_size)
|
425 |
+
self.patch_size = patch_size
|
426 |
+
|
427 |
+
self.in_chans = in_chans
|
428 |
+
self.embed_dim = embed_dim
|
429 |
+
|
430 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
431 |
+
if norm_layer is not None:
|
432 |
+
self.norm = norm_layer(embed_dim)
|
433 |
+
else:
|
434 |
+
self.norm = None
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
"""Forward function."""
|
438 |
+
# padding
|
439 |
+
_, _, H, W = x.size()
|
440 |
+
if W % self.patch_size[1] != 0:
|
441 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
442 |
+
if H % self.patch_size[0] != 0:
|
443 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
444 |
+
|
445 |
+
x = self.proj(x) # B C Wh Ww
|
446 |
+
if self.norm is not None:
|
447 |
+
Wh, Ww = x.size(2), x.size(3)
|
448 |
+
x = x.flatten(2).transpose(1, 2)
|
449 |
+
x = self.norm(x)
|
450 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
451 |
+
|
452 |
+
return x
|
453 |
+
|
454 |
+
|
455 |
+
class SwinTransformer(nn.Module):
|
456 |
+
""" Swin Transformer backbone.
|
457 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
458 |
+
https://arxiv.org/pdf/2103.14030
|
459 |
+
Args:
|
460 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
461 |
+
used in absolute postion embedding. Default 224.
|
462 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
463 |
+
in_chans (int): Number of input image channels. Default: 3.
|
464 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
465 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
466 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
467 |
+
window_size (int): Window size. Default: 7.
|
468 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
469 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
470 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
471 |
+
drop_rate (float): Dropout rate.
|
472 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
473 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
474 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
475 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
476 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
477 |
+
out_indices (Sequence[int]): Output from which stages.
|
478 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
479 |
+
-1 means not freezing any parameters.
|
480 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
481 |
+
"""
|
482 |
+
|
483 |
+
def __init__(self,
|
484 |
+
pretrain_img_size=224,
|
485 |
+
patch_size=4,
|
486 |
+
in_chans=3,
|
487 |
+
embed_dim=96,
|
488 |
+
depths=[2, 2, 6, 2],
|
489 |
+
num_heads=[3, 6, 12, 24],
|
490 |
+
window_size=7,
|
491 |
+
mlp_ratio=4.,
|
492 |
+
qkv_bias=True,
|
493 |
+
qk_scale=None,
|
494 |
+
drop_rate=0.,
|
495 |
+
attn_drop_rate=0.,
|
496 |
+
drop_path_rate=0.2,
|
497 |
+
norm_layer=nn.LayerNorm,
|
498 |
+
ape=False,
|
499 |
+
patch_norm=True,
|
500 |
+
out_indices=(0, 1, 2, 3),
|
501 |
+
frozen_stages=-1,
|
502 |
+
use_checkpoint=False):
|
503 |
+
super().__init__()
|
504 |
+
|
505 |
+
self.pretrain_img_size = pretrain_img_size
|
506 |
+
self.num_layers = len(depths)
|
507 |
+
self.embed_dim = embed_dim
|
508 |
+
self.ape = ape
|
509 |
+
self.patch_norm = patch_norm
|
510 |
+
self.out_indices = out_indices
|
511 |
+
self.frozen_stages = frozen_stages
|
512 |
+
|
513 |
+
# split image into non-overlapping patches
|
514 |
+
self.patch_embed = PatchEmbed(
|
515 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
516 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
517 |
+
|
518 |
+
# absolute position embedding
|
519 |
+
if self.ape:
|
520 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
521 |
+
patch_size = to_2tuple(patch_size)
|
522 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
523 |
+
|
524 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
525 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
526 |
+
|
527 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
528 |
+
|
529 |
+
# stochastic depth
|
530 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
531 |
+
|
532 |
+
# build layers
|
533 |
+
self.layers = nn.ModuleList()
|
534 |
+
for i_layer in range(self.num_layers):
|
535 |
+
layer = BasicLayer(
|
536 |
+
dim=int(embed_dim * 2 ** i_layer),
|
537 |
+
depth=depths[i_layer],
|
538 |
+
num_heads=num_heads[i_layer],
|
539 |
+
window_size=window_size,
|
540 |
+
mlp_ratio=mlp_ratio,
|
541 |
+
qkv_bias=qkv_bias,
|
542 |
+
qk_scale=qk_scale,
|
543 |
+
drop=drop_rate,
|
544 |
+
attn_drop=attn_drop_rate,
|
545 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
546 |
+
norm_layer=norm_layer,
|
547 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
548 |
+
use_checkpoint=use_checkpoint)
|
549 |
+
self.layers.append(layer)
|
550 |
+
|
551 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
552 |
+
self.num_features = num_features
|
553 |
+
|
554 |
+
# add a norm layer for each output
|
555 |
+
for i_layer in out_indices:
|
556 |
+
layer = norm_layer(num_features[i_layer])
|
557 |
+
layer_name = f'norm{i_layer}'
|
558 |
+
self.add_module(layer_name, layer)
|
559 |
+
|
560 |
+
self._freeze_stages()
|
561 |
+
|
562 |
+
def _freeze_stages(self):
|
563 |
+
if self.frozen_stages >= 0:
|
564 |
+
self.patch_embed.eval()
|
565 |
+
for param in self.patch_embed.parameters():
|
566 |
+
param.requires_grad = False
|
567 |
+
|
568 |
+
if self.frozen_stages >= 1 and self.ape:
|
569 |
+
self.absolute_pos_embed.requires_grad = False
|
570 |
+
|
571 |
+
if self.frozen_stages >= 2:
|
572 |
+
self.pos_drop.eval()
|
573 |
+
for i in range(0, self.frozen_stages - 1):
|
574 |
+
m = self.layers[i]
|
575 |
+
m.eval()
|
576 |
+
for param in m.parameters():
|
577 |
+
param.requires_grad = False
|
578 |
+
|
579 |
+
|
580 |
+
def forward(self, x):
|
581 |
+
|
582 |
+
x = self.patch_embed(x)
|
583 |
+
|
584 |
+
Wh, Ww = x.size(2), x.size(3)
|
585 |
+
if self.ape:
|
586 |
+
# interpolate the position embedding to the corresponding size
|
587 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
588 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
589 |
+
|
590 |
+
outs = [x.contiguous()]
|
591 |
+
x = x.flatten(2).transpose(1, 2)
|
592 |
+
x = self.pos_drop(x)
|
593 |
+
|
594 |
+
|
595 |
+
for i in range(self.num_layers):
|
596 |
+
layer = self.layers[i]
|
597 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
598 |
+
|
599 |
+
|
600 |
+
if i in self.out_indices:
|
601 |
+
norm_layer = getattr(self, f'norm{i}')
|
602 |
+
x_out = norm_layer(x_out)
|
603 |
+
|
604 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
605 |
+
outs.append(out)
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
return tuple(outs)
|
610 |
+
|
611 |
+
|
612 |
+
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
|
617 |
+
|
618 |
+
def get_activation_fn(activation):
|
619 |
+
"""Return an activation function given a string"""
|
620 |
+
if activation == "gelu":
|
621 |
+
return F.gelu
|
622 |
+
|
623 |
+
raise RuntimeError(F"activation should be gelu, not {activation}.")
|
624 |
+
|
625 |
+
|
626 |
+
def make_cbr(in_dim, out_dim):
|
627 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
628 |
+
|
629 |
+
|
630 |
+
def make_cbg(in_dim, out_dim):
|
631 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
632 |
+
|
633 |
+
|
634 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
635 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
636 |
+
|
637 |
+
|
638 |
+
def resize_as(x, y, interpolation='bilinear'):
|
639 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
640 |
+
|
641 |
+
|
642 |
+
def image2patches(x):
|
643 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
644 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2 )
|
645 |
+
return x
|
646 |
+
|
647 |
+
|
648 |
+
def patches2image(x):
|
649 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
650 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
651 |
+
return x
|
652 |
+
|
653 |
+
|
654 |
+
|
655 |
+
class PositionEmbeddingSine:
|
656 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
657 |
+
super().__init__()
|
658 |
+
self.num_pos_feats = num_pos_feats
|
659 |
+
self.temperature = temperature
|
660 |
+
self.normalize = normalize
|
661 |
+
if scale is not None and normalize is False:
|
662 |
+
raise ValueError("normalize should be True if scale is passed")
|
663 |
+
if scale is None:
|
664 |
+
scale = 2 * math.pi
|
665 |
+
self.scale = scale
|
666 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
667 |
+
|
668 |
+
def __call__(self, b, h, w):
|
669 |
+
device = self.dim_t.device
|
670 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
671 |
+
assert mask is not None
|
672 |
+
not_mask = ~mask
|
673 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
674 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
675 |
+
if self.normalize:
|
676 |
+
eps = 1e-6
|
677 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
678 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
679 |
+
|
680 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
681 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
682 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
683 |
+
|
684 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
685 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
686 |
+
|
687 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
688 |
+
|
689 |
+
|
690 |
+
|
691 |
+
class PositionEmbeddingSine:
|
692 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
693 |
+
super().__init__()
|
694 |
+
self.num_pos_feats = num_pos_feats
|
695 |
+
self.temperature = temperature
|
696 |
+
self.normalize = normalize
|
697 |
+
if scale is not None and normalize is False:
|
698 |
+
raise ValueError("normalize should be True if scale is passed")
|
699 |
+
if scale is None:
|
700 |
+
scale = 2 * math.pi
|
701 |
+
self.scale = scale
|
702 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
703 |
+
|
704 |
+
def __call__(self, b, h, w):
|
705 |
+
device = self.dim_t.device
|
706 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
707 |
+
assert mask is not None
|
708 |
+
not_mask = ~mask
|
709 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
710 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
711 |
+
if self.normalize:
|
712 |
+
eps = 1e-6
|
713 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
714 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
715 |
+
|
716 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
717 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
718 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
719 |
+
|
720 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
721 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
722 |
+
|
723 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
724 |
+
|
725 |
+
|
726 |
+
class MCLM(nn.Module):
|
727 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
728 |
+
super(MCLM, self).__init__()
|
729 |
+
self.attention = nn.ModuleList([
|
730 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
731 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
732 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
733 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
734 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
735 |
+
])
|
736 |
+
|
737 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
738 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
739 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
740 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
741 |
+
self.norm1 = nn.LayerNorm(d_model)
|
742 |
+
self.norm2 = nn.LayerNorm(d_model)
|
743 |
+
self.dropout = nn.Dropout(0.1)
|
744 |
+
self.dropout1 = nn.Dropout(0.1)
|
745 |
+
self.dropout2 = nn.Dropout(0.1)
|
746 |
+
self.activation = get_activation_fn('gelu')
|
747 |
+
self.pool_ratios = pool_ratios
|
748 |
+
self.p_poses = []
|
749 |
+
self.g_pos = None
|
750 |
+
self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
|
751 |
+
|
752 |
+
def forward(self, l, g):
|
753 |
+
"""
|
754 |
+
l: 4,c,h,w
|
755 |
+
g: 1,c,h,w
|
756 |
+
"""
|
757 |
+
self.p_poses = []
|
758 |
+
self.g_pos = None
|
759 |
+
b, c, h, w = l.size()
|
760 |
+
# 4,c,h,w -> 1,c,2h,2w
|
761 |
+
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
762 |
+
|
763 |
+
pools = []
|
764 |
+
for pool_ratio in self.pool_ratios:
|
765 |
+
# b,c,h,w
|
766 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
767 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
768 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
769 |
+
if self.g_pos is None:
|
770 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
|
771 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
772 |
+
self.p_poses.append(pos_emb)
|
773 |
+
pools = torch.cat(pools, 0)
|
774 |
+
if self.g_pos is None:
|
775 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
776 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
777 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
778 |
+
|
779 |
+
device = pools.device
|
780 |
+
self.p_poses = self.p_poses.to(device)
|
781 |
+
self.g_pos = self.g_pos.to(device)
|
782 |
+
|
783 |
+
|
784 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
785 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
786 |
+
|
787 |
+
|
788 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
789 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
790 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
791 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
792 |
+
|
793 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
794 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
795 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
796 |
+
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
|
797 |
+
outputs_re = []
|
798 |
+
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
799 |
+
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
|
800 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
801 |
+
|
802 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
803 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
804 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
805 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
806 |
+
|
807 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
808 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
809 |
+
|
810 |
+
|
811 |
+
|
812 |
+
|
813 |
+
|
814 |
+
|
815 |
+
|
816 |
+
|
817 |
+
|
818 |
+
class MCRM(nn.Module):
|
819 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
820 |
+
super(MCRM, self).__init__()
|
821 |
+
self.attention = nn.ModuleList([
|
822 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
823 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
824 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
825 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
826 |
+
])
|
827 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
828 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
829 |
+
self.norm1 = nn.LayerNorm(d_model)
|
830 |
+
self.norm2 = nn.LayerNorm(d_model)
|
831 |
+
self.dropout = nn.Dropout(0.1)
|
832 |
+
self.dropout1 = nn.Dropout(0.1)
|
833 |
+
self.dropout2 = nn.Dropout(0.1)
|
834 |
+
self.sigmoid = nn.Sigmoid()
|
835 |
+
self.activation = get_activation_fn('gelu')
|
836 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
837 |
+
self.pool_ratios = pool_ratios
|
838 |
+
|
839 |
+
def forward(self, x):
|
840 |
+
device = x.device
|
841 |
+
b, c, h, w = x.size()
|
842 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
843 |
+
|
844 |
+
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
845 |
+
|
846 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
847 |
+
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
848 |
+
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
849 |
+
|
850 |
+
pools = []
|
851 |
+
for pool_ratio in self.pool_ratios:
|
852 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
853 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
854 |
+
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
855 |
+
|
856 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
857 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
858 |
+
|
859 |
+
outputs = []
|
860 |
+
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
861 |
+
v = pools[i]
|
862 |
+
k = v
|
863 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
864 |
+
|
865 |
+
outputs = torch.cat(outputs, 1)
|
866 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
867 |
+
src = self.norm1(src)
|
868 |
+
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
869 |
+
src = self.norm2(src)
|
870 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
871 |
+
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
872 |
+
|
873 |
+
return torch.cat((src, glb), 0), token_attention_map
|
874 |
+
|
875 |
+
|
876 |
+
|
877 |
+
class BEN_Base(nn.Module):
|
878 |
+
def __init__(self):
|
879 |
+
super().__init__()
|
880 |
+
|
881 |
+
self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
|
882 |
+
emb_dim = 128
|
883 |
+
self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
884 |
+
self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
885 |
+
self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
886 |
+
self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
887 |
+
self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
888 |
+
|
889 |
+
self.output5 = make_cbr(1024, emb_dim)
|
890 |
+
self.output4 = make_cbr(512, emb_dim)
|
891 |
+
self.output3 = make_cbr(256, emb_dim)
|
892 |
+
self.output2 = make_cbr(128, emb_dim)
|
893 |
+
self.output1 = make_cbr(128, emb_dim)
|
894 |
+
|
895 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
896 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
897 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
898 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
899 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
900 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
901 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
902 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
903 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
904 |
+
|
905 |
+
self.insmask_head = nn.Sequential(
|
906 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
907 |
+
nn.InstanceNorm2d(384),
|
908 |
+
nn.GELU(),
|
909 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1),
|
910 |
+
nn.InstanceNorm2d(384),
|
911 |
+
nn.GELU(),
|
912 |
+
nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
|
913 |
+
)
|
914 |
+
|
915 |
+
self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
916 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
917 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
918 |
+
self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
919 |
+
|
920 |
+
for m in self.modules():
|
921 |
+
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
922 |
+
m.inplace = True
|
923 |
+
|
924 |
+
|
925 |
+
|
926 |
+
@torch.inference_mode()
|
927 |
+
@torch.autocast(device_type="cuda",dtype=torch.float16)
|
928 |
+
def forward(self, x):
|
929 |
+
real_batch = x.size(0)
|
930 |
+
|
931 |
+
shallow_batch = self.shallow(x)
|
932 |
+
glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
933 |
+
|
934 |
+
|
935 |
+
|
936 |
+
final_input = None
|
937 |
+
for i in range(real_batch):
|
938 |
+
start = i * 4
|
939 |
+
end = (i + 1) * 4
|
940 |
+
loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0))
|
941 |
+
input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0)
|
942 |
+
|
943 |
+
|
944 |
+
if final_input == None:
|
945 |
+
final_input= input_
|
946 |
+
else: final_input = torch.cat((final_input, input_), dim=0)
|
947 |
+
|
948 |
+
features = self.backbone(final_input)
|
949 |
+
outputs = []
|
950 |
+
|
951 |
+
for i in range(real_batch):
|
952 |
+
|
953 |
+
start = i * 5
|
954 |
+
end = (i + 1) * 5
|
955 |
+
|
956 |
+
f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
|
957 |
+
f3 = features[3][start:end, :, :, :]
|
958 |
+
f2 = features[2][start:end, :, :, :]
|
959 |
+
f1 = features[1][start:end, :, :, :]
|
960 |
+
f0 = features[0][start:end, :, :, :]
|
961 |
+
e5 = self.output5(f4)
|
962 |
+
e4 = self.output4(f3)
|
963 |
+
e3 = self.output3(f2)
|
964 |
+
e2 = self.output2(f1)
|
965 |
+
e1 = self.output1(f0)
|
966 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
967 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
968 |
+
|
969 |
+
|
970 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
971 |
+
e4 = self.conv4(e4)
|
972 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
973 |
+
e3 = self.conv3(e3)
|
974 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
975 |
+
e2 = self.conv2(e2)
|
976 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
977 |
+
e1 = self.conv1(e1)
|
978 |
+
|
979 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
980 |
+
|
981 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
982 |
+
|
983 |
+
# add glb feat in
|
984 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
985 |
+
# merge
|
986 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
987 |
+
# shallow feature merge
|
988 |
+
shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0)
|
989 |
+
final_output = final_output + resize_as(shallow, final_output)
|
990 |
+
final_output = self.upsample1(rescale_to(final_output))
|
991 |
+
final_output = rescale_to(final_output + resize_as(shallow, final_output))
|
992 |
+
final_output = self.upsample2(final_output)
|
993 |
+
final_output = self.output(final_output)
|
994 |
+
mask = final_output.sigmoid()
|
995 |
+
outputs.append(mask)
|
996 |
+
|
997 |
+
return torch.cat(outputs, dim=0)
|
998 |
+
|
999 |
+
|
1000 |
+
|
1001 |
+
|
1002 |
+
def loadcheckpoints(self,model_path):
|
1003 |
+
model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
1004 |
+
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
1005 |
+
del model_path
|
1006 |
+
|
1007 |
+
def inference(self,image,refine_foreground=False):
|
1008 |
+
|
1009 |
+
set_random_seed(9)
|
1010 |
+
# image = ImageOps.exif_transpose(image)
|
1011 |
+
if isinstance(image, Image.Image):
|
1012 |
+
image, h, w,original_image = rgb_loader_refiner(image)
|
1013 |
+
if torch.cuda.is_available():
|
1014 |
+
|
1015 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1016 |
+
else:
|
1017 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
1018 |
+
|
1019 |
+
|
1020 |
+
with torch.no_grad():
|
1021 |
+
res = self.forward(img_tensor)
|
1022 |
+
|
1023 |
+
# Show Results
|
1024 |
+
if refine_foreground == True:
|
1025 |
+
|
1026 |
+
pred_pil = transforms.ToPILImage()(res.squeeze())
|
1027 |
+
image_masked = refine_foreground_process(original_image, pred_pil)
|
1028 |
+
|
1029 |
+
image_masked.putalpha(pred_pil.resize(original_image.size))
|
1030 |
+
return image_masked
|
1031 |
+
|
1032 |
+
else:
|
1033 |
+
alpha = postprocess_image(res, im_size=[w,h])
|
1034 |
+
pred_pil = transforms.ToPILImage()(alpha)
|
1035 |
+
mask = pred_pil.resize(original_image.size)
|
1036 |
+
original_image.putalpha(mask)
|
1037 |
+
# mask = Image.fromarray(alpha)
|
1038 |
+
|
1039 |
+
return original_image
|
1040 |
+
|
1041 |
+
|
1042 |
+
else:
|
1043 |
+
foregrounds = []
|
1044 |
+
for batch in image:
|
1045 |
+
image, h, w,original_image = rgb_loader_refiner(batch)
|
1046 |
+
if torch.cuda.is_available():
|
1047 |
+
|
1048 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1049 |
+
else:
|
1050 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
1051 |
+
|
1052 |
+
with torch.no_grad():
|
1053 |
+
res = self.forward(img_tensor)
|
1054 |
+
|
1055 |
+
if refine_foreground == True:
|
1056 |
+
|
1057 |
+
pred_pil = transforms.ToPILImage()(res.squeeze())
|
1058 |
+
image_masked = refine_foreground_process(original_image, pred_pil)
|
1059 |
+
|
1060 |
+
image_masked.putalpha(pred_pil.resize(original_image.size))
|
1061 |
+
|
1062 |
+
foregrounds.append(image_masked)
|
1063 |
+
else:
|
1064 |
+
alpha = postprocess_image(res, im_size=[w,h])
|
1065 |
+
pred_pil = transforms.ToPILImage()(alpha)
|
1066 |
+
mask = pred_pil.resize(original_image.size)
|
1067 |
+
original_image.putalpha(mask)
|
1068 |
+
# mask = Image.fromarray(alpha)
|
1069 |
+
foregrounds.append(original_image)
|
1070 |
+
|
1071 |
+
return foregrounds
|
1072 |
+
|
1073 |
+
|
1074 |
+
|
1075 |
+
|
1076 |
+
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
|
1077 |
+
|
1078 |
+
"""
|
1079 |
+
Segments the given video to extract the foreground (with alpha) from each frame
|
1080 |
+
and saves the result as either a WebM video (with alpha channel) or MP4 (with a
|
1081 |
+
color background).
|
1082 |
+
|
1083 |
+
Args:
|
1084 |
+
video_path (str):
|
1085 |
+
Path to the input video file.
|
1086 |
+
|
1087 |
+
output_path (str, optional):
|
1088 |
+
Directory (or full path) where the output video and/or files will be saved.
|
1089 |
+
Defaults to "./".
|
1090 |
+
|
1091 |
+
fps (int, optional):
|
1092 |
+
The frames per second (FPS) to use for the output video. If 0 (default), the
|
1093 |
+
original FPS of the input video is used. Otherwise, overrides it.
|
1094 |
+
|
1095 |
+
refine_foreground (bool, optional):
|
1096 |
+
Whether to run an additional “refine foreground” process on each frame.
|
1097 |
+
Defaults to False.
|
1098 |
+
|
1099 |
+
batch (int, optional):
|
1100 |
+
Number of frames to process at once (inference batch size). Large batch sizes
|
1101 |
+
may require more GPU memory. Defaults to 1.
|
1102 |
+
|
1103 |
+
print_frames_processed (bool, optional):
|
1104 |
+
If True (default), prints progress (how many frames have been processed) to
|
1105 |
+
the console.
|
1106 |
+
|
1107 |
+
webm (bool, optional):
|
1108 |
+
If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
|
1109 |
+
If False, exports an MP4 video composited over a solid color background.
|
1110 |
+
|
1111 |
+
rgb_value (tuple, optional):
|
1112 |
+
The RGB background color (e.g., green screen) used to composite frames when
|
1113 |
+
saving to MP4. Defaults to (0, 255, 0).
|
1114 |
+
|
1115 |
+
Returns:
|
1116 |
+
None. Writes the output video(s) to disk in the specified format.
|
1117 |
+
"""
|
1118 |
+
|
1119 |
+
|
1120 |
+
cap = cv2.VideoCapture(video_path)
|
1121 |
+
if not cap.isOpened():
|
1122 |
+
raise IOError(f"Cannot open video: {video_path}")
|
1123 |
+
|
1124 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
1125 |
+
original_fps = 30 if original_fps == 0 else original_fps
|
1126 |
+
fps = original_fps if fps == 0 else fps
|
1127 |
+
|
1128 |
+
ret, first_frame = cap.read()
|
1129 |
+
if not ret:
|
1130 |
+
raise ValueError("No frames found in the video.")
|
1131 |
+
height, width = first_frame.shape[:2]
|
1132 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
1133 |
+
|
1134 |
+
foregrounds = []
|
1135 |
+
frame_idx = 0
|
1136 |
+
processed_count = 0
|
1137 |
+
batch_frames = []
|
1138 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
1139 |
+
|
1140 |
+
while True:
|
1141 |
+
ret, frame = cap.read()
|
1142 |
+
if not ret:
|
1143 |
+
if batch_frames:
|
1144 |
+
batch_results = self.inference(batch_frames, refine_foreground)
|
1145 |
+
if isinstance(batch_results, Image.Image):
|
1146 |
+
foregrounds.append(batch_results)
|
1147 |
+
else:
|
1148 |
+
foregrounds.extend(batch_results)
|
1149 |
+
if print_frames_processed:
|
1150 |
+
print(f"Processed frames {frame_idx-len(batch_frames)+1} to {frame_idx} of {total_frames}")
|
1151 |
+
break
|
1152 |
+
|
1153 |
+
# Process every frame instead of using intervals
|
1154 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
1155 |
+
pil_frame = Image.fromarray(frame_rgb)
|
1156 |
+
batch_frames.append(pil_frame)
|
1157 |
+
|
1158 |
+
if len(batch_frames) == batch:
|
1159 |
+
batch_results = self.inference(batch_frames, refine_foreground)
|
1160 |
+
if isinstance(batch_results, Image.Image):
|
1161 |
+
foregrounds.append(batch_results)
|
1162 |
+
else:
|
1163 |
+
foregrounds.extend(batch_results)
|
1164 |
+
if print_frames_processed:
|
1165 |
+
print(f"Processed frames {frame_idx-batch+1} to {frame_idx} of {total_frames}")
|
1166 |
+
batch_frames = []
|
1167 |
+
processed_count += batch
|
1168 |
+
|
1169 |
+
frame_idx += 1
|
1170 |
+
|
1171 |
+
|
1172 |
+
if webm:
|
1173 |
+
alpha_webm_path = os.path.join(output_path, "foreground.webm")
|
1174 |
+
pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
|
1175 |
+
|
1176 |
+
else:
|
1177 |
+
cap.release()
|
1178 |
+
fg_output = os.path.join(output_path, 'foreground.mp4')
|
1179 |
+
|
1180 |
+
pil_images_to_mp4(foregrounds, fg_output, fps=original_fps,rgb_value=rgb_value)
|
1181 |
+
cv2.destroyAllWindows()
|
1182 |
+
|
1183 |
+
try:
|
1184 |
+
fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
|
1185 |
+
add_audio_to_video(fg_output, video_path, fg_audio_output)
|
1186 |
+
except Exception as e:
|
1187 |
+
print("No audio found in the original video")
|
1188 |
+
print(e)
|
1189 |
+
|
1190 |
+
|
1191 |
+
|
1192 |
+
|
1193 |
+
|
1194 |
+
def rgb_loader_refiner( original_image):
|
1195 |
+
h, w = original_image.size
|
1196 |
+
|
1197 |
+
image = original_image
|
1198 |
+
# Convert to RGB if necessary
|
1199 |
+
if image.mode != 'RGB':
|
1200 |
+
image = image.convert('RGB')
|
1201 |
+
|
1202 |
+
# Resize the image
|
1203 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
1204 |
+
|
1205 |
+
return image.convert('RGB'), h, w,original_image
|
1206 |
+
|
1207 |
+
# Define the image transformation
|
1208 |
+
img_transform = transforms.Compose([
|
1209 |
+
transforms.ToTensor(),
|
1210 |
+
transforms.ConvertImageDtype(torch.float16),
|
1211 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1212 |
+
])
|
1213 |
+
|
1214 |
+
img_transform32 = transforms.Compose([
|
1215 |
+
transforms.ToTensor(),
|
1216 |
+
transforms.ConvertImageDtype(torch.float32),
|
1217 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1218 |
+
])
|
1219 |
+
|
1220 |
+
|
1221 |
+
|
1222 |
+
|
1223 |
+
|
1224 |
+
def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
|
1225 |
+
"""
|
1226 |
+
Converts an array of PIL images to an MP4 video.
|
1227 |
+
|
1228 |
+
Args:
|
1229 |
+
images: List of PIL images
|
1230 |
+
output_path: Path to save the MP4 file
|
1231 |
+
fps: Frames per second (default: 24)
|
1232 |
+
rgb_value: Background RGB color tuple (default: green (0, 255, 0))
|
1233 |
+
"""
|
1234 |
+
if not images:
|
1235 |
+
raise ValueError("No images provided to convert to MP4.")
|
1236 |
+
|
1237 |
+
width, height = images[0].size
|
1238 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
1239 |
+
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
1240 |
+
|
1241 |
+
for image in images:
|
1242 |
+
# If image has alpha channel, composite onto the specified background color
|
1243 |
+
if image.mode == 'RGBA':
|
1244 |
+
# Create background image with specified RGB color
|
1245 |
+
background = Image.new('RGB', image.size, rgb_value)
|
1246 |
+
background = background.convert('RGBA')
|
1247 |
+
# Composite the image onto the background
|
1248 |
+
image = Image.alpha_composite(background, image)
|
1249 |
+
image = image.convert('RGB')
|
1250 |
+
else:
|
1251 |
+
# Ensure RGB format for non-alpha images
|
1252 |
+
image = image.convert('RGB')
|
1253 |
+
|
1254 |
+
# Convert to OpenCV format and write
|
1255 |
+
open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
1256 |
+
video_writer.write(open_cv_image)
|
1257 |
+
|
1258 |
+
video_writer.release()
|
1259 |
+
|
1260 |
+
def pil_images_to_webm_alpha(images, output_path, fps=30):
|
1261 |
+
"""
|
1262 |
+
Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
|
1263 |
+
|
1264 |
+
NOTE: Not all players will display alpha in WebM.
|
1265 |
+
Browsers like Chrome/Firefox typically do support VP9 alpha.
|
1266 |
+
"""
|
1267 |
+
if not images:
|
1268 |
+
raise ValueError("No images provided for WebM with alpha.")
|
1269 |
+
|
1270 |
+
# Ensure output directory exists
|
1271 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
1272 |
+
|
1273 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
1274 |
+
# Save frames as PNG (with alpha)
|
1275 |
+
for idx, img in enumerate(images):
|
1276 |
+
if img.mode != "RGBA":
|
1277 |
+
img = img.convert("RGBA")
|
1278 |
+
out_path = os.path.join(tmpdir, f"{idx:06d}.png")
|
1279 |
+
img.save(out_path, "PNG")
|
1280 |
+
|
1281 |
+
# Construct ffmpeg command
|
1282 |
+
# -c:v libvpx-vp9 => VP9 encoder
|
1283 |
+
# -pix_fmt yuva420p => alpha-enabled pixel format
|
1284 |
+
# -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
|
1285 |
+
ffmpeg_cmd = [
|
1286 |
+
"ffmpeg", "-y",
|
1287 |
+
"-framerate", str(fps),
|
1288 |
+
"-i", os.path.join(tmpdir, "%06d.png"),
|
1289 |
+
"-c:v", "libvpx-vp9",
|
1290 |
+
"-pix_fmt", "yuva420p",
|
1291 |
+
"-auto-alt-ref", "0",
|
1292 |
+
output_path
|
1293 |
+
]
|
1294 |
+
|
1295 |
+
subprocess.run(ffmpeg_cmd, check=True)
|
1296 |
+
|
1297 |
+
print(f"WebM with alpha saved to {output_path}")
|
1298 |
+
|
1299 |
+
def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
|
1300 |
+
"""
|
1301 |
+
Check if the original video has an audio stream. If yes, add it. If not, skip.
|
1302 |
+
"""
|
1303 |
+
# 1) Probe original video for audio streams
|
1304 |
+
probe_command = [
|
1305 |
+
'ffprobe', '-v', 'error',
|
1306 |
+
'-select_streams', 'a:0',
|
1307 |
+
'-show_entries', 'stream=index',
|
1308 |
+
'-of', 'csv=p=0',
|
1309 |
+
original_video_path
|
1310 |
+
]
|
1311 |
+
result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
1312 |
+
|
1313 |
+
# result.stdout is empty if no audio stream found
|
1314 |
+
if not result.stdout.strip():
|
1315 |
+
print("No audio track found in original video, skipping audio addition.")
|
1316 |
+
return
|
1317 |
+
|
1318 |
+
print("Audio track detected; proceeding to mux audio.")
|
1319 |
+
# 2) If audio found, run ffmpeg to add it
|
1320 |
+
command = [
|
1321 |
+
'ffmpeg', '-y',
|
1322 |
+
'-i', video_without_audio_path,
|
1323 |
+
'-i', original_video_path,
|
1324 |
+
'-c', 'copy',
|
1325 |
+
'-map', '0:v:0',
|
1326 |
+
'-map', '1:a:0', # we know there's an audio track now
|
1327 |
+
output_path
|
1328 |
+
]
|
1329 |
+
subprocess.run(command, check=True)
|
1330 |
+
print(f"Audio added successfully => {output_path}")
|
1331 |
+
|
1332 |
+
|
1333 |
+
|
1334 |
+
|
1335 |
+
|
1336 |
+
### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
|
1337 |
+
def refine_foreground_process(image, mask, r=90):
|
1338 |
+
if mask.size != image.size:
|
1339 |
+
mask = mask.resize(image.size)
|
1340 |
+
image = np.array(image) / 255.0
|
1341 |
+
mask = np.array(mask) / 255.0
|
1342 |
+
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
|
1343 |
+
image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
|
1344 |
+
return image_masked
|
1345 |
+
|
1346 |
+
|
1347 |
+
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
1348 |
+
# Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
|
1349 |
+
alpha = alpha[:, :, None]
|
1350 |
+
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
|
1351 |
+
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
1352 |
+
|
1353 |
+
|
1354 |
+
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
1355 |
+
if isinstance(image, Image.Image):
|
1356 |
+
image = np.array(image) / 255.0
|
1357 |
+
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
1358 |
+
|
1359 |
+
blurred_FA = cv2.blur(F * alpha, (r, r))
|
1360 |
+
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
1361 |
+
|
1362 |
+
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
|
1363 |
+
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
|
1364 |
+
F = blurred_F + alpha * \
|
1365 |
+
(image - alpha * blurred_F - (1 - alpha) * blurred_B)
|
1366 |
+
F = np.clip(F, 0, 1)
|
1367 |
+
return F, blurred_B
|
1368 |
+
|
1369 |
+
|
1370 |
+
|
1371 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
1372 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
|
1373 |
+
ma = torch.max(result)
|
1374 |
+
mi = torch.min(result)
|
1375 |
+
result = (result - mi) / (ma - mi)
|
1376 |
+
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
1377 |
+
im_array = np.squeeze(im_array)
|
1378 |
+
return im_array
|
1379 |
+
|
1380 |
+
|
1381 |
+
|
1382 |
+
|
1383 |
+
def rgb_loader_refiner( original_image):
|
1384 |
+
h, w = original_image.size
|
1385 |
+
# # Apply EXIF orientation
|
1386 |
+
|
1387 |
+
image = ImageOps.exif_transpose(original_image)
|
1388 |
+
|
1389 |
+
if original_image.mode != 'RGB':
|
1390 |
+
original_image = original_image.convert('RGB')
|
1391 |
+
|
1392 |
+
image = original_image
|
1393 |
+
# Convert to RGB if necessary
|
1394 |
+
|
1395 |
+
# Resize the image
|
1396 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
1397 |
+
|
1398 |
+
return image, h, w,original_image
|
1399 |
+
|
1400 |
+
|
1401 |
+
|
models/RMBG/BEN2/BEN2_Base.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:926144a876bda06f125555b4f5a239ece89dc6eb838a863700ca9bf192161a1c
|
3 |
+
size 1134584206
|
models/RMBG/BEN2/__pycache__/BEN2.cpython-310.pyc
ADDED
Binary file (38.4 kB). View file
|
|
models/RMBG/RMBG-2.0/BiRefNet_config.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class BiRefNetConfig(PretrainedConfig):
|
4 |
+
model_type = "SegformerForSemanticSegmentation"
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
bb_pretrained=False,
|
8 |
+
**kwargs
|
9 |
+
):
|
10 |
+
self.bb_pretrained = bb_pretrained
|
11 |
+
super().__init__(**kwargs)
|
models/RMBG/RMBG-2.0/birefnet.py
ADDED
@@ -0,0 +1,2244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### config.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class Config():
|
8 |
+
def __init__(self) -> None:
|
9 |
+
# PATH settings
|
10 |
+
self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
|
11 |
+
|
12 |
+
# TASK settings
|
13 |
+
self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
|
14 |
+
self.training_set = {
|
15 |
+
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
|
16 |
+
'COD': 'TR-COD10K+TR-CAMO',
|
17 |
+
'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
|
18 |
+
'DIS5K+HRSOD+HRS10K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD', # leave DIS-VD for evaluation.
|
19 |
+
'P3M-10k': 'TR-P3M-10k',
|
20 |
+
}[self.task]
|
21 |
+
self.prompt4loc = ['dense', 'sparse'][0]
|
22 |
+
|
23 |
+
# Faster-Training settings
|
24 |
+
self.load_all = True
|
25 |
+
self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
|
26 |
+
# Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting.
|
27 |
+
# 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607.
|
28 |
+
# 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training.
|
29 |
+
self.precisionHigh = True
|
30 |
+
|
31 |
+
# MODEL settings
|
32 |
+
self.ms_supervision = True
|
33 |
+
self.out_ref = self.ms_supervision and True
|
34 |
+
self.dec_ipt = True
|
35 |
+
self.dec_ipt_split = True
|
36 |
+
self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder
|
37 |
+
self.mul_scl_ipt = ['', 'add', 'cat'][2]
|
38 |
+
self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2]
|
39 |
+
self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1]
|
40 |
+
self.dec_blk = ['BasicDecBlk', 'ResBlk', 'HierarAttDecBlk'][0]
|
41 |
+
|
42 |
+
# TRAINING settings
|
43 |
+
self.batch_size = 4
|
44 |
+
self.IoU_finetune_last_epochs = [
|
45 |
+
0,
|
46 |
+
{
|
47 |
+
'DIS5K': -50,
|
48 |
+
'COD': -20,
|
49 |
+
'HRSOD': -20,
|
50 |
+
'DIS5K+HRSOD+HRS10K': -20,
|
51 |
+
'P3M-10k': -20,
|
52 |
+
}[self.task]
|
53 |
+
][1] # choose 0 to skip
|
54 |
+
self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
|
55 |
+
self.size = 1024
|
56 |
+
self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader
|
57 |
+
|
58 |
+
# Backbone settings
|
59 |
+
self.bb = [
|
60 |
+
'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2
|
61 |
+
'swin_v1_t', 'swin_v1_s', # 3, 4
|
62 |
+
'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4
|
63 |
+
'pvt_v2_b0', 'pvt_v2_b1', # 7, 8
|
64 |
+
'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5
|
65 |
+
][6]
|
66 |
+
self.lateral_channels_in_collection = {
|
67 |
+
'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
|
68 |
+
'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
|
69 |
+
'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
|
70 |
+
'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96],
|
71 |
+
'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64],
|
72 |
+
}[self.bb]
|
73 |
+
if self.mul_scl_ipt == 'cat':
|
74 |
+
self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection]
|
75 |
+
self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else []
|
76 |
+
|
77 |
+
# MODEL settings - inactive
|
78 |
+
self.lat_blk = ['BasicLatBlk'][0]
|
79 |
+
self.dec_channels_inter = ['fixed', 'adap'][0]
|
80 |
+
self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0]
|
81 |
+
self.progressive_ref = self.refine and True
|
82 |
+
self.ender = self.progressive_ref and False
|
83 |
+
self.scale = self.progressive_ref and 2
|
84 |
+
self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`.
|
85 |
+
self.refine_iteration = 1
|
86 |
+
self.freeze_bb = False
|
87 |
+
self.model = [
|
88 |
+
'BiRefNet',
|
89 |
+
][0]
|
90 |
+
if self.dec_blk == 'HierarAttDecBlk':
|
91 |
+
self.batch_size = 2 ** [0, 1, 2, 3, 4][2]
|
92 |
+
|
93 |
+
# TRAINING settings - inactive
|
94 |
+
self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4]
|
95 |
+
self.optimizer = ['Adam', 'AdamW'][1]
|
96 |
+
self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch.
|
97 |
+
self.lr_decay_rate = 0.5
|
98 |
+
# Loss
|
99 |
+
self.lambdas_pix_last = {
|
100 |
+
# not 0 means opening this loss
|
101 |
+
# original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
|
102 |
+
'bce': 30 * 1, # high performance
|
103 |
+
'iou': 0.5 * 1, # 0 / 255
|
104 |
+
'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64)
|
105 |
+
'mse': 150 * 0, # can smooth the saliency map
|
106 |
+
'triplet': 3 * 0,
|
107 |
+
'reg': 100 * 0,
|
108 |
+
'ssim': 10 * 1, # help contours,
|
109 |
+
'cnt': 5 * 0, # help contours
|
110 |
+
'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4.
|
111 |
+
}
|
112 |
+
self.lambdas_cls = {
|
113 |
+
'ce': 5.0
|
114 |
+
}
|
115 |
+
# Adv
|
116 |
+
self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training
|
117 |
+
self.lambda_adv_d = 3. * (self.lambda_adv_g > 0)
|
118 |
+
|
119 |
+
# PATH settings - inactive
|
120 |
+
self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')
|
121 |
+
self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights')
|
122 |
+
self.weights = {
|
123 |
+
'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'),
|
124 |
+
'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
|
125 |
+
'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
|
126 |
+
'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
|
127 |
+
'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
|
128 |
+
'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
|
129 |
+
'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]),
|
130 |
+
'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]),
|
131 |
+
}
|
132 |
+
|
133 |
+
# Callbacks - inactive
|
134 |
+
self.verbose_eval = True
|
135 |
+
self.only_S_MAE = False
|
136 |
+
self.use_fp16 = False # Bugs. It may cause nan in training.
|
137 |
+
self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs
|
138 |
+
|
139 |
+
# others
|
140 |
+
self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0')
|
141 |
+
|
142 |
+
self.batch_size_valid = 1
|
143 |
+
self.rand_seed = 7
|
144 |
+
# run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
|
145 |
+
# with open(run_sh_file[0], 'r') as f:
|
146 |
+
# lines = f.readlines()
|
147 |
+
# self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
|
148 |
+
# self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])
|
149 |
+
# self.val_step = [0, self.save_step][0]
|
150 |
+
|
151 |
+
def print_task(self) -> None:
|
152 |
+
# Return task for choosing settings in shell scripts.
|
153 |
+
print(self.task)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
### models/backbones/pvt_v2.py
|
158 |
+
|
159 |
+
import torch
|
160 |
+
import torch.nn as nn
|
161 |
+
from functools import partial
|
162 |
+
|
163 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
164 |
+
from timm.models.registry import register_model
|
165 |
+
|
166 |
+
import math
|
167 |
+
|
168 |
+
# from config import Config
|
169 |
+
|
170 |
+
# config = Config()
|
171 |
+
|
172 |
+
class Mlp(nn.Module):
|
173 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
174 |
+
super().__init__()
|
175 |
+
out_features = out_features or in_features
|
176 |
+
hidden_features = hidden_features or in_features
|
177 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
178 |
+
self.dwconv = DWConv(hidden_features)
|
179 |
+
self.act = act_layer()
|
180 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
181 |
+
self.drop = nn.Dropout(drop)
|
182 |
+
|
183 |
+
self.apply(self._init_weights)
|
184 |
+
|
185 |
+
def _init_weights(self, m):
|
186 |
+
if isinstance(m, nn.Linear):
|
187 |
+
trunc_normal_(m.weight, std=.02)
|
188 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
189 |
+
nn.init.constant_(m.bias, 0)
|
190 |
+
elif isinstance(m, nn.LayerNorm):
|
191 |
+
nn.init.constant_(m.bias, 0)
|
192 |
+
nn.init.constant_(m.weight, 1.0)
|
193 |
+
elif isinstance(m, nn.Conv2d):
|
194 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
195 |
+
fan_out //= m.groups
|
196 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
197 |
+
if m.bias is not None:
|
198 |
+
m.bias.data.zero_()
|
199 |
+
|
200 |
+
def forward(self, x, H, W):
|
201 |
+
x = self.fc1(x)
|
202 |
+
x = self.dwconv(x, H, W)
|
203 |
+
x = self.act(x)
|
204 |
+
x = self.drop(x)
|
205 |
+
x = self.fc2(x)
|
206 |
+
x = self.drop(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class Attention(nn.Module):
|
211 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
212 |
+
super().__init__()
|
213 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
214 |
+
|
215 |
+
self.dim = dim
|
216 |
+
self.num_heads = num_heads
|
217 |
+
head_dim = dim // num_heads
|
218 |
+
self.scale = qk_scale or head_dim ** -0.5
|
219 |
+
|
220 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
221 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
222 |
+
self.attn_drop_prob = attn_drop
|
223 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
224 |
+
self.proj = nn.Linear(dim, dim)
|
225 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
226 |
+
|
227 |
+
self.sr_ratio = sr_ratio
|
228 |
+
if sr_ratio > 1:
|
229 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
230 |
+
self.norm = nn.LayerNorm(dim)
|
231 |
+
|
232 |
+
self.apply(self._init_weights)
|
233 |
+
|
234 |
+
def _init_weights(self, m):
|
235 |
+
if isinstance(m, nn.Linear):
|
236 |
+
trunc_normal_(m.weight, std=.02)
|
237 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
238 |
+
nn.init.constant_(m.bias, 0)
|
239 |
+
elif isinstance(m, nn.LayerNorm):
|
240 |
+
nn.init.constant_(m.bias, 0)
|
241 |
+
nn.init.constant_(m.weight, 1.0)
|
242 |
+
elif isinstance(m, nn.Conv2d):
|
243 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
244 |
+
fan_out //= m.groups
|
245 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
246 |
+
if m.bias is not None:
|
247 |
+
m.bias.data.zero_()
|
248 |
+
|
249 |
+
def forward(self, x, H, W):
|
250 |
+
B, N, C = x.shape
|
251 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
252 |
+
|
253 |
+
if self.sr_ratio > 1:
|
254 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
255 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
256 |
+
x_ = self.norm(x_)
|
257 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
258 |
+
else:
|
259 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
260 |
+
k, v = kv[0], kv[1]
|
261 |
+
|
262 |
+
if config.SDPA_enabled:
|
263 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
264 |
+
q, k, v,
|
265 |
+
attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False
|
266 |
+
).transpose(1, 2).reshape(B, N, C)
|
267 |
+
else:
|
268 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
269 |
+
attn = attn.softmax(dim=-1)
|
270 |
+
attn = self.attn_drop(attn)
|
271 |
+
|
272 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
273 |
+
x = self.proj(x)
|
274 |
+
x = self.proj_drop(x)
|
275 |
+
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
class Block(nn.Module):
|
280 |
+
|
281 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
282 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
283 |
+
super().__init__()
|
284 |
+
self.norm1 = norm_layer(dim)
|
285 |
+
self.attn = Attention(
|
286 |
+
dim,
|
287 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
288 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
289 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
290 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
291 |
+
self.norm2 = norm_layer(dim)
|
292 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
293 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
294 |
+
|
295 |
+
self.apply(self._init_weights)
|
296 |
+
|
297 |
+
def _init_weights(self, m):
|
298 |
+
if isinstance(m, nn.Linear):
|
299 |
+
trunc_normal_(m.weight, std=.02)
|
300 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
301 |
+
nn.init.constant_(m.bias, 0)
|
302 |
+
elif isinstance(m, nn.LayerNorm):
|
303 |
+
nn.init.constant_(m.bias, 0)
|
304 |
+
nn.init.constant_(m.weight, 1.0)
|
305 |
+
elif isinstance(m, nn.Conv2d):
|
306 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
307 |
+
fan_out //= m.groups
|
308 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
309 |
+
if m.bias is not None:
|
310 |
+
m.bias.data.zero_()
|
311 |
+
|
312 |
+
def forward(self, x, H, W):
|
313 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
314 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
315 |
+
|
316 |
+
return x
|
317 |
+
|
318 |
+
|
319 |
+
class OverlapPatchEmbed(nn.Module):
|
320 |
+
""" Image to Patch Embedding
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768):
|
324 |
+
super().__init__()
|
325 |
+
img_size = to_2tuple(img_size)
|
326 |
+
patch_size = to_2tuple(patch_size)
|
327 |
+
|
328 |
+
self.img_size = img_size
|
329 |
+
self.patch_size = patch_size
|
330 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
331 |
+
self.num_patches = self.H * self.W
|
332 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride,
|
333 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
334 |
+
self.norm = nn.LayerNorm(embed_dim)
|
335 |
+
|
336 |
+
self.apply(self._init_weights)
|
337 |
+
|
338 |
+
def _init_weights(self, m):
|
339 |
+
if isinstance(m, nn.Linear):
|
340 |
+
trunc_normal_(m.weight, std=.02)
|
341 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
342 |
+
nn.init.constant_(m.bias, 0)
|
343 |
+
elif isinstance(m, nn.LayerNorm):
|
344 |
+
nn.init.constant_(m.bias, 0)
|
345 |
+
nn.init.constant_(m.weight, 1.0)
|
346 |
+
elif isinstance(m, nn.Conv2d):
|
347 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
348 |
+
fan_out //= m.groups
|
349 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
350 |
+
if m.bias is not None:
|
351 |
+
m.bias.data.zero_()
|
352 |
+
|
353 |
+
def forward(self, x):
|
354 |
+
x = self.proj(x)
|
355 |
+
_, _, H, W = x.shape
|
356 |
+
x = x.flatten(2).transpose(1, 2)
|
357 |
+
x = self.norm(x)
|
358 |
+
|
359 |
+
return x, H, W
|
360 |
+
|
361 |
+
|
362 |
+
class PyramidVisionTransformerImpr(nn.Module):
|
363 |
+
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
364 |
+
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
365 |
+
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
366 |
+
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
367 |
+
super().__init__()
|
368 |
+
self.num_classes = num_classes
|
369 |
+
self.depths = depths
|
370 |
+
|
371 |
+
# patch_embed
|
372 |
+
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels,
|
373 |
+
embed_dim=embed_dims[0])
|
374 |
+
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0],
|
375 |
+
embed_dim=embed_dims[1])
|
376 |
+
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1],
|
377 |
+
embed_dim=embed_dims[2])
|
378 |
+
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2],
|
379 |
+
embed_dim=embed_dims[3])
|
380 |
+
|
381 |
+
# transformer encoder
|
382 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
383 |
+
cur = 0
|
384 |
+
self.block1 = nn.ModuleList([Block(
|
385 |
+
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
386 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
387 |
+
sr_ratio=sr_ratios[0])
|
388 |
+
for i in range(depths[0])])
|
389 |
+
self.norm1 = norm_layer(embed_dims[0])
|
390 |
+
|
391 |
+
cur += depths[0]
|
392 |
+
self.block2 = nn.ModuleList([Block(
|
393 |
+
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
394 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
395 |
+
sr_ratio=sr_ratios[1])
|
396 |
+
for i in range(depths[1])])
|
397 |
+
self.norm2 = norm_layer(embed_dims[1])
|
398 |
+
|
399 |
+
cur += depths[1]
|
400 |
+
self.block3 = nn.ModuleList([Block(
|
401 |
+
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
402 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
403 |
+
sr_ratio=sr_ratios[2])
|
404 |
+
for i in range(depths[2])])
|
405 |
+
self.norm3 = norm_layer(embed_dims[2])
|
406 |
+
|
407 |
+
cur += depths[2]
|
408 |
+
self.block4 = nn.ModuleList([Block(
|
409 |
+
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
410 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
411 |
+
sr_ratio=sr_ratios[3])
|
412 |
+
for i in range(depths[3])])
|
413 |
+
self.norm4 = norm_layer(embed_dims[3])
|
414 |
+
|
415 |
+
# classification head
|
416 |
+
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
417 |
+
|
418 |
+
self.apply(self._init_weights)
|
419 |
+
|
420 |
+
def _init_weights(self, m):
|
421 |
+
if isinstance(m, nn.Linear):
|
422 |
+
trunc_normal_(m.weight, std=.02)
|
423 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
424 |
+
nn.init.constant_(m.bias, 0)
|
425 |
+
elif isinstance(m, nn.LayerNorm):
|
426 |
+
nn.init.constant_(m.bias, 0)
|
427 |
+
nn.init.constant_(m.weight, 1.0)
|
428 |
+
elif isinstance(m, nn.Conv2d):
|
429 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
430 |
+
fan_out //= m.groups
|
431 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
432 |
+
if m.bias is not None:
|
433 |
+
m.bias.data.zero_()
|
434 |
+
|
435 |
+
def init_weights(self, pretrained=None):
|
436 |
+
if isinstance(pretrained, str):
|
437 |
+
logger = 1
|
438 |
+
#load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
439 |
+
|
440 |
+
def reset_drop_path(self, drop_path_rate):
|
441 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
442 |
+
cur = 0
|
443 |
+
for i in range(self.depths[0]):
|
444 |
+
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
445 |
+
|
446 |
+
cur += self.depths[0]
|
447 |
+
for i in range(self.depths[1]):
|
448 |
+
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
449 |
+
|
450 |
+
cur += self.depths[1]
|
451 |
+
for i in range(self.depths[2]):
|
452 |
+
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
453 |
+
|
454 |
+
cur += self.depths[2]
|
455 |
+
for i in range(self.depths[3]):
|
456 |
+
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
457 |
+
|
458 |
+
def freeze_patch_emb(self):
|
459 |
+
self.patch_embed1.requires_grad = False
|
460 |
+
|
461 |
+
@torch.jit.ignore
|
462 |
+
def no_weight_decay(self):
|
463 |
+
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
464 |
+
|
465 |
+
def get_classifier(self):
|
466 |
+
return self.head
|
467 |
+
|
468 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
469 |
+
self.num_classes = num_classes
|
470 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
471 |
+
|
472 |
+
def forward_features(self, x):
|
473 |
+
B = x.shape[0]
|
474 |
+
outs = []
|
475 |
+
|
476 |
+
# stage 1
|
477 |
+
x, H, W = self.patch_embed1(x)
|
478 |
+
for i, blk in enumerate(self.block1):
|
479 |
+
x = blk(x, H, W)
|
480 |
+
x = self.norm1(x)
|
481 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
482 |
+
outs.append(x)
|
483 |
+
|
484 |
+
# stage 2
|
485 |
+
x, H, W = self.patch_embed2(x)
|
486 |
+
for i, blk in enumerate(self.block2):
|
487 |
+
x = blk(x, H, W)
|
488 |
+
x = self.norm2(x)
|
489 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
490 |
+
outs.append(x)
|
491 |
+
|
492 |
+
# stage 3
|
493 |
+
x, H, W = self.patch_embed3(x)
|
494 |
+
for i, blk in enumerate(self.block3):
|
495 |
+
x = blk(x, H, W)
|
496 |
+
x = self.norm3(x)
|
497 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
498 |
+
outs.append(x)
|
499 |
+
|
500 |
+
# stage 4
|
501 |
+
x, H, W = self.patch_embed4(x)
|
502 |
+
for i, blk in enumerate(self.block4):
|
503 |
+
x = blk(x, H, W)
|
504 |
+
x = self.norm4(x)
|
505 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
506 |
+
outs.append(x)
|
507 |
+
|
508 |
+
return outs
|
509 |
+
|
510 |
+
# return x.mean(dim=1)
|
511 |
+
|
512 |
+
def forward(self, x):
|
513 |
+
x = self.forward_features(x)
|
514 |
+
# x = self.head(x)
|
515 |
+
|
516 |
+
return x
|
517 |
+
|
518 |
+
|
519 |
+
class DWConv(nn.Module):
|
520 |
+
def __init__(self, dim=768):
|
521 |
+
super(DWConv, self).__init__()
|
522 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
523 |
+
|
524 |
+
def forward(self, x, H, W):
|
525 |
+
B, N, C = x.shape
|
526 |
+
x = x.transpose(1, 2).view(B, C, H, W).contiguous()
|
527 |
+
x = self.dwconv(x)
|
528 |
+
x = x.flatten(2).transpose(1, 2)
|
529 |
+
|
530 |
+
return x
|
531 |
+
|
532 |
+
|
533 |
+
def _conv_filter(state_dict, patch_size=16):
|
534 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
535 |
+
out_dict = {}
|
536 |
+
for k, v in state_dict.items():
|
537 |
+
if 'patch_embed.proj.weight' in k:
|
538 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
539 |
+
out_dict[k] = v
|
540 |
+
|
541 |
+
return out_dict
|
542 |
+
|
543 |
+
|
544 |
+
## @register_model
|
545 |
+
class pvt_v2_b0(PyramidVisionTransformerImpr):
|
546 |
+
def __init__(self, **kwargs):
|
547 |
+
super(pvt_v2_b0, self).__init__(
|
548 |
+
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
549 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
550 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
## @register_model
|
555 |
+
class pvt_v2_b1(PyramidVisionTransformerImpr):
|
556 |
+
def __init__(self, **kwargs):
|
557 |
+
super(pvt_v2_b1, self).__init__(
|
558 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
559 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
560 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
561 |
+
|
562 |
+
## @register_model
|
563 |
+
class pvt_v2_b2(PyramidVisionTransformerImpr):
|
564 |
+
def __init__(self, in_channels=3, **kwargs):
|
565 |
+
super(pvt_v2_b2, self).__init__(
|
566 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
567 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
568 |
+
drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
|
569 |
+
|
570 |
+
## @register_model
|
571 |
+
class pvt_v2_b3(PyramidVisionTransformerImpr):
|
572 |
+
def __init__(self, **kwargs):
|
573 |
+
super(pvt_v2_b3, self).__init__(
|
574 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
575 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
576 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
577 |
+
|
578 |
+
## @register_model
|
579 |
+
class pvt_v2_b4(PyramidVisionTransformerImpr):
|
580 |
+
def __init__(self, **kwargs):
|
581 |
+
super(pvt_v2_b4, self).__init__(
|
582 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
583 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
584 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
585 |
+
|
586 |
+
|
587 |
+
## @register_model
|
588 |
+
class pvt_v2_b5(PyramidVisionTransformerImpr):
|
589 |
+
def __init__(self, **kwargs):
|
590 |
+
super(pvt_v2_b5, self).__init__(
|
591 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
592 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
593 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
594 |
+
|
595 |
+
|
596 |
+
|
597 |
+
### models/backbones/swin_v1.py
|
598 |
+
|
599 |
+
# --------------------------------------------------------
|
600 |
+
# Swin Transformer
|
601 |
+
# Copyright (c) 2021 Microsoft
|
602 |
+
# Licensed under The MIT License [see LICENSE for details]
|
603 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
604 |
+
# --------------------------------------------------------
|
605 |
+
|
606 |
+
import torch
|
607 |
+
import torch.nn as nn
|
608 |
+
import torch.nn.functional as F
|
609 |
+
import torch.utils.checkpoint as checkpoint
|
610 |
+
import numpy as np
|
611 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
612 |
+
|
613 |
+
# from config import Config
|
614 |
+
|
615 |
+
|
616 |
+
# config = Config()
|
617 |
+
|
618 |
+
class Mlp(nn.Module):
|
619 |
+
""" Multilayer perceptron."""
|
620 |
+
|
621 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
622 |
+
super().__init__()
|
623 |
+
out_features = out_features or in_features
|
624 |
+
hidden_features = hidden_features or in_features
|
625 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
626 |
+
self.act = act_layer()
|
627 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
628 |
+
self.drop = nn.Dropout(drop)
|
629 |
+
|
630 |
+
def forward(self, x):
|
631 |
+
x = self.fc1(x)
|
632 |
+
x = self.act(x)
|
633 |
+
x = self.drop(x)
|
634 |
+
x = self.fc2(x)
|
635 |
+
x = self.drop(x)
|
636 |
+
return x
|
637 |
+
|
638 |
+
|
639 |
+
def window_partition(x, window_size):
|
640 |
+
"""
|
641 |
+
Args:
|
642 |
+
x: (B, H, W, C)
|
643 |
+
window_size (int): window size
|
644 |
+
|
645 |
+
Returns:
|
646 |
+
windows: (num_windows*B, window_size, window_size, C)
|
647 |
+
"""
|
648 |
+
B, H, W, C = x.shape
|
649 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
650 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
651 |
+
return windows
|
652 |
+
|
653 |
+
|
654 |
+
def window_reverse(windows, window_size, H, W):
|
655 |
+
"""
|
656 |
+
Args:
|
657 |
+
windows: (num_windows*B, window_size, window_size, C)
|
658 |
+
window_size (int): Window size
|
659 |
+
H (int): Height of image
|
660 |
+
W (int): Width of image
|
661 |
+
|
662 |
+
Returns:
|
663 |
+
x: (B, H, W, C)
|
664 |
+
"""
|
665 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
666 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
667 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
668 |
+
return x
|
669 |
+
|
670 |
+
|
671 |
+
class WindowAttention(nn.Module):
|
672 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
673 |
+
It supports both of shifted and non-shifted window.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
dim (int): Number of input channels.
|
677 |
+
window_size (tuple[int]): The height and width of the window.
|
678 |
+
num_heads (int): Number of attention heads.
|
679 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
680 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
681 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
682 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
683 |
+
"""
|
684 |
+
|
685 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
686 |
+
|
687 |
+
super().__init__()
|
688 |
+
self.dim = dim
|
689 |
+
self.window_size = window_size # Wh, Ww
|
690 |
+
self.num_heads = num_heads
|
691 |
+
head_dim = dim // num_heads
|
692 |
+
self.scale = qk_scale or head_dim ** -0.5
|
693 |
+
|
694 |
+
# define a parameter table of relative position bias
|
695 |
+
self.relative_position_bias_table = nn.Parameter(
|
696 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
697 |
+
|
698 |
+
# get pair-wise relative position index for each token inside the window
|
699 |
+
coords_h = torch.arange(self.window_size[0])
|
700 |
+
coords_w = torch.arange(self.window_size[1])
|
701 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
702 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
703 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
704 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
705 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
706 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
707 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
708 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
709 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
710 |
+
|
711 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
712 |
+
self.attn_drop_prob = attn_drop
|
713 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
714 |
+
self.proj = nn.Linear(dim, dim)
|
715 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
716 |
+
|
717 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
718 |
+
self.softmax = nn.Softmax(dim=-1)
|
719 |
+
|
720 |
+
def forward(self, x, mask=None):
|
721 |
+
""" Forward function.
|
722 |
+
|
723 |
+
Args:
|
724 |
+
x: input features with shape of (num_windows*B, N, C)
|
725 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
726 |
+
"""
|
727 |
+
B_, N, C = x.shape
|
728 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
729 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
730 |
+
|
731 |
+
q = q * self.scale
|
732 |
+
|
733 |
+
if config.SDPA_enabled:
|
734 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
735 |
+
q, k, v,
|
736 |
+
attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False
|
737 |
+
).transpose(1, 2).reshape(B_, N, C)
|
738 |
+
else:
|
739 |
+
attn = (q @ k.transpose(-2, -1))
|
740 |
+
|
741 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
742 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
743 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
744 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
745 |
+
|
746 |
+
if mask is not None:
|
747 |
+
nW = mask.shape[0]
|
748 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
749 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
750 |
+
attn = self.softmax(attn)
|
751 |
+
else:
|
752 |
+
attn = self.softmax(attn)
|
753 |
+
|
754 |
+
attn = self.attn_drop(attn)
|
755 |
+
|
756 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
757 |
+
x = self.proj(x)
|
758 |
+
x = self.proj_drop(x)
|
759 |
+
return x
|
760 |
+
|
761 |
+
|
762 |
+
class SwinTransformerBlock(nn.Module):
|
763 |
+
""" Swin Transformer Block.
|
764 |
+
|
765 |
+
Args:
|
766 |
+
dim (int): Number of input channels.
|
767 |
+
num_heads (int): Number of attention heads.
|
768 |
+
window_size (int): Window size.
|
769 |
+
shift_size (int): Shift size for SW-MSA.
|
770 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
771 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
772 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
773 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
774 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
775 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
776 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
777 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
778 |
+
"""
|
779 |
+
|
780 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
781 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
782 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
783 |
+
super().__init__()
|
784 |
+
self.dim = dim
|
785 |
+
self.num_heads = num_heads
|
786 |
+
self.window_size = window_size
|
787 |
+
self.shift_size = shift_size
|
788 |
+
self.mlp_ratio = mlp_ratio
|
789 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
790 |
+
|
791 |
+
self.norm1 = norm_layer(dim)
|
792 |
+
self.attn = WindowAttention(
|
793 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
794 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
795 |
+
|
796 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
797 |
+
self.norm2 = norm_layer(dim)
|
798 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
799 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
800 |
+
|
801 |
+
self.H = None
|
802 |
+
self.W = None
|
803 |
+
|
804 |
+
def forward(self, x, mask_matrix):
|
805 |
+
""" Forward function.
|
806 |
+
|
807 |
+
Args:
|
808 |
+
x: Input feature, tensor size (B, H*W, C).
|
809 |
+
H, W: Spatial resolution of the input feature.
|
810 |
+
mask_matrix: Attention mask for cyclic shift.
|
811 |
+
"""
|
812 |
+
B, L, C = x.shape
|
813 |
+
H, W = self.H, self.W
|
814 |
+
assert L == H * W, "input feature has wrong size"
|
815 |
+
|
816 |
+
shortcut = x
|
817 |
+
x = self.norm1(x)
|
818 |
+
x = x.view(B, H, W, C)
|
819 |
+
|
820 |
+
# pad feature maps to multiples of window size
|
821 |
+
pad_l = pad_t = 0
|
822 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
823 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
824 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
825 |
+
_, Hp, Wp, _ = x.shape
|
826 |
+
|
827 |
+
# cyclic shift
|
828 |
+
if self.shift_size > 0:
|
829 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
830 |
+
attn_mask = mask_matrix
|
831 |
+
else:
|
832 |
+
shifted_x = x
|
833 |
+
attn_mask = None
|
834 |
+
|
835 |
+
# partition windows
|
836 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
837 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
838 |
+
|
839 |
+
# W-MSA/SW-MSA
|
840 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
841 |
+
|
842 |
+
# merge windows
|
843 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
844 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
845 |
+
|
846 |
+
# reverse cyclic shift
|
847 |
+
if self.shift_size > 0:
|
848 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
849 |
+
else:
|
850 |
+
x = shifted_x
|
851 |
+
|
852 |
+
if pad_r > 0 or pad_b > 0:
|
853 |
+
x = x[:, :H, :W, :].contiguous()
|
854 |
+
|
855 |
+
x = x.view(B, H * W, C)
|
856 |
+
|
857 |
+
# FFN
|
858 |
+
x = shortcut + self.drop_path(x)
|
859 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
860 |
+
|
861 |
+
return x
|
862 |
+
|
863 |
+
|
864 |
+
class PatchMerging(nn.Module):
|
865 |
+
""" Patch Merging Layer
|
866 |
+
|
867 |
+
Args:
|
868 |
+
dim (int): Number of input channels.
|
869 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
870 |
+
"""
|
871 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
872 |
+
super().__init__()
|
873 |
+
self.dim = dim
|
874 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
875 |
+
self.norm = norm_layer(4 * dim)
|
876 |
+
|
877 |
+
def forward(self, x, H, W):
|
878 |
+
""" Forward function.
|
879 |
+
|
880 |
+
Args:
|
881 |
+
x: Input feature, tensor size (B, H*W, C).
|
882 |
+
H, W: Spatial resolution of the input feature.
|
883 |
+
"""
|
884 |
+
B, L, C = x.shape
|
885 |
+
assert L == H * W, "input feature has wrong size"
|
886 |
+
|
887 |
+
x = x.view(B, H, W, C)
|
888 |
+
|
889 |
+
# padding
|
890 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
891 |
+
if pad_input:
|
892 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
893 |
+
|
894 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
895 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
896 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
897 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
898 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
899 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
900 |
+
|
901 |
+
x = self.norm(x)
|
902 |
+
x = self.reduction(x)
|
903 |
+
|
904 |
+
return x
|
905 |
+
|
906 |
+
|
907 |
+
class BasicLayer(nn.Module):
|
908 |
+
""" A basic Swin Transformer layer for one stage.
|
909 |
+
|
910 |
+
Args:
|
911 |
+
dim (int): Number of feature channels
|
912 |
+
depth (int): Depths of this stage.
|
913 |
+
num_heads (int): Number of attention head.
|
914 |
+
window_size (int): Local window size. Default: 7.
|
915 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
916 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
917 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
918 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
919 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
920 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
921 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
922 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
923 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
924 |
+
"""
|
925 |
+
|
926 |
+
def __init__(self,
|
927 |
+
dim,
|
928 |
+
depth,
|
929 |
+
num_heads,
|
930 |
+
window_size=7,
|
931 |
+
mlp_ratio=4.,
|
932 |
+
qkv_bias=True,
|
933 |
+
qk_scale=None,
|
934 |
+
drop=0.,
|
935 |
+
attn_drop=0.,
|
936 |
+
drop_path=0.,
|
937 |
+
norm_layer=nn.LayerNorm,
|
938 |
+
downsample=None,
|
939 |
+
use_checkpoint=False):
|
940 |
+
super().__init__()
|
941 |
+
self.window_size = window_size
|
942 |
+
self.shift_size = window_size // 2
|
943 |
+
self.depth = depth
|
944 |
+
self.use_checkpoint = use_checkpoint
|
945 |
+
|
946 |
+
# build blocks
|
947 |
+
self.blocks = nn.ModuleList([
|
948 |
+
SwinTransformerBlock(
|
949 |
+
dim=dim,
|
950 |
+
num_heads=num_heads,
|
951 |
+
window_size=window_size,
|
952 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
953 |
+
mlp_ratio=mlp_ratio,
|
954 |
+
qkv_bias=qkv_bias,
|
955 |
+
qk_scale=qk_scale,
|
956 |
+
drop=drop,
|
957 |
+
attn_drop=attn_drop,
|
958 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
959 |
+
norm_layer=norm_layer)
|
960 |
+
for i in range(depth)])
|
961 |
+
|
962 |
+
# patch merging layer
|
963 |
+
if downsample is not None:
|
964 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
965 |
+
else:
|
966 |
+
self.downsample = None
|
967 |
+
|
968 |
+
def forward(self, x, H, W):
|
969 |
+
""" Forward function.
|
970 |
+
|
971 |
+
Args:
|
972 |
+
x: Input feature, tensor size (B, H*W, C).
|
973 |
+
H, W: Spatial resolution of the input feature.
|
974 |
+
"""
|
975 |
+
|
976 |
+
# calculate attention mask for SW-MSA
|
977 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
978 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
979 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
980 |
+
h_slices = (slice(0, -self.window_size),
|
981 |
+
slice(-self.window_size, -self.shift_size),
|
982 |
+
slice(-self.shift_size, None))
|
983 |
+
w_slices = (slice(0, -self.window_size),
|
984 |
+
slice(-self.window_size, -self.shift_size),
|
985 |
+
slice(-self.shift_size, None))
|
986 |
+
cnt = 0
|
987 |
+
for h in h_slices:
|
988 |
+
for w in w_slices:
|
989 |
+
img_mask[:, h, w, :] = cnt
|
990 |
+
cnt += 1
|
991 |
+
|
992 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
993 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
994 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
995 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
996 |
+
|
997 |
+
for blk in self.blocks:
|
998 |
+
blk.H, blk.W = H, W
|
999 |
+
if self.use_checkpoint:
|
1000 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
1001 |
+
else:
|
1002 |
+
x = blk(x, attn_mask)
|
1003 |
+
if self.downsample is not None:
|
1004 |
+
x_down = self.downsample(x, H, W)
|
1005 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
1006 |
+
return x, H, W, x_down, Wh, Ww
|
1007 |
+
else:
|
1008 |
+
return x, H, W, x, H, W
|
1009 |
+
|
1010 |
+
|
1011 |
+
class PatchEmbed(nn.Module):
|
1012 |
+
""" Image to Patch Embedding
|
1013 |
+
|
1014 |
+
Args:
|
1015 |
+
patch_size (int): Patch token size. Default: 4.
|
1016 |
+
in_channels (int): Number of input image channels. Default: 3.
|
1017 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
1018 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
1019 |
+
"""
|
1020 |
+
|
1021 |
+
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
|
1022 |
+
super().__init__()
|
1023 |
+
patch_size = to_2tuple(patch_size)
|
1024 |
+
self.patch_size = patch_size
|
1025 |
+
|
1026 |
+
self.in_channels = in_channels
|
1027 |
+
self.embed_dim = embed_dim
|
1028 |
+
|
1029 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
1030 |
+
if norm_layer is not None:
|
1031 |
+
self.norm = norm_layer(embed_dim)
|
1032 |
+
else:
|
1033 |
+
self.norm = None
|
1034 |
+
|
1035 |
+
def forward(self, x):
|
1036 |
+
"""Forward function."""
|
1037 |
+
# padding
|
1038 |
+
_, _, H, W = x.size()
|
1039 |
+
if W % self.patch_size[1] != 0:
|
1040 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
1041 |
+
if H % self.patch_size[0] != 0:
|
1042 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
1043 |
+
|
1044 |
+
x = self.proj(x) # B C Wh Ww
|
1045 |
+
if self.norm is not None:
|
1046 |
+
Wh, Ww = x.size(2), x.size(3)
|
1047 |
+
x = x.flatten(2).transpose(1, 2)
|
1048 |
+
x = self.norm(x)
|
1049 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
1050 |
+
|
1051 |
+
return x
|
1052 |
+
|
1053 |
+
|
1054 |
+
class SwinTransformer(nn.Module):
|
1055 |
+
""" Swin Transformer backbone.
|
1056 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
1057 |
+
https://arxiv.org/pdf/2103.14030
|
1058 |
+
|
1059 |
+
Args:
|
1060 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
1061 |
+
used in absolute postion embedding. Default 224.
|
1062 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
1063 |
+
in_channels (int): Number of input image channels. Default: 3.
|
1064 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
1065 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
1066 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
1067 |
+
window_size (int): Window size. Default: 7.
|
1068 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
1069 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
1070 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
1071 |
+
drop_rate (float): Dropout rate.
|
1072 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
1073 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
1074 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
1075 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
1076 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
1077 |
+
out_indices (Sequence[int]): Output from which stages.
|
1078 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
1079 |
+
-1 means not freezing any parameters.
|
1080 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
1081 |
+
"""
|
1082 |
+
|
1083 |
+
def __init__(self,
|
1084 |
+
pretrain_img_size=224,
|
1085 |
+
patch_size=4,
|
1086 |
+
in_channels=3,
|
1087 |
+
embed_dim=96,
|
1088 |
+
depths=[2, 2, 6, 2],
|
1089 |
+
num_heads=[3, 6, 12, 24],
|
1090 |
+
window_size=7,
|
1091 |
+
mlp_ratio=4.,
|
1092 |
+
qkv_bias=True,
|
1093 |
+
qk_scale=None,
|
1094 |
+
drop_rate=0.,
|
1095 |
+
attn_drop_rate=0.,
|
1096 |
+
drop_path_rate=0.2,
|
1097 |
+
norm_layer=nn.LayerNorm,
|
1098 |
+
ape=False,
|
1099 |
+
patch_norm=True,
|
1100 |
+
out_indices=(0, 1, 2, 3),
|
1101 |
+
frozen_stages=-1,
|
1102 |
+
use_checkpoint=False):
|
1103 |
+
super().__init__()
|
1104 |
+
|
1105 |
+
self.pretrain_img_size = pretrain_img_size
|
1106 |
+
self.num_layers = len(depths)
|
1107 |
+
self.embed_dim = embed_dim
|
1108 |
+
self.ape = ape
|
1109 |
+
self.patch_norm = patch_norm
|
1110 |
+
self.out_indices = out_indices
|
1111 |
+
self.frozen_stages = frozen_stages
|
1112 |
+
|
1113 |
+
# split image into non-overlapping patches
|
1114 |
+
self.patch_embed = PatchEmbed(
|
1115 |
+
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
|
1116 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
1117 |
+
|
1118 |
+
# absolute position embedding
|
1119 |
+
if self.ape:
|
1120 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
1121 |
+
patch_size = to_2tuple(patch_size)
|
1122 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
1123 |
+
|
1124 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
1125 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
1126 |
+
|
1127 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
1128 |
+
|
1129 |
+
# stochastic depth
|
1130 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
1131 |
+
|
1132 |
+
# build layers
|
1133 |
+
self.layers = nn.ModuleList()
|
1134 |
+
for i_layer in range(self.num_layers):
|
1135 |
+
layer = BasicLayer(
|
1136 |
+
dim=int(embed_dim * 2 ** i_layer),
|
1137 |
+
depth=depths[i_layer],
|
1138 |
+
num_heads=num_heads[i_layer],
|
1139 |
+
window_size=window_size,
|
1140 |
+
mlp_ratio=mlp_ratio,
|
1141 |
+
qkv_bias=qkv_bias,
|
1142 |
+
qk_scale=qk_scale,
|
1143 |
+
drop=drop_rate,
|
1144 |
+
attn_drop=attn_drop_rate,
|
1145 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
1146 |
+
norm_layer=norm_layer,
|
1147 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
1148 |
+
use_checkpoint=use_checkpoint)
|
1149 |
+
self.layers.append(layer)
|
1150 |
+
|
1151 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
1152 |
+
self.num_features = num_features
|
1153 |
+
|
1154 |
+
# add a norm layer for each output
|
1155 |
+
for i_layer in out_indices:
|
1156 |
+
layer = norm_layer(num_features[i_layer])
|
1157 |
+
layer_name = f'norm{i_layer}'
|
1158 |
+
self.add_module(layer_name, layer)
|
1159 |
+
|
1160 |
+
self._freeze_stages()
|
1161 |
+
|
1162 |
+
def _freeze_stages(self):
|
1163 |
+
if self.frozen_stages >= 0:
|
1164 |
+
self.patch_embed.eval()
|
1165 |
+
for param in self.patch_embed.parameters():
|
1166 |
+
param.requires_grad = False
|
1167 |
+
|
1168 |
+
if self.frozen_stages >= 1 and self.ape:
|
1169 |
+
self.absolute_pos_embed.requires_grad = False
|
1170 |
+
|
1171 |
+
if self.frozen_stages >= 2:
|
1172 |
+
self.pos_drop.eval()
|
1173 |
+
for i in range(0, self.frozen_stages - 1):
|
1174 |
+
m = self.layers[i]
|
1175 |
+
m.eval()
|
1176 |
+
for param in m.parameters():
|
1177 |
+
param.requires_grad = False
|
1178 |
+
|
1179 |
+
|
1180 |
+
def forward(self, x):
|
1181 |
+
"""Forward function."""
|
1182 |
+
x = self.patch_embed(x)
|
1183 |
+
|
1184 |
+
Wh, Ww = x.size(2), x.size(3)
|
1185 |
+
if self.ape:
|
1186 |
+
# interpolate the position embedding to the corresponding size
|
1187 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
1188 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
1189 |
+
|
1190 |
+
outs = []#x.contiguous()]
|
1191 |
+
x = x.flatten(2).transpose(1, 2)
|
1192 |
+
x = self.pos_drop(x)
|
1193 |
+
for i in range(self.num_layers):
|
1194 |
+
layer = self.layers[i]
|
1195 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
1196 |
+
|
1197 |
+
if i in self.out_indices:
|
1198 |
+
norm_layer = getattr(self, f'norm{i}')
|
1199 |
+
x_out = norm_layer(x_out)
|
1200 |
+
|
1201 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
1202 |
+
outs.append(out)
|
1203 |
+
|
1204 |
+
return tuple(outs)
|
1205 |
+
|
1206 |
+
def train(self, mode=True):
|
1207 |
+
"""Convert the model into training mode while keep layers freezed."""
|
1208 |
+
super(SwinTransformer, self).train(mode)
|
1209 |
+
self._freeze_stages()
|
1210 |
+
|
1211 |
+
def swin_v1_t():
|
1212 |
+
model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)
|
1213 |
+
return model
|
1214 |
+
|
1215 |
+
def swin_v1_s():
|
1216 |
+
model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7)
|
1217 |
+
return model
|
1218 |
+
|
1219 |
+
def swin_v1_b():
|
1220 |
+
model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
|
1221 |
+
return model
|
1222 |
+
|
1223 |
+
def swin_v1_l():
|
1224 |
+
model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12)
|
1225 |
+
return model
|
1226 |
+
|
1227 |
+
|
1228 |
+
|
1229 |
+
### models/modules/deform_conv.py
|
1230 |
+
|
1231 |
+
import torch
|
1232 |
+
import torch.nn as nn
|
1233 |
+
from torchvision.ops import deform_conv2d
|
1234 |
+
|
1235 |
+
|
1236 |
+
class DeformableConv2d(nn.Module):
|
1237 |
+
def __init__(self,
|
1238 |
+
in_channels,
|
1239 |
+
out_channels,
|
1240 |
+
kernel_size=3,
|
1241 |
+
stride=1,
|
1242 |
+
padding=1,
|
1243 |
+
bias=False):
|
1244 |
+
|
1245 |
+
super(DeformableConv2d, self).__init__()
|
1246 |
+
|
1247 |
+
assert type(kernel_size) == tuple or type(kernel_size) == int
|
1248 |
+
|
1249 |
+
kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
|
1250 |
+
self.stride = stride if type(stride) == tuple else (stride, stride)
|
1251 |
+
self.padding = padding
|
1252 |
+
|
1253 |
+
self.offset_conv = nn.Conv2d(in_channels,
|
1254 |
+
2 * kernel_size[0] * kernel_size[1],
|
1255 |
+
kernel_size=kernel_size,
|
1256 |
+
stride=stride,
|
1257 |
+
padding=self.padding,
|
1258 |
+
bias=True)
|
1259 |
+
|
1260 |
+
nn.init.constant_(self.offset_conv.weight, 0.)
|
1261 |
+
nn.init.constant_(self.offset_conv.bias, 0.)
|
1262 |
+
|
1263 |
+
self.modulator_conv = nn.Conv2d(in_channels,
|
1264 |
+
1 * kernel_size[0] * kernel_size[1],
|
1265 |
+
kernel_size=kernel_size,
|
1266 |
+
stride=stride,
|
1267 |
+
padding=self.padding,
|
1268 |
+
bias=True)
|
1269 |
+
|
1270 |
+
nn.init.constant_(self.modulator_conv.weight, 0.)
|
1271 |
+
nn.init.constant_(self.modulator_conv.bias, 0.)
|
1272 |
+
|
1273 |
+
self.regular_conv = nn.Conv2d(in_channels,
|
1274 |
+
out_channels=out_channels,
|
1275 |
+
kernel_size=kernel_size,
|
1276 |
+
stride=stride,
|
1277 |
+
padding=self.padding,
|
1278 |
+
bias=bias)
|
1279 |
+
|
1280 |
+
def forward(self, x):
|
1281 |
+
#h, w = x.shape[2:]
|
1282 |
+
#max_offset = max(h, w)/4.
|
1283 |
+
|
1284 |
+
offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
|
1285 |
+
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
1286 |
+
|
1287 |
+
x = deform_conv2d(
|
1288 |
+
input=x,
|
1289 |
+
offset=offset,
|
1290 |
+
weight=self.regular_conv.weight,
|
1291 |
+
bias=self.regular_conv.bias,
|
1292 |
+
padding=self.padding,
|
1293 |
+
mask=modulator,
|
1294 |
+
stride=self.stride,
|
1295 |
+
)
|
1296 |
+
return x
|
1297 |
+
|
1298 |
+
|
1299 |
+
|
1300 |
+
|
1301 |
+
### utils.py
|
1302 |
+
|
1303 |
+
import torch.nn as nn
|
1304 |
+
|
1305 |
+
|
1306 |
+
def build_act_layer(act_layer):
|
1307 |
+
if act_layer == 'ReLU':
|
1308 |
+
return nn.ReLU(inplace=True)
|
1309 |
+
elif act_layer == 'SiLU':
|
1310 |
+
return nn.SiLU(inplace=True)
|
1311 |
+
elif act_layer == 'GELU':
|
1312 |
+
return nn.GELU()
|
1313 |
+
|
1314 |
+
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
1315 |
+
|
1316 |
+
|
1317 |
+
def build_norm_layer(dim,
|
1318 |
+
norm_layer,
|
1319 |
+
in_format='channels_last',
|
1320 |
+
out_format='channels_last',
|
1321 |
+
eps=1e-6):
|
1322 |
+
layers = []
|
1323 |
+
if norm_layer == 'BN':
|
1324 |
+
if in_format == 'channels_last':
|
1325 |
+
layers.append(to_channels_first())
|
1326 |
+
layers.append(nn.BatchNorm2d(dim))
|
1327 |
+
if out_format == 'channels_last':
|
1328 |
+
layers.append(to_channels_last())
|
1329 |
+
elif norm_layer == 'LN':
|
1330 |
+
if in_format == 'channels_first':
|
1331 |
+
layers.append(to_channels_last())
|
1332 |
+
layers.append(nn.LayerNorm(dim, eps=eps))
|
1333 |
+
if out_format == 'channels_first':
|
1334 |
+
layers.append(to_channels_first())
|
1335 |
+
else:
|
1336 |
+
raise NotImplementedError(
|
1337 |
+
f'build_norm_layer does not support {norm_layer}')
|
1338 |
+
return nn.Sequential(*layers)
|
1339 |
+
|
1340 |
+
|
1341 |
+
class to_channels_first(nn.Module):
|
1342 |
+
|
1343 |
+
def __init__(self):
|
1344 |
+
super().__init__()
|
1345 |
+
|
1346 |
+
def forward(self, x):
|
1347 |
+
return x.permute(0, 3, 1, 2)
|
1348 |
+
|
1349 |
+
|
1350 |
+
class to_channels_last(nn.Module):
|
1351 |
+
|
1352 |
+
def __init__(self):
|
1353 |
+
super().__init__()
|
1354 |
+
|
1355 |
+
def forward(self, x):
|
1356 |
+
return x.permute(0, 2, 3, 1)
|
1357 |
+
|
1358 |
+
|
1359 |
+
|
1360 |
+
### dataset.py
|
1361 |
+
|
1362 |
+
_class_labels_TR_sorted = (
|
1363 |
+
'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, '
|
1364 |
+
'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, '
|
1365 |
+
'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, '
|
1366 |
+
'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, '
|
1367 |
+
'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, '
|
1368 |
+
'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, '
|
1369 |
+
'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, '
|
1370 |
+
'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, '
|
1371 |
+
'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, '
|
1372 |
+
'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, '
|
1373 |
+
'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, '
|
1374 |
+
'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, '
|
1375 |
+
'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, '
|
1376 |
+
'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
|
1377 |
+
)
|
1378 |
+
class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
|
1379 |
+
|
1380 |
+
|
1381 |
+
### models/backbones/build_backbones.py
|
1382 |
+
|
1383 |
+
import torch
|
1384 |
+
import torch.nn as nn
|
1385 |
+
from collections import OrderedDict
|
1386 |
+
from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights
|
1387 |
+
# from models.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5
|
1388 |
+
# from models.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
|
1389 |
+
# from config import Config
|
1390 |
+
|
1391 |
+
|
1392 |
+
config = Config()
|
1393 |
+
|
1394 |
+
def build_backbone(bb_name, pretrained=True, params_settings=''):
|
1395 |
+
if bb_name == 'vgg16':
|
1396 |
+
bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0]
|
1397 |
+
bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]}))
|
1398 |
+
elif bb_name == 'vgg16bn':
|
1399 |
+
bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0]
|
1400 |
+
bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]}))
|
1401 |
+
elif bb_name == 'resnet50':
|
1402 |
+
bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children())
|
1403 |
+
bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]}))
|
1404 |
+
else:
|
1405 |
+
bb = eval('{}({})'.format(bb_name, params_settings))
|
1406 |
+
if pretrained:
|
1407 |
+
bb = load_weights(bb, bb_name)
|
1408 |
+
return bb
|
1409 |
+
|
1410 |
+
def load_weights(model, model_name):
|
1411 |
+
save_model = torch.load(config.weights[model_name], map_location='cpu')
|
1412 |
+
model_dict = model.state_dict()
|
1413 |
+
state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()}
|
1414 |
+
# to ignore the weights with mismatched size when I modify the backbone itself.
|
1415 |
+
if not state_dict:
|
1416 |
+
save_model_keys = list(save_model.keys())
|
1417 |
+
sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None
|
1418 |
+
state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()}
|
1419 |
+
if not state_dict or not sub_item:
|
1420 |
+
print('Weights are not successully loaded. Check the state dict of weights file.')
|
1421 |
+
return None
|
1422 |
+
else:
|
1423 |
+
print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item))
|
1424 |
+
model_dict.update(state_dict)
|
1425 |
+
model.load_state_dict(model_dict)
|
1426 |
+
return model
|
1427 |
+
|
1428 |
+
|
1429 |
+
|
1430 |
+
### models/modules/decoder_blocks.py
|
1431 |
+
|
1432 |
+
import torch
|
1433 |
+
import torch.nn as nn
|
1434 |
+
# from models.aspp import ASPP, ASPPDeformable
|
1435 |
+
# from config import Config
|
1436 |
+
|
1437 |
+
|
1438 |
+
# config = Config()
|
1439 |
+
|
1440 |
+
|
1441 |
+
class BasicDecBlk(nn.Module):
|
1442 |
+
def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
|
1443 |
+
super(BasicDecBlk, self).__init__()
|
1444 |
+
inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
|
1445 |
+
self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
|
1446 |
+
self.relu_in = nn.ReLU(inplace=True)
|
1447 |
+
if config.dec_att == 'ASPP':
|
1448 |
+
self.dec_att = ASPP(in_channels=inter_channels)
|
1449 |
+
elif config.dec_att == 'ASPPDeformable':
|
1450 |
+
self.dec_att = ASPPDeformable(in_channels=inter_channels)
|
1451 |
+
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
|
1452 |
+
self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
|
1453 |
+
self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
1454 |
+
|
1455 |
+
def forward(self, x):
|
1456 |
+
x = self.conv_in(x)
|
1457 |
+
x = self.bn_in(x)
|
1458 |
+
x = self.relu_in(x)
|
1459 |
+
if hasattr(self, 'dec_att'):
|
1460 |
+
x = self.dec_att(x)
|
1461 |
+
x = self.conv_out(x)
|
1462 |
+
x = self.bn_out(x)
|
1463 |
+
return x
|
1464 |
+
|
1465 |
+
|
1466 |
+
class ResBlk(nn.Module):
|
1467 |
+
def __init__(self, in_channels=64, out_channels=None, inter_channels=64):
|
1468 |
+
super(ResBlk, self).__init__()
|
1469 |
+
if out_channels is None:
|
1470 |
+
out_channels = in_channels
|
1471 |
+
inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
|
1472 |
+
|
1473 |
+
self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
|
1474 |
+
self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
|
1475 |
+
self.relu_in = nn.ReLU(inplace=True)
|
1476 |
+
|
1477 |
+
if config.dec_att == 'ASPP':
|
1478 |
+
self.dec_att = ASPP(in_channels=inter_channels)
|
1479 |
+
elif config.dec_att == 'ASPPDeformable':
|
1480 |
+
self.dec_att = ASPPDeformable(in_channels=inter_channels)
|
1481 |
+
|
1482 |
+
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
|
1483 |
+
self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
1484 |
+
|
1485 |
+
self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
|
1486 |
+
|
1487 |
+
def forward(self, x):
|
1488 |
+
_x = self.conv_resi(x)
|
1489 |
+
x = self.conv_in(x)
|
1490 |
+
x = self.bn_in(x)
|
1491 |
+
x = self.relu_in(x)
|
1492 |
+
if hasattr(self, 'dec_att'):
|
1493 |
+
x = self.dec_att(x)
|
1494 |
+
x = self.conv_out(x)
|
1495 |
+
x = self.bn_out(x)
|
1496 |
+
return x + _x
|
1497 |
+
|
1498 |
+
|
1499 |
+
|
1500 |
+
### models/modules/lateral_blocks.py
|
1501 |
+
|
1502 |
+
import numpy as np
|
1503 |
+
import torch
|
1504 |
+
import torch.nn as nn
|
1505 |
+
import torch.nn.functional as F
|
1506 |
+
from functools import partial
|
1507 |
+
|
1508 |
+
# from config import Config
|
1509 |
+
|
1510 |
+
|
1511 |
+
# config = Config()
|
1512 |
+
|
1513 |
+
|
1514 |
+
class BasicLatBlk(nn.Module):
|
1515 |
+
def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
|
1516 |
+
super(BasicLatBlk, self).__init__()
|
1517 |
+
inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
|
1518 |
+
self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
|
1519 |
+
|
1520 |
+
def forward(self, x):
|
1521 |
+
x = self.conv(x)
|
1522 |
+
return x
|
1523 |
+
|
1524 |
+
|
1525 |
+
|
1526 |
+
### models/modules/aspp.py
|
1527 |
+
|
1528 |
+
import torch
|
1529 |
+
import torch.nn as nn
|
1530 |
+
import torch.nn.functional as F
|
1531 |
+
# from models.deform_conv import DeformableConv2d
|
1532 |
+
# from config import Config
|
1533 |
+
|
1534 |
+
|
1535 |
+
# config = Config()
|
1536 |
+
|
1537 |
+
|
1538 |
+
class _ASPPModule(nn.Module):
|
1539 |
+
def __init__(self, in_channels, planes, kernel_size, padding, dilation):
|
1540 |
+
super(_ASPPModule, self).__init__()
|
1541 |
+
self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size,
|
1542 |
+
stride=1, padding=padding, dilation=dilation, bias=False)
|
1543 |
+
self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
|
1544 |
+
self.relu = nn.ReLU(inplace=True)
|
1545 |
+
|
1546 |
+
def forward(self, x):
|
1547 |
+
x = self.atrous_conv(x)
|
1548 |
+
x = self.bn(x)
|
1549 |
+
|
1550 |
+
return self.relu(x)
|
1551 |
+
|
1552 |
+
|
1553 |
+
class ASPP(nn.Module):
|
1554 |
+
def __init__(self, in_channels=64, out_channels=None, output_stride=16):
|
1555 |
+
super(ASPP, self).__init__()
|
1556 |
+
self.down_scale = 1
|
1557 |
+
if out_channels is None:
|
1558 |
+
out_channels = in_channels
|
1559 |
+
self.in_channelster = 256 // self.down_scale
|
1560 |
+
if output_stride == 16:
|
1561 |
+
dilations = [1, 6, 12, 18]
|
1562 |
+
elif output_stride == 8:
|
1563 |
+
dilations = [1, 12, 24, 36]
|
1564 |
+
else:
|
1565 |
+
raise NotImplementedError
|
1566 |
+
|
1567 |
+
self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0])
|
1568 |
+
self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1])
|
1569 |
+
self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2])
|
1570 |
+
self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3])
|
1571 |
+
|
1572 |
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
1573 |
+
nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
|
1574 |
+
nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
|
1575 |
+
nn.ReLU(inplace=True))
|
1576 |
+
self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
|
1577 |
+
self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
1578 |
+
self.relu = nn.ReLU(inplace=True)
|
1579 |
+
self.dropout = nn.Dropout(0.5)
|
1580 |
+
|
1581 |
+
def forward(self, x):
|
1582 |
+
x1 = self.aspp1(x)
|
1583 |
+
x2 = self.aspp2(x)
|
1584 |
+
x3 = self.aspp3(x)
|
1585 |
+
x4 = self.aspp4(x)
|
1586 |
+
x5 = self.global_avg_pool(x)
|
1587 |
+
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
|
1588 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
1589 |
+
|
1590 |
+
x = self.conv1(x)
|
1591 |
+
x = self.bn1(x)
|
1592 |
+
x = self.relu(x)
|
1593 |
+
|
1594 |
+
return self.dropout(x)
|
1595 |
+
|
1596 |
+
|
1597 |
+
##################### Deformable
|
1598 |
+
class _ASPPModuleDeformable(nn.Module):
|
1599 |
+
def __init__(self, in_channels, planes, kernel_size, padding):
|
1600 |
+
super(_ASPPModuleDeformable, self).__init__()
|
1601 |
+
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
|
1602 |
+
stride=1, padding=padding, bias=False)
|
1603 |
+
self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
|
1604 |
+
self.relu = nn.ReLU(inplace=True)
|
1605 |
+
|
1606 |
+
def forward(self, x):
|
1607 |
+
x = self.atrous_conv(x)
|
1608 |
+
x = self.bn(x)
|
1609 |
+
|
1610 |
+
return self.relu(x)
|
1611 |
+
|
1612 |
+
|
1613 |
+
class ASPPDeformable(nn.Module):
|
1614 |
+
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
|
1615 |
+
super(ASPPDeformable, self).__init__()
|
1616 |
+
self.down_scale = 1
|
1617 |
+
if out_channels is None:
|
1618 |
+
out_channels = in_channels
|
1619 |
+
self.in_channelster = 256 // self.down_scale
|
1620 |
+
|
1621 |
+
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0)
|
1622 |
+
self.aspp_deforms = nn.ModuleList([
|
1623 |
+
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes
|
1624 |
+
])
|
1625 |
+
|
1626 |
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
1627 |
+
nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
|
1628 |
+
nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
|
1629 |
+
nn.ReLU(inplace=True))
|
1630 |
+
self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
|
1631 |
+
self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
1632 |
+
self.relu = nn.ReLU(inplace=True)
|
1633 |
+
self.dropout = nn.Dropout(0.5)
|
1634 |
+
|
1635 |
+
def forward(self, x):
|
1636 |
+
x1 = self.aspp1(x)
|
1637 |
+
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
1638 |
+
x5 = self.global_avg_pool(x)
|
1639 |
+
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
|
1640 |
+
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
|
1641 |
+
|
1642 |
+
x = self.conv1(x)
|
1643 |
+
x = self.bn1(x)
|
1644 |
+
x = self.relu(x)
|
1645 |
+
|
1646 |
+
return self.dropout(x)
|
1647 |
+
|
1648 |
+
|
1649 |
+
|
1650 |
+
### models/refinement/refiner.py
|
1651 |
+
|
1652 |
+
import torch
|
1653 |
+
import torch.nn as nn
|
1654 |
+
from collections import OrderedDict
|
1655 |
+
import torch
|
1656 |
+
import torch.nn as nn
|
1657 |
+
import torch.nn.functional as F
|
1658 |
+
from torchvision.models import vgg16, vgg16_bn
|
1659 |
+
from torchvision.models import resnet50
|
1660 |
+
|
1661 |
+
# from config import Config
|
1662 |
+
# from dataset import class_labels_TR_sorted
|
1663 |
+
# from models.build_backbone import build_backbone
|
1664 |
+
# from models.decoder_blocks import BasicDecBlk
|
1665 |
+
# from models.lateral_blocks import BasicLatBlk
|
1666 |
+
# from models.ing import *
|
1667 |
+
# from models.stem_layer import StemLayer
|
1668 |
+
|
1669 |
+
|
1670 |
+
class RefinerPVTInChannels4(nn.Module):
|
1671 |
+
def __init__(self, in_channels=3+1):
|
1672 |
+
super(RefinerPVTInChannels4, self).__init__()
|
1673 |
+
self.config = Config()
|
1674 |
+
self.epoch = 1
|
1675 |
+
self.bb = build_backbone(self.config.bb, params_settings='in_channels=4')
|
1676 |
+
|
1677 |
+
lateral_channels_in_collection = {
|
1678 |
+
'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
|
1679 |
+
'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
|
1680 |
+
'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
|
1681 |
+
}
|
1682 |
+
channels = lateral_channels_in_collection[self.config.bb]
|
1683 |
+
self.squeeze_module = BasicDecBlk(channels[0], channels[0])
|
1684 |
+
|
1685 |
+
self.decoder = Decoder(channels)
|
1686 |
+
|
1687 |
+
if 0:
|
1688 |
+
for key, value in self.named_parameters():
|
1689 |
+
if 'bb.' in key:
|
1690 |
+
value.requires_grad = False
|
1691 |
+
|
1692 |
+
def forward(self, x):
|
1693 |
+
if isinstance(x, list):
|
1694 |
+
x = torch.cat(x, dim=1)
|
1695 |
+
########## Encoder ##########
|
1696 |
+
if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
|
1697 |
+
x1 = self.bb.conv1(x)
|
1698 |
+
x2 = self.bb.conv2(x1)
|
1699 |
+
x3 = self.bb.conv3(x2)
|
1700 |
+
x4 = self.bb.conv4(x3)
|
1701 |
+
else:
|
1702 |
+
x1, x2, x3, x4 = self.bb(x)
|
1703 |
+
|
1704 |
+
x4 = self.squeeze_module(x4)
|
1705 |
+
|
1706 |
+
########## Decoder ##########
|
1707 |
+
|
1708 |
+
features = [x, x1, x2, x3, x4]
|
1709 |
+
scaled_preds = self.decoder(features)
|
1710 |
+
|
1711 |
+
return scaled_preds
|
1712 |
+
|
1713 |
+
|
1714 |
+
class Refiner(nn.Module):
|
1715 |
+
def __init__(self, in_channels=3+1):
|
1716 |
+
super(Refiner, self).__init__()
|
1717 |
+
self.config = Config()
|
1718 |
+
self.epoch = 1
|
1719 |
+
self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
|
1720 |
+
self.bb = build_backbone(self.config.bb)
|
1721 |
+
|
1722 |
+
lateral_channels_in_collection = {
|
1723 |
+
'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
|
1724 |
+
'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
|
1725 |
+
'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
|
1726 |
+
}
|
1727 |
+
channels = lateral_channels_in_collection[self.config.bb]
|
1728 |
+
self.squeeze_module = BasicDecBlk(channels[0], channels[0])
|
1729 |
+
|
1730 |
+
self.decoder = Decoder(channels)
|
1731 |
+
|
1732 |
+
if 0:
|
1733 |
+
for key, value in self.named_parameters():
|
1734 |
+
if 'bb.' in key:
|
1735 |
+
value.requires_grad = False
|
1736 |
+
|
1737 |
+
def forward(self, x):
|
1738 |
+
if isinstance(x, list):
|
1739 |
+
x = torch.cat(x, dim=1)
|
1740 |
+
x = self.stem_layer(x)
|
1741 |
+
########## Encoder ##########
|
1742 |
+
if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
|
1743 |
+
x1 = self.bb.conv1(x)
|
1744 |
+
x2 = self.bb.conv2(x1)
|
1745 |
+
x3 = self.bb.conv3(x2)
|
1746 |
+
x4 = self.bb.conv4(x3)
|
1747 |
+
else:
|
1748 |
+
x1, x2, x3, x4 = self.bb(x)
|
1749 |
+
|
1750 |
+
x4 = self.squeeze_module(x4)
|
1751 |
+
|
1752 |
+
########## Decoder ##########
|
1753 |
+
|
1754 |
+
features = [x, x1, x2, x3, x4]
|
1755 |
+
scaled_preds = self.decoder(features)
|
1756 |
+
|
1757 |
+
return scaled_preds
|
1758 |
+
|
1759 |
+
|
1760 |
+
class Decoder(nn.Module):
|
1761 |
+
def __init__(self, channels):
|
1762 |
+
super(Decoder, self).__init__()
|
1763 |
+
self.config = Config()
|
1764 |
+
DecoderBlock = eval('BasicDecBlk')
|
1765 |
+
LateralBlock = eval('BasicLatBlk')
|
1766 |
+
|
1767 |
+
self.decoder_block4 = DecoderBlock(channels[0], channels[1])
|
1768 |
+
self.decoder_block3 = DecoderBlock(channels[1], channels[2])
|
1769 |
+
self.decoder_block2 = DecoderBlock(channels[2], channels[3])
|
1770 |
+
self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2)
|
1771 |
+
|
1772 |
+
self.lateral_block4 = LateralBlock(channels[1], channels[1])
|
1773 |
+
self.lateral_block3 = LateralBlock(channels[2], channels[2])
|
1774 |
+
self.lateral_block2 = LateralBlock(channels[3], channels[3])
|
1775 |
+
|
1776 |
+
if self.config.ms_supervision:
|
1777 |
+
self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
|
1778 |
+
self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
|
1779 |
+
self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
|
1780 |
+
self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0))
|
1781 |
+
|
1782 |
+
def forward(self, features):
|
1783 |
+
x, x1, x2, x3, x4 = features
|
1784 |
+
outs = []
|
1785 |
+
p4 = self.decoder_block4(x4)
|
1786 |
+
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
1787 |
+
_p3 = _p4 + self.lateral_block4(x3)
|
1788 |
+
|
1789 |
+
p3 = self.decoder_block3(_p3)
|
1790 |
+
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
1791 |
+
_p2 = _p3 + self.lateral_block3(x2)
|
1792 |
+
|
1793 |
+
p2 = self.decoder_block2(_p2)
|
1794 |
+
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
1795 |
+
_p1 = _p2 + self.lateral_block2(x1)
|
1796 |
+
|
1797 |
+
_p1 = self.decoder_block1(_p1)
|
1798 |
+
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
1799 |
+
p1_out = self.conv_out1(_p1)
|
1800 |
+
|
1801 |
+
if self.config.ms_supervision:
|
1802 |
+
outs.append(self.conv_ms_spvn_4(p4))
|
1803 |
+
outs.append(self.conv_ms_spvn_3(p3))
|
1804 |
+
outs.append(self.conv_ms_spvn_2(p2))
|
1805 |
+
outs.append(p1_out)
|
1806 |
+
return outs
|
1807 |
+
|
1808 |
+
|
1809 |
+
class RefUNet(nn.Module):
|
1810 |
+
# Refinement
|
1811 |
+
def __init__(self, in_channels=3+1):
|
1812 |
+
super(RefUNet, self).__init__()
|
1813 |
+
self.encoder_1 = nn.Sequential(
|
1814 |
+
nn.Conv2d(in_channels, 64, 3, 1, 1),
|
1815 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
1816 |
+
nn.BatchNorm2d(64),
|
1817 |
+
nn.ReLU(inplace=True)
|
1818 |
+
)
|
1819 |
+
|
1820 |
+
self.encoder_2 = nn.Sequential(
|
1821 |
+
nn.MaxPool2d(2, 2, ceil_mode=True),
|
1822 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
1823 |
+
nn.BatchNorm2d(64),
|
1824 |
+
nn.ReLU(inplace=True)
|
1825 |
+
)
|
1826 |
+
|
1827 |
+
self.encoder_3 = nn.Sequential(
|
1828 |
+
nn.MaxPool2d(2, 2, ceil_mode=True),
|
1829 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
1830 |
+
nn.BatchNorm2d(64),
|
1831 |
+
nn.ReLU(inplace=True)
|
1832 |
+
)
|
1833 |
+
|
1834 |
+
self.encoder_4 = nn.Sequential(
|
1835 |
+
nn.MaxPool2d(2, 2, ceil_mode=True),
|
1836 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
1837 |
+
nn.BatchNorm2d(64),
|
1838 |
+
nn.ReLU(inplace=True)
|
1839 |
+
)
|
1840 |
+
|
1841 |
+
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
1842 |
+
#####
|
1843 |
+
self.decoder_5 = nn.Sequential(
|
1844 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
1845 |
+
nn.BatchNorm2d(64),
|
1846 |
+
nn.ReLU(inplace=True)
|
1847 |
+
)
|
1848 |
+
#####
|
1849 |
+
self.decoder_4 = nn.Sequential(
|
1850 |
+
nn.Conv2d(128, 64, 3, 1, 1),
|
1851 |
+
nn.BatchNorm2d(64),
|
1852 |
+
nn.ReLU(inplace=True)
|
1853 |
+
)
|
1854 |
+
|
1855 |
+
self.decoder_3 = nn.Sequential(
|
1856 |
+
nn.Conv2d(128, 64, 3, 1, 1),
|
1857 |
+
nn.BatchNorm2d(64),
|
1858 |
+
nn.ReLU(inplace=True)
|
1859 |
+
)
|
1860 |
+
|
1861 |
+
self.decoder_2 = nn.Sequential(
|
1862 |
+
nn.Conv2d(128, 64, 3, 1, 1),
|
1863 |
+
nn.BatchNorm2d(64),
|
1864 |
+
nn.ReLU(inplace=True)
|
1865 |
+
)
|
1866 |
+
|
1867 |
+
self.decoder_1 = nn.Sequential(
|
1868 |
+
nn.Conv2d(128, 64, 3, 1, 1),
|
1869 |
+
nn.BatchNorm2d(64),
|
1870 |
+
nn.ReLU(inplace=True)
|
1871 |
+
)
|
1872 |
+
|
1873 |
+
self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1)
|
1874 |
+
|
1875 |
+
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
1876 |
+
|
1877 |
+
def forward(self, x):
|
1878 |
+
outs = []
|
1879 |
+
if isinstance(x, list):
|
1880 |
+
x = torch.cat(x, dim=1)
|
1881 |
+
hx = x
|
1882 |
+
|
1883 |
+
hx1 = self.encoder_1(hx)
|
1884 |
+
hx2 = self.encoder_2(hx1)
|
1885 |
+
hx3 = self.encoder_3(hx2)
|
1886 |
+
hx4 = self.encoder_4(hx3)
|
1887 |
+
|
1888 |
+
hx = self.decoder_5(self.pool4(hx4))
|
1889 |
+
hx = torch.cat((self.upscore2(hx), hx4), 1)
|
1890 |
+
|
1891 |
+
d4 = self.decoder_4(hx)
|
1892 |
+
hx = torch.cat((self.upscore2(d4), hx3), 1)
|
1893 |
+
|
1894 |
+
d3 = self.decoder_3(hx)
|
1895 |
+
hx = torch.cat((self.upscore2(d3), hx2), 1)
|
1896 |
+
|
1897 |
+
d2 = self.decoder_2(hx)
|
1898 |
+
hx = torch.cat((self.upscore2(d2), hx1), 1)
|
1899 |
+
|
1900 |
+
d1 = self.decoder_1(hx)
|
1901 |
+
|
1902 |
+
x = self.conv_d0(d1)
|
1903 |
+
outs.append(x)
|
1904 |
+
return outs
|
1905 |
+
|
1906 |
+
|
1907 |
+
|
1908 |
+
### models/stem_layer.py
|
1909 |
+
|
1910 |
+
import torch.nn as nn
|
1911 |
+
# from utils import build_act_layer, build_norm_layer
|
1912 |
+
|
1913 |
+
|
1914 |
+
class StemLayer(nn.Module):
|
1915 |
+
r""" Stem layer of InternImage
|
1916 |
+
Args:
|
1917 |
+
in_channels (int): number of input channels
|
1918 |
+
out_channels (int): number of output channels
|
1919 |
+
act_layer (str): activation layer
|
1920 |
+
norm_layer (str): normalization layer
|
1921 |
+
"""
|
1922 |
+
|
1923 |
+
def __init__(self,
|
1924 |
+
in_channels=3+1,
|
1925 |
+
inter_channels=48,
|
1926 |
+
out_channels=96,
|
1927 |
+
act_layer='GELU',
|
1928 |
+
norm_layer='BN'):
|
1929 |
+
super().__init__()
|
1930 |
+
self.conv1 = nn.Conv2d(in_channels,
|
1931 |
+
inter_channels,
|
1932 |
+
kernel_size=3,
|
1933 |
+
stride=1,
|
1934 |
+
padding=1)
|
1935 |
+
self.norm1 = build_norm_layer(
|
1936 |
+
inter_channels, norm_layer, 'channels_first', 'channels_first'
|
1937 |
+
)
|
1938 |
+
self.act = build_act_layer(act_layer)
|
1939 |
+
self.conv2 = nn.Conv2d(inter_channels,
|
1940 |
+
out_channels,
|
1941 |
+
kernel_size=3,
|
1942 |
+
stride=1,
|
1943 |
+
padding=1)
|
1944 |
+
self.norm2 = build_norm_layer(
|
1945 |
+
out_channels, norm_layer, 'channels_first', 'channels_first'
|
1946 |
+
)
|
1947 |
+
|
1948 |
+
def forward(self, x):
|
1949 |
+
x = self.conv1(x)
|
1950 |
+
x = self.norm1(x)
|
1951 |
+
x = self.act(x)
|
1952 |
+
x = self.conv2(x)
|
1953 |
+
x = self.norm2(x)
|
1954 |
+
return x
|
1955 |
+
|
1956 |
+
|
1957 |
+
### models/birefnet.py
|
1958 |
+
|
1959 |
+
import torch
|
1960 |
+
import torch.nn as nn
|
1961 |
+
import torch.nn.functional as F
|
1962 |
+
from kornia.filters import laplacian
|
1963 |
+
from transformers import PreTrainedModel
|
1964 |
+
|
1965 |
+
# from config import Config
|
1966 |
+
# from dataset import class_labels_TR_sorted
|
1967 |
+
# from models.build_backbone import build_backbone
|
1968 |
+
# from models.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk
|
1969 |
+
# from models.lateral_blocks import BasicLatBlk
|
1970 |
+
# from models.aspp import ASPP, ASPPDeformable
|
1971 |
+
# from models.ing import *
|
1972 |
+
# from models.refiner import Refiner, RefinerPVTInChannels4, RefUNet
|
1973 |
+
# from models.stem_layer import StemLayer
|
1974 |
+
from .BiRefNet_config import BiRefNetConfig
|
1975 |
+
|
1976 |
+
|
1977 |
+
class BiRefNet(
|
1978 |
+
PreTrainedModel
|
1979 |
+
):
|
1980 |
+
config_class = BiRefNetConfig
|
1981 |
+
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
1982 |
+
super(BiRefNet, self).__init__(config)
|
1983 |
+
bb_pretrained = config.bb_pretrained
|
1984 |
+
self.config = Config()
|
1985 |
+
self.epoch = 1
|
1986 |
+
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
1987 |
+
|
1988 |
+
channels = self.config.lateral_channels_in_collection
|
1989 |
+
|
1990 |
+
if self.config.auxiliary_classification:
|
1991 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
1992 |
+
self.cls_head = nn.Sequential(
|
1993 |
+
nn.Linear(channels[0], len(class_labels_TR_sorted))
|
1994 |
+
)
|
1995 |
+
|
1996 |
+
if self.config.squeeze_block:
|
1997 |
+
self.squeeze_module = nn.Sequential(*[
|
1998 |
+
eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
|
1999 |
+
for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
|
2000 |
+
])
|
2001 |
+
|
2002 |
+
self.decoder = Decoder(channels)
|
2003 |
+
|
2004 |
+
if self.config.ender:
|
2005 |
+
self.dec_end = nn.Sequential(
|
2006 |
+
nn.Conv2d(1, 16, 3, 1, 1),
|
2007 |
+
nn.Conv2d(16, 1, 3, 1, 1),
|
2008 |
+
nn.ReLU(inplace=True),
|
2009 |
+
)
|
2010 |
+
|
2011 |
+
# refine patch-level segmentation
|
2012 |
+
if self.config.refine:
|
2013 |
+
if self.config.refine == 'itself':
|
2014 |
+
self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
|
2015 |
+
else:
|
2016 |
+
self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
|
2017 |
+
|
2018 |
+
if self.config.freeze_bb:
|
2019 |
+
# Freeze the backbone...
|
2020 |
+
print(self.named_parameters())
|
2021 |
+
for key, value in self.named_parameters():
|
2022 |
+
if 'bb.' in key and 'refiner.' not in key:
|
2023 |
+
value.requires_grad = False
|
2024 |
+
|
2025 |
+
def forward_enc(self, x):
|
2026 |
+
if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
|
2027 |
+
x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
|
2028 |
+
else:
|
2029 |
+
x1, x2, x3, x4 = self.bb(x)
|
2030 |
+
if self.config.mul_scl_ipt == 'cat':
|
2031 |
+
B, C, H, W = x.shape
|
2032 |
+
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
|
2033 |
+
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
2034 |
+
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
2035 |
+
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
2036 |
+
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
2037 |
+
elif self.config.mul_scl_ipt == 'add':
|
2038 |
+
B, C, H, W = x.shape
|
2039 |
+
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
|
2040 |
+
x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
2041 |
+
x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
2042 |
+
x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
2043 |
+
x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
|
2044 |
+
class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
|
2045 |
+
if self.config.cxt:
|
2046 |
+
x4 = torch.cat(
|
2047 |
+
(
|
2048 |
+
*[
|
2049 |
+
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
2050 |
+
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
2051 |
+
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
2052 |
+
][-len(self.config.cxt):],
|
2053 |
+
x4
|
2054 |
+
),
|
2055 |
+
dim=1
|
2056 |
+
)
|
2057 |
+
return (x1, x2, x3, x4), class_preds
|
2058 |
+
|
2059 |
+
def forward_ori(self, x):
|
2060 |
+
########## Encoder ##########
|
2061 |
+
(x1, x2, x3, x4), class_preds = self.forward_enc(x)
|
2062 |
+
if self.config.squeeze_block:
|
2063 |
+
x4 = self.squeeze_module(x4)
|
2064 |
+
########## Decoder ##########
|
2065 |
+
features = [x, x1, x2, x3, x4]
|
2066 |
+
if self.training and self.config.out_ref:
|
2067 |
+
features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
|
2068 |
+
scaled_preds = self.decoder(features)
|
2069 |
+
return scaled_preds, class_preds
|
2070 |
+
|
2071 |
+
def forward(self, x):
|
2072 |
+
scaled_preds, class_preds = self.forward_ori(x)
|
2073 |
+
class_preds_lst = [class_preds]
|
2074 |
+
return [scaled_preds, class_preds_lst] if self.training else scaled_preds
|
2075 |
+
|
2076 |
+
|
2077 |
+
class Decoder(nn.Module):
|
2078 |
+
def __init__(self, channels):
|
2079 |
+
super(Decoder, self).__init__()
|
2080 |
+
self.config = Config()
|
2081 |
+
DecoderBlock = eval(self.config.dec_blk)
|
2082 |
+
LateralBlock = eval(self.config.lat_blk)
|
2083 |
+
|
2084 |
+
if self.config.dec_ipt:
|
2085 |
+
self.split = self.config.dec_ipt_split
|
2086 |
+
N_dec_ipt = 64
|
2087 |
+
DBlock = SimpleConvs
|
2088 |
+
ic = 64
|
2089 |
+
ipt_cha_opt = 1
|
2090 |
+
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
2091 |
+
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
2092 |
+
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
|
2093 |
+
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
|
2094 |
+
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
|
2095 |
+
else:
|
2096 |
+
self.split = None
|
2097 |
+
|
2098 |
+
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1])
|
2099 |
+
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2])
|
2100 |
+
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3])
|
2101 |
+
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2)
|
2102 |
+
self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0))
|
2103 |
+
|
2104 |
+
self.lateral_block4 = LateralBlock(channels[1], channels[1])
|
2105 |
+
self.lateral_block3 = LateralBlock(channels[2], channels[2])
|
2106 |
+
self.lateral_block2 = LateralBlock(channels[3], channels[3])
|
2107 |
+
|
2108 |
+
if self.config.ms_supervision:
|
2109 |
+
self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
|
2110 |
+
self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
|
2111 |
+
self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
|
2112 |
+
|
2113 |
+
if self.config.out_ref:
|
2114 |
+
_N = 16
|
2115 |
+
self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
|
2116 |
+
self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
|
2117 |
+
self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
|
2118 |
+
|
2119 |
+
self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
2120 |
+
self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
2121 |
+
self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
2122 |
+
|
2123 |
+
self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
2124 |
+
self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
2125 |
+
self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
2126 |
+
|
2127 |
+
def get_patches_batch(self, x, p):
|
2128 |
+
_size_h, _size_w = p.shape[2:]
|
2129 |
+
patches_batch = []
|
2130 |
+
for idx in range(x.shape[0]):
|
2131 |
+
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
|
2132 |
+
patches_x = []
|
2133 |
+
for column_x in columns_x:
|
2134 |
+
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
|
2135 |
+
patch_sample = torch.cat(patches_x, dim=1)
|
2136 |
+
patches_batch.append(patch_sample)
|
2137 |
+
return torch.cat(patches_batch, dim=0)
|
2138 |
+
|
2139 |
+
def forward(self, features):
|
2140 |
+
if self.training and self.config.out_ref:
|
2141 |
+
outs_gdt_pred = []
|
2142 |
+
outs_gdt_label = []
|
2143 |
+
x, x1, x2, x3, x4, gdt_gt = features
|
2144 |
+
else:
|
2145 |
+
x, x1, x2, x3, x4 = features
|
2146 |
+
outs = []
|
2147 |
+
|
2148 |
+
if self.config.dec_ipt:
|
2149 |
+
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
2150 |
+
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
2151 |
+
p4 = self.decoder_block4(x4)
|
2152 |
+
m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
|
2153 |
+
if self.config.out_ref:
|
2154 |
+
p4_gdt = self.gdt_convs_4(p4)
|
2155 |
+
if self.training:
|
2156 |
+
# >> GT:
|
2157 |
+
m4_dia = m4
|
2158 |
+
gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
2159 |
+
outs_gdt_label.append(gdt_label_main_4)
|
2160 |
+
# >> Pred:
|
2161 |
+
gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
|
2162 |
+
outs_gdt_pred.append(gdt_pred_4)
|
2163 |
+
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
|
2164 |
+
# >> Finally:
|
2165 |
+
p4 = p4 * gdt_attn_4
|
2166 |
+
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
2167 |
+
_p3 = _p4 + self.lateral_block4(x3)
|
2168 |
+
|
2169 |
+
if self.config.dec_ipt:
|
2170 |
+
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
|
2171 |
+
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
|
2172 |
+
p3 = self.decoder_block3(_p3)
|
2173 |
+
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
|
2174 |
+
if self.config.out_ref:
|
2175 |
+
p3_gdt = self.gdt_convs_3(p3)
|
2176 |
+
if self.training:
|
2177 |
+
# >> GT:
|
2178 |
+
# m3 --dilation--> m3_dia
|
2179 |
+
# G_3^gt * m3_dia --> G_3^m, which is the label of gradient
|
2180 |
+
m3_dia = m3
|
2181 |
+
gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
2182 |
+
outs_gdt_label.append(gdt_label_main_3)
|
2183 |
+
# >> Pred:
|
2184 |
+
# p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
|
2185 |
+
# F_3^G --sigmoid--> A_3^G
|
2186 |
+
gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
|
2187 |
+
outs_gdt_pred.append(gdt_pred_3)
|
2188 |
+
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
2189 |
+
# >> Finally:
|
2190 |
+
# p3 = p3 * A_3^G
|
2191 |
+
p3 = p3 * gdt_attn_3
|
2192 |
+
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
2193 |
+
_p2 = _p3 + self.lateral_block3(x2)
|
2194 |
+
|
2195 |
+
if self.config.dec_ipt:
|
2196 |
+
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
|
2197 |
+
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
|
2198 |
+
p2 = self.decoder_block2(_p2)
|
2199 |
+
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
|
2200 |
+
if self.config.out_ref:
|
2201 |
+
p2_gdt = self.gdt_convs_2(p2)
|
2202 |
+
if self.training:
|
2203 |
+
# >> GT:
|
2204 |
+
m2_dia = m2
|
2205 |
+
gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
2206 |
+
outs_gdt_label.append(gdt_label_main_2)
|
2207 |
+
# >> Pred:
|
2208 |
+
gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
|
2209 |
+
outs_gdt_pred.append(gdt_pred_2)
|
2210 |
+
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
2211 |
+
# >> Finally:
|
2212 |
+
p2 = p2 * gdt_attn_2
|
2213 |
+
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
2214 |
+
_p1 = _p2 + self.lateral_block2(x1)
|
2215 |
+
|
2216 |
+
if self.config.dec_ipt:
|
2217 |
+
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
2218 |
+
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
|
2219 |
+
_p1 = self.decoder_block1(_p1)
|
2220 |
+
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
2221 |
+
|
2222 |
+
if self.config.dec_ipt:
|
2223 |
+
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
2224 |
+
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
2225 |
+
p1_out = self.conv_out1(_p1)
|
2226 |
+
|
2227 |
+
if self.config.ms_supervision:
|
2228 |
+
outs.append(m4)
|
2229 |
+
outs.append(m3)
|
2230 |
+
outs.append(m2)
|
2231 |
+
outs.append(p1_out)
|
2232 |
+
return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs)
|
2233 |
+
|
2234 |
+
|
2235 |
+
class SimpleConvs(nn.Module):
|
2236 |
+
def __init__(
|
2237 |
+
self, in_channels: int, out_channels: int, inter_channels=64
|
2238 |
+
) -> None:
|
2239 |
+
super().__init__()
|
2240 |
+
self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
|
2241 |
+
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
|
2242 |
+
|
2243 |
+
def forward(self, x):
|
2244 |
+
return self.conv_out(self.conv1(x))
|
models/RMBG/RMBG-2.0/config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "ZhengPeng7/BiRefNet",
|
3 |
+
"architectures": [
|
4 |
+
"BiRefNet"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "BiRefNet_config.BiRefNetConfig",
|
8 |
+
"AutoModelForImageSegmentation": "birefnet.BiRefNet"
|
9 |
+
},
|
10 |
+
"custom_pipelines": {
|
11 |
+
"image-segmentation": {
|
12 |
+
"pt": [
|
13 |
+
"AutoModelForImageSegmentation"
|
14 |
+
],
|
15 |
+
"tf": [],
|
16 |
+
"type": "image"
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"bb_pretrained": false
|
20 |
+
}
|
models/RMBG/RMBG-2.0/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:566ed80c3d95f87ada6864d4cbe2290a1c5eb1c7bb0b123e984f60f76b02c3a7
|
3 |
+
size 884878856
|
models/TTS/DiffRhythm/.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*jpg* filter=lfs diff=lfs merge=lfs -text
|
models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-MuLan-large/1d2f0a1aedbc66ea23e7fef7985c875c3e98c08d.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-MuLan-large/d42ae3f7cb9b66759ee0089ddc70e2f28b130c2d8ba621457358272d32dd0444.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-large-msd-iter/334df3de2832ec1acfd8b6ce54e7de4073401fe821f7ec0ad0d954832be2d26a.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--OpenMuQ--MuQ-large-msd-iter/fec6c73f7b811281b440462fcf4d98c7953c3d94.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/1960141250d189366dfb76630ba794a9c104ec07.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/34ddbd64a4cd3f2d9d8a9120d3662d0bf91baead.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/463f3414782c1c9405828c9b31bfa36dda1f45c5.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/6fd4797bc397c3b8b55d6bb5740366b57e6a3ce91c04c77f22aafc0c128e6feb.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/.locks/models--xlm-roberta-base/db9af13bf09fd3028ca32be90d3fb66d5e470399.lock
ADDED
File without changes
|
models/TTS/DiffRhythm/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
models/TTS/DiffRhythm/LICENSE.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT
|
2 |
+
|
3 |
+
Last Updated: July 5, 2024
|
4 |
+
|
5 |
+
1. INTRODUCTION
|
6 |
+
|
7 |
+
This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
8 |
+
|
9 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
10 |
+
|
11 |
+
By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf.
|
12 |
+
|
13 |
+
2. RESEARCH & NON-COMMERCIAL USE LICENSE
|
14 |
+
|
15 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
16 |
+
|
17 |
+
3. COMMERCIAL USE LICENSE
|
18 |
+
|
19 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations.
|
20 |
+
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
21 |
+
|
22 |
+
4. GENERAL TERMS
|
23 |
+
|
24 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
|
25 |
+
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified.
|
26 |
+
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
|
27 |
+
c. Intellectual Property.
|
28 |
+
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
|
29 |
+
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
30 |
+
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
|
31 |
+
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
|
32 |
+
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback.
|
33 |
+
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
34 |
+
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
35 |
+
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
|
36 |
+
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
37 |
+
|
38 |
+
5. DEFINITIONS
|
39 |
+
|
40 |
+
“Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
|
41 |
+
|
42 |
+
"Agreement" means this Stability AI Community License Agreement.
|
43 |
+
|
44 |
+
“AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
|
45 |
+
|
46 |
+
"Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model.
|
47 |
+
|
48 |
+
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
|
49 |
+
|
50 |
+
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
|
51 |
+
|
52 |
+
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
|
53 |
+
|
54 |
+
"Software" means Stability AI’s proprietary software made available under this Agreement now or in the future.
|
55 |
+
|
56 |
+
“Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
|
57 |
+
|
58 |
+
“Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
models/TTS/DiffRhythm/README.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- music
|
7 |
+
- art
|
8 |
+
- diffusion
|
9 |
+
license: other
|
10 |
+
license_name: stable-audio-community
|
11 |
+
license_link: LICENSE
|
12 |
+
library_name: DiffRhythm
|
13 |
+
---
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
<h1>DiffRhythm: Blazingly Fast and Embarrassingly Simple End-to-End Full-Length Song Generation with Latent Diffusion</h1>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
Ziqian Ning, Huakang Chen, Yuepeng Jiang, Chunbo Hao, Guobin Ma, Shuai Wang, Jixun Yao, Lei Xie†
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<a href="https://huggingface.co/spaces/ASLP-lab/DiffRhythm"> Huggingface Space</a> </a> 
|
23 |
+
<br>
|
24 |
+
📑 <a href="https://arxiv.org/abs/2503.01183">Paper</a>    |    📑 <a href="https://aslp-lab.github.io/DiffRhythm.github.io/">Demo</a>   
|
25 |
+
</p>
|
26 |
+
|
27 |
+
DiffRhythm (Chinese: 谛韵, Dì Yùn) is the ***first*** diffusion-based song generation model that is capable of creating full-length songs. The name combines "Diff" (referencing its diffusion architecture) with "Rhythm" (highlighting its focus on music and song creation). The Chinese name 谛韵 (Dì Yùn) phonetically mirrors "DiffRhythm", where "谛" (attentive listening) symbolizes auditory perception, and "韵" (melodic charm) represents musicality.
|
28 |
+
|
29 |
+
|
30 |
+
<p align="center">
|
31 |
+
<img src="src/diffrhythm.jpg" width="90%"/>
|
32 |
+
<p>
|
33 |
+
|
34 |
+
## News and Updates
|
35 |
+
|
36 |
+
### 2025.3.4 🔥 We released the [DiffRhythm paper](https://arxiv.org/abs/2503.01183) and [Huggingface Space demo](https://huggingface.co/spaces/ASLP-lab/DiffRhythm).
|
37 |
+
|
38 |
+
## TODOs
|
39 |
+
- [ ] Support local deployment:
|
40 |
+
- [ ] Support Colab:
|
41 |
+
- [ ] Support Docker:
|
42 |
+
- [x] Release paper to Arxiv.
|
43 |
+
- [x] Online serving on huggingface space.
|
44 |
+
|
45 |
+
## Model Versions
|
46 |
+
|
47 |
+
| Model | HuggingFace |
|
48 |
+
| ---- | ---- |
|
49 |
+
| DiffRhythm-base (1m35s) | https://huggingface.co/ASLP-lab/DiffRhythm-base |
|
50 |
+
| DiffRhythm-full (4m45s) | Coming soon... |
|
51 |
+
| DiffRhythm-vae | https://huggingface.co/ASLP-lab/DiffRhythm-vae |
|
52 |
+
|
53 |
+
|
54 |
+
## License & Disclaimer
|
55 |
+
|
56 |
+
As the VAE is fine-tuned from [Stable Audio Open](https://huggingface.co/stabilityai/stable-audio-open-1.0), DiffRhythm is subject to the [Stability AI Community License Agreement](LICENSE.md)
|
57 |
+
|
58 |
+
DiffRhythm enables the creation of original music across diverse genres, supporting applications in artistic creation, education, and entertainment. While designed for positive use cases, potential risks include unintentional copyright infringement through stylistic similarities, inappropriate blending of cultural musical elements, and misuse for generating harmful content. To ensure responsible deployment, users must implement verification mechanisms to confirm musical originality, disclose AI involvement in generated works, and obtain permissions when adapting protected styles.
|
59 |
+
|
60 |
+
## Citation
|
61 |
+
```
|
62 |
+
@article{ning2025diffrhythm,
|
63 |
+
title={{DiffRhythm}: Blazingly Fast and Embarrassingly Simple</br>End-to-End Full-Length Song Generation with Latent Diffusion<},
|
64 |
+
author={Ziqian, Ning and Huakang, Chen and Yuepeng, Jiang and Chunbo, Hao and Guobin, Ma and Shuai, Wang and Jixun, Yao and Lei, Xie},
|
65 |
+
journal={arXiv preprint arXiv:2503.01183},
|
66 |
+
year={2025}
|
67 |
+
}
|
68 |
+
```
|
69 |
+
## Contact Us
|
70 |
+
|
71 |
+
If you are interested in leaving a message to our research team, feel free to email `[email protected]`.
|
72 |
+
<p align="center">
|
73 |
+
<a href="http://www.nwpu-aslp.org/">
|
74 |
+
<img src="src/ASLP.jpg" width="400"/>
|
75 |
+
</a>
|
76 |
+
</p>
|
models/TTS/DiffRhythm/cfm_full_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8f33d24107bfdcf98cf9b7c427800cc35a2b61023750518d3fff58e617f0a47
|
3 |
+
size 2218708010
|
models/TTS/DiffRhythm/cfm_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:128a2513fca01531339cc5698c30e4f5a375256d1ed2257b464db9c886d6801a
|
3 |
+
size 2218706831
|
models/TTS/DiffRhythm/config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "diffrhythm",
|
3 |
+
"model": {
|
4 |
+
"dim": 2048,
|
5 |
+
"depth": 16,
|
6 |
+
"heads": 32,
|
7 |
+
"ff_mult": 4,
|
8 |
+
"text_dim": 512,
|
9 |
+
"conv_layers": 4,
|
10 |
+
"mel_dim": 64,
|
11 |
+
"text_num_embeds": 363
|
12 |
+
}
|
13 |
+
}
|
models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/.no_exist/8a081dbcf84edd47ea7db3c4ecb8fd1ec1ddacfe/model.safetensors
ADDED
File without changes
|
models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/blobs/1d2f0a1aedbc66ea23e7fef7985c875c3e98c08d
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"mulan": {
|
3 |
+
"sr": 24000,
|
4 |
+
"clip_secs": 10,
|
5 |
+
"dim_latent": 512,
|
6 |
+
"decoupled_contrastive_learning": true,
|
7 |
+
"hierarchical_contrastive_loss": false,
|
8 |
+
"hierarchical_contrastive_loss_layers": null,
|
9 |
+
"sigmoid_contrastive_loss": false,
|
10 |
+
"rank_contrast": true
|
11 |
+
},
|
12 |
+
"audio_model": {
|
13 |
+
"name": "OpenMuQ/MuQ-large-msd-iter",
|
14 |
+
"model_dim": 1024,
|
15 |
+
"use_layer_idx": -1
|
16 |
+
},
|
17 |
+
"text_model": {
|
18 |
+
"name": "xlm-roberta-base",
|
19 |
+
"model_dim": null,
|
20 |
+
"use_layer_idx": -1
|
21 |
+
},
|
22 |
+
"audio_transformer": {
|
23 |
+
"dim": 768,
|
24 |
+
"tf_depth": 0,
|
25 |
+
"heads": 8,
|
26 |
+
"dim_head": 64,
|
27 |
+
"attn_dropout": 0,
|
28 |
+
"ff_dropout": 0,
|
29 |
+
"ff_mult": 4
|
30 |
+
},
|
31 |
+
"text_transformer": {
|
32 |
+
"dim": 768,
|
33 |
+
"tf_depth": 8,
|
34 |
+
"max_seq_len": 1024,
|
35 |
+
"dim_head": 64,
|
36 |
+
"heads": 8,
|
37 |
+
"attn_dropout": 0,
|
38 |
+
"ff_dropout": 0,
|
39 |
+
"ff_mult": 4
|
40 |
+
}
|
41 |
+
}
|
models/TTS/DiffRhythm/models--OpenMuQ--MuQ-MuLan-large/blobs/d42ae3f7cb9b66759ee0089ddc70e2f28b130c2d8ba621457358272d32dd0444
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d42ae3f7cb9b66759ee0089ddc70e2f28b130c2d8ba621457358272d32dd0444
|
3 |
+
size 2653954401
|