Fly-ShuAI commited on
Commit
df946da
·
verified ·
1 Parent(s): 7b39fe7

Delete diffsynth

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffsynth/__init__.py +0 -6
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +0 -806
  4. diffsynth/controlnets/__init__.py +0 -2
  5. diffsynth/controlnets/controlnet_unit.py +0 -91
  6. diffsynth/controlnets/processors.py +0 -62
  7. diffsynth/data/__init__.py +0 -1
  8. diffsynth/data/simple_text_image.py +0 -41
  9. diffsynth/data/video.py +0 -148
  10. diffsynth/distributed/__init__.py +0 -0
  11. diffsynth/distributed/xdit_context_parallel.py +0 -129
  12. diffsynth/extensions/ESRGAN/__init__.py +0 -137
  13. diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc +0 -0
  14. diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-311.pyc +0 -0
  15. diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-312.pyc +0 -0
  16. diffsynth/extensions/FastBlend/__init__.py +0 -63
  17. diffsynth/extensions/FastBlend/api.py +0 -397
  18. diffsynth/extensions/FastBlend/cupy_kernels.py +0 -119
  19. diffsynth/extensions/FastBlend/data.py +0 -146
  20. diffsynth/extensions/FastBlend/patch_match.py +0 -298
  21. diffsynth/extensions/FastBlend/runners/__init__.py +0 -4
  22. diffsynth/extensions/FastBlend/runners/accurate.py +0 -35
  23. diffsynth/extensions/FastBlend/runners/balanced.py +0 -46
  24. diffsynth/extensions/FastBlend/runners/fast.py +0 -141
  25. diffsynth/extensions/FastBlend/runners/interpolation.py +0 -121
  26. diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py +0 -1
  27. diffsynth/extensions/ImageQualityMetric/BLIP/blip.py +0 -77
  28. diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py +0 -44
  29. diffsynth/extensions/ImageQualityMetric/BLIP/med.py +0 -947
  30. diffsynth/extensions/ImageQualityMetric/BLIP/vit.py +0 -301
  31. diffsynth/extensions/ImageQualityMetric/__init__.py +0 -148
  32. diffsynth/extensions/ImageQualityMetric/aesthetic.py +0 -148
  33. diffsynth/extensions/ImageQualityMetric/clip.py +0 -97
  34. diffsynth/extensions/ImageQualityMetric/config.py +0 -23
  35. diffsynth/extensions/ImageQualityMetric/hps.py +0 -118
  36. diffsynth/extensions/ImageQualityMetric/imagereward.py +0 -212
  37. diffsynth/extensions/ImageQualityMetric/mps.py +0 -129
  38. diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py +0 -14
  39. diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py +0 -458
  40. diffsynth/extensions/ImageQualityMetric/open_clip/constants.py +0 -2
  41. diffsynth/extensions/ImageQualityMetric/open_clip/factory.py +0 -433
  42. diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py +0 -0
  43. diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py +0 -45
  44. diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py +0 -176
  45. diffsynth/extensions/ImageQualityMetric/open_clip/loss.py +0 -270
  46. diffsynth/extensions/ImageQualityMetric/open_clip/model.py +0 -461
  47. diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json +0 -17
  48. diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py +0 -181
  49. diffsynth/extensions/ImageQualityMetric/open_clip/openai.py +0 -144
  50. diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py +0 -376
diffsynth/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .data import *
2
- from .models import *
3
- from .prompters import *
4
- from .schedulers import *
5
- from .pipelines import *
6
- from .controlnets import *
 
 
 
 
 
 
 
diffsynth/configs/__init__.py DELETED
File without changes
diffsynth/configs/model_config.py DELETED
@@ -1,806 +0,0 @@
1
- from typing_extensions import Literal, TypeAlias
2
-
3
- from ..models.sd_text_encoder import SDTextEncoder
4
- from ..models.sd_unet import SDUNet
5
- from ..models.sd_vae_encoder import SDVAEEncoder
6
- from ..models.sd_vae_decoder import SDVAEDecoder
7
-
8
- from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
9
- from ..models.sdxl_unet import SDXLUNet
10
- from ..models.sdxl_vae_decoder import SDXLVAEDecoder
11
- from ..models.sdxl_vae_encoder import SDXLVAEEncoder
12
-
13
- from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
14
- from ..models.sd3_dit import SD3DiT
15
- from ..models.sd3_vae_decoder import SD3VAEDecoder
16
- from ..models.sd3_vae_encoder import SD3VAEEncoder
17
-
18
- from ..models.sd_controlnet import SDControlNet
19
- from ..models.sdxl_controlnet import SDXLControlNetUnion
20
-
21
- from ..models.sd_motion import SDMotionModel
22
- from ..models.sdxl_motion import SDXLMotionModel
23
-
24
- from ..models.svd_image_encoder import SVDImageEncoder
25
- from ..models.svd_unet import SVDUNet
26
- from ..models.svd_vae_decoder import SVDVAEDecoder
27
- from ..models.svd_vae_encoder import SVDVAEEncoder
28
-
29
- from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
30
- from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
31
-
32
- from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
33
- from ..models.hunyuan_dit import HunyuanDiT
34
-
35
- from ..models.flux_dit import FluxDiT
36
- from ..models.flux_text_encoder import FluxTextEncoder2
37
- from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
38
- from ..models.flux_controlnet import FluxControlNet
39
- from ..models.flux_ipadapter import FluxIpAdapter
40
- from ..models.flux_infiniteyou import InfiniteYouImageProjector
41
-
42
- from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
43
- from ..models.cog_dit import CogDiT
44
-
45
- from ..models.omnigen import OmniGenTransformer
46
-
47
- from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
48
- from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
49
-
50
- from ..extensions.RIFE import IFNet
51
- from ..extensions.ESRGAN import RRDBNet
52
-
53
- from ..models.hunyuan_video_dit import HunyuanVideoDiT
54
-
55
- from ..models.stepvideo_vae import StepVideoVAE
56
- from ..models.stepvideo_dit import StepVideoModel
57
-
58
- from ..models.wan_video_dit import WanModel
59
- from ..models.wan_video_text_encoder import WanTextEncoder
60
- from ..models.wan_video_image_encoder import WanImageEncoder
61
- from ..models.wan_video_vae import WanVideoVAE
62
- from ..models.wan_video_motion_controller import WanMotionControllerModel
63
-
64
-
65
- model_loader_configs = [
66
- # These configs are provided for detecting model type automatically.
67
- # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
68
- (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
69
- (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
70
- (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
71
- (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
72
- (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
73
- (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
74
- (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
75
- (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
76
- (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
77
- (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
78
- (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
79
- (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
80
- (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
81
- (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
82
- (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
83
- (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
84
- (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
85
- (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
86
- (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
87
- (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
88
- (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
89
- (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
90
- (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
91
- (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
92
- (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
93
- (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
94
- (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
95
- (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
96
- (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
97
- (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
98
- (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
99
- (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
100
- (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
101
- (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
102
- (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
103
- (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
104
- (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
105
- (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
106
- (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
107
- (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
108
- (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
109
- (None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
110
- (None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
111
- (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
112
- (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
113
- (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
114
- (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
115
- (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
116
- (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
117
- (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
118
- (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
119
- (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
120
- (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
121
- (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
122
- (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
123
- (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
124
- (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
125
- (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
126
- (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
127
- (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
128
- (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
129
- (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
130
- (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
131
- (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
132
- (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
133
- (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
134
- ]
135
- huggingface_model_loader_configs = [
136
- # These configs are provided for detecting model type automatically.
137
- # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
138
- ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
139
- ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
140
- ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
141
- ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
142
- # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
143
- ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
144
- ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
145
- ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
146
- ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
147
- ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
148
- ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
149
- ]
150
- patch_model_loader_configs = [
151
- # These configs are provided for detecting model type automatically.
152
- # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
153
- ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
154
- ]
155
-
156
- preset_models_on_huggingface = {
157
- "HunyuanDiT": [
158
- ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
159
- ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
160
- ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
161
- ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
162
- ],
163
- "stable-video-diffusion-img2vid-xt": [
164
- ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
165
- ],
166
- "ExVideo-SVD-128f-v1": [
167
- ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
168
- ],
169
- # Stable Diffusion
170
- "StableDiffusion_v15": [
171
- ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
172
- ],
173
- "DreamShaper_8": [
174
- ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
175
- ],
176
- # Textual Inversion
177
- "TextualInversion_VeryBadImageNegative_v1.3": [
178
- ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
179
- ],
180
- # Stable Diffusion XL
181
- "StableDiffusionXL_v1": [
182
- ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
183
- ],
184
- "BluePencilXL_v200": [
185
- ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
186
- ],
187
- "StableDiffusionXL_Turbo": [
188
- ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
189
- ],
190
- # Stable Diffusion 3
191
- "StableDiffusion3": [
192
- ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
193
- ],
194
- "StableDiffusion3_without_T5": [
195
- ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
196
- ],
197
- # ControlNet
198
- "ControlNet_v11f1p_sd15_depth": [
199
- ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
200
- ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
201
- ],
202
- "ControlNet_v11p_sd15_softedge": [
203
- ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
204
- ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
205
- ],
206
- "ControlNet_v11f1e_sd15_tile": [
207
- ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
208
- ],
209
- "ControlNet_v11p_sd15_lineart": [
210
- ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
211
- ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
212
- ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
213
- ],
214
- "ControlNet_union_sdxl_promax": [
215
- ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
216
- ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
217
- ],
218
- # AnimateDiff
219
- "AnimateDiff_v2": [
220
- ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
221
- ],
222
- "AnimateDiff_xl_beta": [
223
- ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
224
- ],
225
-
226
- # Qwen Prompt
227
- "QwenPrompt": [
228
- ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
229
- ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
230
- ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
231
- ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
232
- ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
233
- ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
234
- ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
235
- ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
236
- ],
237
- # Beautiful Prompt
238
- "BeautifulPrompt": [
239
- ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
240
- ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
241
- ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
242
- ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
243
- ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
244
- ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
245
- ],
246
- # Omost prompt
247
- "OmostPrompt":[
248
- ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
249
- ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
250
- ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
251
- ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
252
- ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
253
- ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
254
- ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
255
- ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
256
- ],
257
- # Translator
258
- "opus-mt-zh-en": [
259
- ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
260
- ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
261
- ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
262
- ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
263
- ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
264
- ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
265
- ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
266
- ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
267
- ],
268
- # IP-Adapter
269
- "IP-Adapter-SD": [
270
- ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
271
- ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
272
- ],
273
- "IP-Adapter-SDXL": [
274
- ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
275
- ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
276
- ],
277
- "SDXL-vae-fp16-fix": [
278
- ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
279
- ],
280
- # Kolors
281
- "Kolors": [
282
- ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
283
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
284
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
285
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
286
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
287
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
288
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
289
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
290
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
291
- ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
292
- ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
293
- ],
294
- # FLUX
295
- "FLUX.1-dev": [
296
- ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
297
- ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
298
- ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
299
- ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
300
- ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
301
- ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
302
- ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
303
- ],
304
- "InstantX/FLUX.1-dev-IP-Adapter": {
305
- "file_list": [
306
- ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
307
- ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
308
- ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
309
- ],
310
- "load_path": [
311
- "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
312
- "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
313
- ],
314
- },
315
- # RIFE
316
- "RIFE": [
317
- ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
318
- ],
319
- # CogVideo
320
- "CogVideoX-5B": [
321
- ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
322
- ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
323
- ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
324
- ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
325
- ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
326
- ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
327
- ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
328
- ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
329
- ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
330
- ],
331
- # Stable Diffusion 3.5
332
- "StableDiffusion3.5-large": [
333
- ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
334
- ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
335
- ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
336
- ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
337
- ],
338
- }
339
- preset_models_on_modelscope = {
340
- # Hunyuan DiT
341
- "HunyuanDiT": [
342
- ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
343
- ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
344
- ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
345
- ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
346
- ],
347
- # Stable Video Diffusion
348
- "stable-video-diffusion-img2vid-xt": [
349
- ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
350
- ],
351
- # ExVideo
352
- "ExVideo-SVD-128f-v1": [
353
- ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
354
- ],
355
- "ExVideo-CogVideoX-LoRA-129f-v1": [
356
- ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
357
- ],
358
- # Stable Diffusion
359
- "StableDiffusion_v15": [
360
- ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
361
- ],
362
- "DreamShaper_8": [
363
- ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
364
- ],
365
- "AingDiffusion_v12": [
366
- ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
367
- ],
368
- "Flat2DAnimerge_v45Sharp": [
369
- ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
370
- ],
371
- # Textual Inversion
372
- "TextualInversion_VeryBadImageNegative_v1.3": [
373
- ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
374
- ],
375
- # Stable Diffusion XL
376
- "StableDiffusionXL_v1": [
377
- ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
378
- ],
379
- "BluePencilXL_v200": [
380
- ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
381
- ],
382
- "StableDiffusionXL_Turbo": [
383
- ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
384
- ],
385
- "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
386
- ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
387
- ],
388
- # Stable Diffusion 3
389
- "StableDiffusion3": [
390
- ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
391
- ],
392
- "StableDiffusion3_without_T5": [
393
- ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
394
- ],
395
- # ControlNet
396
- "ControlNet_v11f1p_sd15_depth": [
397
- ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
398
- ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
399
- ],
400
- "ControlNet_v11p_sd15_softedge": [
401
- ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
402
- ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
403
- ],
404
- "ControlNet_v11f1e_sd15_tile": [
405
- ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
406
- ],
407
- "ControlNet_v11p_sd15_lineart": [
408
- ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
409
- ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
410
- ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
411
- ],
412
- "ControlNet_union_sdxl_promax": [
413
- ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
414
- ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
415
- ],
416
- "Annotators:Depth": [
417
- ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
418
- ],
419
- "Annotators:Softedge": [
420
- ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
421
- ],
422
- "Annotators:Lineart": [
423
- ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
424
- ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
425
- ],
426
- "Annotators:Normal": [
427
- ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
428
- ],
429
- "Annotators:Openpose": [
430
- ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
431
- ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
432
- ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
433
- ],
434
- # AnimateDiff
435
- "AnimateDiff_v2": [
436
- ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
437
- ],
438
- "AnimateDiff_xl_beta": [
439
- ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
440
- ],
441
- # RIFE
442
- "RIFE": [
443
- ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
444
- ],
445
- # Qwen Prompt
446
- "QwenPrompt": {
447
- "file_list": [
448
- ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
449
- ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
450
- ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
451
- ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
452
- ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
453
- ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
454
- ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
455
- ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
456
- ],
457
- "load_path": [
458
- "models/QwenPrompt/qwen2-1.5b-instruct",
459
- ],
460
- },
461
- # Beautiful Prompt
462
- "BeautifulPrompt": {
463
- "file_list": [
464
- ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
465
- ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
466
- ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
467
- ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
468
- ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
469
- ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
470
- ],
471
- "load_path": [
472
- "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
473
- ],
474
- },
475
- # Omost prompt
476
- "OmostPrompt": {
477
- "file_list": [
478
- ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
479
- ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
480
- ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
481
- ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
482
- ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
483
- ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
484
- ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
485
- ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
486
- ],
487
- "load_path": [
488
- "models/OmostPrompt/omost-llama-3-8b-4bits",
489
- ],
490
- },
491
- # Translator
492
- "opus-mt-zh-en": {
493
- "file_list": [
494
- ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
495
- ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
496
- ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
497
- ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
498
- ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
499
- ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
500
- ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
501
- ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
502
- ],
503
- "load_path": [
504
- "models/translator/opus-mt-zh-en",
505
- ],
506
- },
507
- # IP-Adapter
508
- "IP-Adapter-SD": [
509
- ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
510
- ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
511
- ],
512
- "IP-Adapter-SDXL": [
513
- ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
514
- ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
515
- ],
516
- # Kolors
517
- "Kolors": {
518
- "file_list": [
519
- ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
520
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
521
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
522
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
523
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
524
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
525
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
526
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
527
- ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
528
- ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
529
- ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
530
- ],
531
- "load_path": [
532
- "models/kolors/Kolors/text_encoder",
533
- "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
534
- "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
535
- ],
536
- },
537
- "SDXL-vae-fp16-fix": [
538
- ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
539
- ],
540
- # FLUX
541
- "FLUX.1-dev": {
542
- "file_list": [
543
- ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
544
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
545
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
546
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
547
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
548
- ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
549
- ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
550
- ],
551
- "load_path": [
552
- "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
553
- "models/FLUX/FLUX.1-dev/text_encoder_2",
554
- "models/FLUX/FLUX.1-dev/ae.safetensors",
555
- "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
556
- ],
557
- },
558
- "FLUX.1-schnell": {
559
- "file_list": [
560
- ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
561
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
562
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
563
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
564
- ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
565
- ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
566
- ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
567
- ],
568
- "load_path": [
569
- "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
570
- "models/FLUX/FLUX.1-dev/text_encoder_2",
571
- "models/FLUX/FLUX.1-dev/ae.safetensors",
572
- "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
573
- ],
574
- },
575
- "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
576
- ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
577
- ],
578
- "jasperai/Flux.1-dev-Controlnet-Depth": [
579
- ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
580
- ],
581
- "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
582
- ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
583
- ],
584
- "jasperai/Flux.1-dev-Controlnet-Upscaler": [
585
- ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
586
- ],
587
- "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
588
- ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
589
- ],
590
- "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
591
- ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
592
- ],
593
- "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
594
- ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
595
- ],
596
- "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
597
- ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
598
- ],
599
- "InstantX/FLUX.1-dev-IP-Adapter": {
600
- "file_list": [
601
- ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
602
- ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
603
- ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
604
- ],
605
- "load_path": [
606
- "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
607
- "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
608
- ],
609
- },
610
- "InfiniteYou":{
611
- "file_list":[
612
- ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
613
- ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
614
- ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
615
- ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
616
- ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
617
- ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
618
- ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
619
- ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
620
- ],
621
- "load_path":[
622
- [
623
- "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
624
- "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
625
- ],
626
- "models/InfiniteYou/image_proj_model.bin",
627
- ],
628
- },
629
- # ESRGAN
630
- "ESRGAN_x4": [
631
- ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
632
- ],
633
- # RIFE
634
- "RIFE": [
635
- ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
636
- ],
637
- # Omnigen
638
- "OmniGen-v1": {
639
- "file_list": [
640
- ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
641
- ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
642
- ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
643
- ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
644
- ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
645
- ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
646
- ],
647
- "load_path": [
648
- "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
649
- "models/OmniGen/OmniGen-v1/model.safetensors",
650
- ]
651
- },
652
- # CogVideo
653
- "CogVideoX-5B": {
654
- "file_list": [
655
- ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
656
- ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
657
- ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
658
- ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
659
- ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
660
- ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
661
- ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
662
- ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
663
- ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
664
- ],
665
- "load_path": [
666
- "models/CogVideo/CogVideoX-5b/text_encoder",
667
- "models/CogVideo/CogVideoX-5b/transformer",
668
- "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
669
- ],
670
- },
671
- # Stable Diffusion 3.5
672
- "StableDiffusion3.5-large": [
673
- ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
674
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
675
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
676
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
677
- ],
678
- "StableDiffusion3.5-medium": [
679
- ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
680
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
681
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
682
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
683
- ],
684
- "StableDiffusion3.5-large-turbo": [
685
- ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
686
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
687
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
688
- ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
689
- ],
690
- "HunyuanVideo":{
691
- "file_list": [
692
- ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
693
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
694
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
695
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
696
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
697
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
698
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
699
- ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
700
- ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
701
- ],
702
- "load_path": [
703
- "models/HunyuanVideo/text_encoder/model.safetensors",
704
- "models/HunyuanVideo/text_encoder_2",
705
- "models/HunyuanVideo/vae/pytorch_model.pt",
706
- "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
707
- ],
708
- },
709
- "HunyuanVideoI2V":{
710
- "file_list": [
711
- ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
712
- ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
713
- ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
714
- ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
715
- ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
716
- ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
717
- ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
718
- ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
719
- ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
720
- ],
721
- "load_path": [
722
- "models/HunyuanVideoI2V/text_encoder/model.safetensors",
723
- "models/HunyuanVideoI2V/text_encoder_2",
724
- "models/HunyuanVideoI2V/vae/pytorch_model.pt",
725
- "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
726
- ],
727
- },
728
- "HunyuanVideo-fp8":{
729
- "file_list": [
730
- ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
731
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
732
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
733
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
734
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
735
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
736
- ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
737
- ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
738
- ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
739
- ],
740
- "load_path": [
741
- "models/HunyuanVideo/text_encoder/model.safetensors",
742
- "models/HunyuanVideo/text_encoder_2",
743
- "models/HunyuanVideo/vae/pytorch_model.pt",
744
- "models/HunyuanVideo/transformers/model.fp8.safetensors"
745
- ],
746
- },
747
- }
748
- Preset_model_id: TypeAlias = Literal[
749
- "HunyuanDiT",
750
- "stable-video-diffusion-img2vid-xt",
751
- "ExVideo-SVD-128f-v1",
752
- "ExVideo-CogVideoX-LoRA-129f-v1",
753
- "StableDiffusion_v15",
754
- "DreamShaper_8",
755
- "AingDiffusion_v12",
756
- "Flat2DAnimerge_v45Sharp",
757
- "TextualInversion_VeryBadImageNegative_v1.3",
758
- "StableDiffusionXL_v1",
759
- "BluePencilXL_v200",
760
- "StableDiffusionXL_Turbo",
761
- "ControlNet_v11f1p_sd15_depth",
762
- "ControlNet_v11p_sd15_softedge",
763
- "ControlNet_v11f1e_sd15_tile",
764
- "ControlNet_v11p_sd15_lineart",
765
- "AnimateDiff_v2",
766
- "AnimateDiff_xl_beta",
767
- "RIFE",
768
- "BeautifulPrompt",
769
- "opus-mt-zh-en",
770
- "IP-Adapter-SD",
771
- "IP-Adapter-SDXL",
772
- "StableDiffusion3",
773
- "StableDiffusion3_without_T5",
774
- "Kolors",
775
- "SDXL-vae-fp16-fix",
776
- "ControlNet_union_sdxl_promax",
777
- "FLUX.1-dev",
778
- "FLUX.1-schnell",
779
- "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
780
- "jasperai/Flux.1-dev-Controlnet-Depth",
781
- "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
782
- "jasperai/Flux.1-dev-Controlnet-Upscaler",
783
- "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
784
- "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
785
- "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
786
- "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
787
- "InstantX/FLUX.1-dev-IP-Adapter",
788
- "InfiniteYou",
789
- "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
790
- "QwenPrompt",
791
- "OmostPrompt",
792
- "ESRGAN_x4",
793
- "RIFE",
794
- "OmniGen-v1",
795
- "CogVideoX-5B",
796
- "Annotators:Depth",
797
- "Annotators:Softedge",
798
- "Annotators:Lineart",
799
- "Annotators:Normal",
800
- "Annotators:Openpose",
801
- "StableDiffusion3.5-large",
802
- "StableDiffusion3.5-medium",
803
- "HunyuanVideo",
804
- "HunyuanVideo-fp8",
805
- "HunyuanVideoI2V",
806
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/controlnets/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
2
- from .processors import Annotator
 
 
 
diffsynth/controlnets/controlnet_unit.py DELETED
@@ -1,91 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from .processors import Processor_id
4
-
5
-
6
- class ControlNetConfigUnit:
7
- def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
8
- self.processor_id = processor_id
9
- self.model_path = model_path
10
- self.scale = scale
11
- self.skip_processor = skip_processor
12
-
13
-
14
- class ControlNetUnit:
15
- def __init__(self, processor, model, scale=1.0):
16
- self.processor = processor
17
- self.model = model
18
- self.scale = scale
19
-
20
-
21
- class MultiControlNetManager:
22
- def __init__(self, controlnet_units=[]):
23
- self.processors = [unit.processor for unit in controlnet_units]
24
- self.models = [unit.model for unit in controlnet_units]
25
- self.scales = [unit.scale for unit in controlnet_units]
26
-
27
- def cpu(self):
28
- for model in self.models:
29
- model.cpu()
30
-
31
- def to(self, device):
32
- for model in self.models:
33
- model.to(device)
34
- for processor in self.processors:
35
- processor.to(device)
36
-
37
- def process_image(self, image, processor_id=None):
38
- if processor_id is None:
39
- processed_image = [processor(image) for processor in self.processors]
40
- else:
41
- processed_image = [self.processors[processor_id](image)]
42
- processed_image = torch.concat([
43
- torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
44
- for image_ in processed_image
45
- ], dim=0)
46
- return processed_image
47
-
48
- def __call__(
49
- self,
50
- sample, timestep, encoder_hidden_states, conditionings,
51
- tiled=False, tile_size=64, tile_stride=32, **kwargs
52
- ):
53
- res_stack = None
54
- for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
55
- res_stack_ = model(
56
- sample, timestep, encoder_hidden_states, conditioning, **kwargs,
57
- tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
58
- processor_id=processor.processor_id
59
- )
60
- res_stack_ = [res * scale for res in res_stack_]
61
- if res_stack is None:
62
- res_stack = res_stack_
63
- else:
64
- res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
65
- return res_stack
66
-
67
-
68
- class FluxMultiControlNetManager(MultiControlNetManager):
69
- def __init__(self, controlnet_units=[]):
70
- super().__init__(controlnet_units=controlnet_units)
71
-
72
- def process_image(self, image, processor_id=None):
73
- if processor_id is None:
74
- processed_image = [processor(image) for processor in self.processors]
75
- else:
76
- processed_image = [self.processors[processor_id](image)]
77
- return processed_image
78
-
79
- def __call__(self, conditionings, **kwargs):
80
- res_stack, single_res_stack = None, None
81
- for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
82
- res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
83
- res_stack_ = [res * scale for res in res_stack_]
84
- single_res_stack_ = [res * scale for res in single_res_stack_]
85
- if res_stack is None:
86
- res_stack = res_stack_
87
- single_res_stack = single_res_stack_
88
- else:
89
- res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
90
- single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
91
- return res_stack, single_res_stack
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/controlnets/processors.py DELETED
@@ -1,62 +0,0 @@
1
- from typing_extensions import Literal, TypeAlias
2
-
3
-
4
- Processor_id: TypeAlias = Literal[
5
- "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
6
- ]
7
-
8
- class Annotator:
9
- def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
10
- if not skip_processor:
11
- if processor_id == "canny":
12
- from controlnet_aux.processor import CannyDetector
13
- self.processor = CannyDetector()
14
- elif processor_id == "depth":
15
- from controlnet_aux.processor import MidasDetector
16
- self.processor = MidasDetector.from_pretrained(model_path).to(device)
17
- elif processor_id == "softedge":
18
- from controlnet_aux.processor import HEDdetector
19
- self.processor = HEDdetector.from_pretrained(model_path).to(device)
20
- elif processor_id == "lineart":
21
- from controlnet_aux.processor import LineartDetector
22
- self.processor = LineartDetector.from_pretrained(model_path).to(device)
23
- elif processor_id == "lineart_anime":
24
- from controlnet_aux.processor import LineartAnimeDetector
25
- self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
26
- elif processor_id == "openpose":
27
- from controlnet_aux.processor import OpenposeDetector
28
- self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
29
- elif processor_id == "normal":
30
- from controlnet_aux.processor import NormalBaeDetector
31
- self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
32
- elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
33
- self.processor = None
34
- else:
35
- raise ValueError(f"Unsupported processor_id: {processor_id}")
36
- else:
37
- self.processor = None
38
-
39
- self.processor_id = processor_id
40
- self.detect_resolution = detect_resolution
41
-
42
- def to(self,device):
43
- if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
44
-
45
- self.processor.model.to(device)
46
-
47
- def __call__(self, image, mask=None):
48
- width, height = image.size
49
- if self.processor_id == "openpose":
50
- kwargs = {
51
- "include_body": True,
52
- "include_hand": True,
53
- "include_face": True
54
- }
55
- else:
56
- kwargs = {}
57
- if self.processor is not None:
58
- detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
59
- image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
60
- image = image.resize((width, height))
61
- return image
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .video import VideoData, save_video, save_frames
 
 
diffsynth/data/simple_text_image.py DELETED
@@ -1,41 +0,0 @@
1
- import torch, os, torchvision
2
- from torchvision import transforms
3
- import pandas as pd
4
- from PIL import Image
5
-
6
-
7
-
8
- class TextImageDataset(torch.utils.data.Dataset):
9
- def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10
- self.steps_per_epoch = steps_per_epoch
11
- metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12
- self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13
- self.text = metadata["text"].to_list()
14
- self.height = height
15
- self.width = width
16
- self.image_processor = transforms.Compose(
17
- [
18
- transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
19
- transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
20
- transforms.ToTensor(),
21
- transforms.Normalize([0.5], [0.5]),
22
- ]
23
- )
24
-
25
-
26
- def __getitem__(self, index):
27
- data_id = torch.randint(0, len(self.path), (1,))[0]
28
- data_id = (data_id + index) % len(self.path) # For fixed seed.
29
- text = self.text[data_id]
30
- image = Image.open(self.path[data_id]).convert("RGB")
31
- target_height, target_width = self.height, self.width
32
- width, height = image.size
33
- scale = max(target_width / width, target_height / height)
34
- shape = [round(height*scale),round(width*scale)]
35
- image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
36
- image = self.image_processor(image)
37
- return {"text": text, "image": image}
38
-
39
-
40
- def __len__(self):
41
- return self.steps_per_epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/data/video.py DELETED
@@ -1,148 +0,0 @@
1
- import imageio, os
2
- import numpy as np
3
- from PIL import Image
4
- from tqdm import tqdm
5
-
6
-
7
- class LowMemoryVideo:
8
- def __init__(self, file_name):
9
- self.reader = imageio.get_reader(file_name)
10
-
11
- def __len__(self):
12
- return self.reader.count_frames()
13
-
14
- def __getitem__(self, item):
15
- return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
-
17
- def __del__(self):
18
- self.reader.close()
19
-
20
-
21
- def split_file_name(file_name):
22
- result = []
23
- number = -1
24
- for i in file_name:
25
- if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
- if number == -1:
27
- number = 0
28
- number = number*10 + ord(i) - ord("0")
29
- else:
30
- if number != -1:
31
- result.append(number)
32
- number = -1
33
- result.append(i)
34
- if number != -1:
35
- result.append(number)
36
- result = tuple(result)
37
- return result
38
-
39
-
40
- def search_for_images(folder):
41
- file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
- file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
- file_list = [i[1] for i in sorted(file_list)]
44
- file_list = [os.path.join(folder, i) for i in file_list]
45
- return file_list
46
-
47
-
48
- class LowMemoryImageFolder:
49
- def __init__(self, folder, file_list=None):
50
- if file_list is None:
51
- self.file_list = search_for_images(folder)
52
- else:
53
- self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
-
55
- def __len__(self):
56
- return len(self.file_list)
57
-
58
- def __getitem__(self, item):
59
- return Image.open(self.file_list[item]).convert("RGB")
60
-
61
- def __del__(self):
62
- pass
63
-
64
-
65
- def crop_and_resize(image, height, width):
66
- image = np.array(image)
67
- image_height, image_width, _ = image.shape
68
- if image_height / image_width < height / width:
69
- croped_width = int(image_height / height * width)
70
- left = (image_width - croped_width) // 2
71
- image = image[:, left: left+croped_width]
72
- image = Image.fromarray(image).resize((width, height))
73
- else:
74
- croped_height = int(image_width / width * height)
75
- left = (image_height - croped_height) // 2
76
- image = image[left: left+croped_height, :]
77
- image = Image.fromarray(image).resize((width, height))
78
- return image
79
-
80
-
81
- class VideoData:
82
- def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
- if video_file is not None:
84
- self.data_type = "video"
85
- self.data = LowMemoryVideo(video_file, **kwargs)
86
- elif image_folder is not None:
87
- self.data_type = "images"
88
- self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
- else:
90
- raise ValueError("Cannot open video or image folder")
91
- self.length = None
92
- self.set_shape(height, width)
93
-
94
- def raw_data(self):
95
- frames = []
96
- for i in range(self.__len__()):
97
- frames.append(self.__getitem__(i))
98
- return frames
99
-
100
- def set_length(self, length):
101
- self.length = length
102
-
103
- def set_shape(self, height, width):
104
- self.height = height
105
- self.width = width
106
-
107
- def __len__(self):
108
- if self.length is None:
109
- return len(self.data)
110
- else:
111
- return self.length
112
-
113
- def shape(self):
114
- if self.height is not None and self.width is not None:
115
- return self.height, self.width
116
- else:
117
- height, width, _ = self.__getitem__(0).shape
118
- return height, width
119
-
120
- def __getitem__(self, item):
121
- frame = self.data.__getitem__(item)
122
- width, height = frame.size
123
- if self.height is not None and self.width is not None:
124
- if self.height != height or self.width != width:
125
- frame = crop_and_resize(frame, self.height, self.width)
126
- return frame
127
-
128
- def __del__(self):
129
- pass
130
-
131
- def save_images(self, folder):
132
- os.makedirs(folder, exist_ok=True)
133
- for i in tqdm(range(self.__len__()), desc="Saving images"):
134
- frame = self.__getitem__(i)
135
- frame.save(os.path.join(folder, f"{i}.png"))
136
-
137
-
138
- def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
139
- writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
140
- for frame in tqdm(frames, desc="Saving video"):
141
- frame = np.array(frame)
142
- writer.append_data(frame)
143
- writer.close()
144
-
145
- def save_frames(frames, save_path):
146
- os.makedirs(save_path, exist_ok=True)
147
- for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
- frame.save(os.path.join(save_path, f"{i}.png"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/distributed/__init__.py DELETED
File without changes
diffsynth/distributed/xdit_context_parallel.py DELETED
@@ -1,129 +0,0 @@
1
- import torch
2
- from typing import Optional
3
- from einops import rearrange
4
- from xfuser.core.distributed import (get_sequence_parallel_rank,
5
- get_sequence_parallel_world_size,
6
- get_sp_group)
7
- from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
-
9
- def sinusoidal_embedding_1d(dim, position):
10
- sinusoid = torch.outer(position.type(torch.float64), torch.pow(
11
- 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
12
- x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
13
- return x.to(position.dtype)
14
-
15
- def pad_freqs(original_tensor, target_len):
16
- seq_len, s1, s2 = original_tensor.shape
17
- pad_size = target_len - seq_len
18
- padding_tensor = torch.ones(
19
- pad_size,
20
- s1,
21
- s2,
22
- dtype=original_tensor.dtype,
23
- device=original_tensor.device)
24
- padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
25
- return padded_tensor
26
-
27
- def rope_apply(x, freqs, num_heads):
28
- x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
29
- s_per_rank = x.shape[1]
30
-
31
- x_out = torch.view_as_complex(x.to(torch.float64).reshape(
32
- x.shape[0], x.shape[1], x.shape[2], -1, 2))
33
-
34
- sp_size = get_sequence_parallel_world_size()
35
- sp_rank = get_sequence_parallel_rank()
36
- freqs = pad_freqs(freqs, s_per_rank * sp_size)
37
- freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
38
-
39
- x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
40
- return x_out.to(x.dtype)
41
-
42
- def usp_dit_forward(self,
43
- x: torch.Tensor,
44
- timestep: torch.Tensor,
45
- context: torch.Tensor,
46
- clip_feature: Optional[torch.Tensor] = None,
47
- y: Optional[torch.Tensor] = None,
48
- use_gradient_checkpointing: bool = False,
49
- use_gradient_checkpointing_offload: bool = False,
50
- **kwargs,
51
- ):
52
- t = self.time_embedding(
53
- sinusoidal_embedding_1d(self.freq_dim, timestep))
54
- t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
55
- context = self.text_embedding(context)
56
-
57
- if self.has_image_input:
58
- x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
59
- clip_embdding = self.img_emb(clip_feature)
60
- context = torch.cat([clip_embdding, context], dim=1)
61
-
62
- x, (f, h, w) = self.patchify(x)
63
-
64
- freqs = torch.cat([
65
- self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
66
- self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
67
- self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
68
- ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
69
-
70
- def create_custom_forward(module):
71
- def custom_forward(*inputs):
72
- return module(*inputs)
73
- return custom_forward
74
-
75
- # Context Parallel
76
- x = torch.chunk(
77
- x, get_sequence_parallel_world_size(),
78
- dim=1)[get_sequence_parallel_rank()]
79
-
80
- for block in self.blocks:
81
- if self.training and use_gradient_checkpointing:
82
- if use_gradient_checkpointing_offload:
83
- with torch.autograd.graph.save_on_cpu():
84
- x = torch.utils.checkpoint.checkpoint(
85
- create_custom_forward(block),
86
- x, context, t_mod, freqs,
87
- use_reentrant=False,
88
- )
89
- else:
90
- x = torch.utils.checkpoint.checkpoint(
91
- create_custom_forward(block),
92
- x, context, t_mod, freqs,
93
- use_reentrant=False,
94
- )
95
- else:
96
- x = block(x, context, t_mod, freqs)
97
-
98
- x = self.head(x, t)
99
-
100
- # Context Parallel
101
- x = get_sp_group().all_gather(x, dim=1)
102
-
103
- # unpatchify
104
- x = self.unpatchify(x, (f, h, w))
105
- return x
106
-
107
-
108
- def usp_attn_forward(self, x, freqs):
109
- q = self.norm_q(self.q(x))
110
- k = self.norm_k(self.k(x))
111
- v = self.v(x)
112
-
113
- q = rope_apply(q, freqs, self.num_heads)
114
- k = rope_apply(k, freqs, self.num_heads)
115
- q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
116
- k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
117
- v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
118
-
119
- x = xFuserLongContextAttention()(
120
- None,
121
- query=q,
122
- key=k,
123
- value=v,
124
- )
125
- x = x.flatten(2)
126
-
127
- del q, k, v
128
- torch.cuda.empty_cache()
129
- return self.o(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ESRGAN/__init__.py DELETED
@@ -1,137 +0,0 @@
1
- import torch
2
- from einops import repeat
3
- from PIL import Image
4
- import numpy as np
5
-
6
-
7
- class ResidualDenseBlock(torch.nn.Module):
8
-
9
- def __init__(self, num_feat=64, num_grow_ch=32):
10
- super(ResidualDenseBlock, self).__init__()
11
- self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
- self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
- self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
- self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
- self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
- self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
-
18
- def forward(self, x):
19
- x1 = self.lrelu(self.conv1(x))
20
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
- return x5 * 0.2 + x
25
-
26
-
27
- class RRDB(torch.nn.Module):
28
-
29
- def __init__(self, num_feat, num_grow_ch=32):
30
- super(RRDB, self).__init__()
31
- self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
- self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
- self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
-
35
- def forward(self, x):
36
- out = self.rdb1(x)
37
- out = self.rdb2(out)
38
- out = self.rdb3(out)
39
- return out * 0.2 + x
40
-
41
-
42
- class RRDBNet(torch.nn.Module):
43
-
44
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
45
- super(RRDBNet, self).__init__()
46
- self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
- self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
- self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
- # upsample
50
- self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
- self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
- self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
- self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
- self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
-
56
- def forward(self, x):
57
- feat = x
58
- feat = self.conv_first(feat)
59
- body_feat = self.conv_body(self.body(feat))
60
- feat = feat + body_feat
61
- # upsample
62
- feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
- feat = self.lrelu(self.conv_up1(feat))
64
- feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
- feat = self.lrelu(self.conv_up2(feat))
66
- out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
- return out
68
-
69
- @staticmethod
70
- def state_dict_converter():
71
- return RRDBNetStateDictConverter()
72
-
73
-
74
- class RRDBNetStateDictConverter:
75
- def __init__(self):
76
- pass
77
-
78
- def from_diffusers(self, state_dict):
79
- return state_dict, {"upcast_to_float32": True}
80
-
81
- def from_civitai(self, state_dict):
82
- return state_dict, {"upcast_to_float32": True}
83
-
84
-
85
- class ESRGAN(torch.nn.Module):
86
- def __init__(self, model):
87
- super().__init__()
88
- self.model = model
89
-
90
- @staticmethod
91
- def from_model_manager(model_manager):
92
- return ESRGAN(model_manager.fetch_model("esrgan"))
93
-
94
- def process_image(self, image):
95
- image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
96
- return image
97
-
98
- def process_images(self, images):
99
- images = [self.process_image(image) for image in images]
100
- images = torch.stack(images)
101
- return images
102
-
103
- def decode_images(self, images):
104
- images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
105
- images = [Image.fromarray(image) for image in images]
106
- return images
107
-
108
- @torch.no_grad()
109
- def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
110
- if not isinstance(images, list):
111
- images = [images]
112
- is_single_image = True
113
- else:
114
- is_single_image = False
115
-
116
- # Preprocess
117
- input_tensor = self.process_images(images)
118
-
119
- # Interpolate
120
- output_tensor = []
121
- for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
122
- batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
123
- batch_input_tensor = input_tensor[batch_id: batch_id_]
124
- batch_input_tensor = batch_input_tensor.to(
125
- device=self.model.conv_first.weight.device,
126
- dtype=self.model.conv_first.weight.dtype)
127
- batch_output_tensor = self.model(batch_input_tensor)
128
- output_tensor.append(batch_output_tensor.cpu())
129
-
130
- # Output
131
- output_tensor = torch.concat(output_tensor, dim=0)
132
-
133
- # To images
134
- output_images = self.decode_images(output_tensor)
135
- if is_single_image:
136
- output_images = output_images[0]
137
- return output_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (6.09 kB)
 
diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (11.8 kB)
 
diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (10.1 kB)
 
diffsynth/extensions/FastBlend/__init__.py DELETED
@@ -1,63 +0,0 @@
1
- from .runners.fast import TableManager, PyramidPatchMatcher
2
- from PIL import Image
3
- import numpy as np
4
- import cupy as cp
5
-
6
-
7
- class FastBlendSmoother:
8
- def __init__(self):
9
- self.batch_size = 8
10
- self.window_size = 64
11
- self.ebsynth_config = {
12
- "minimum_patch_size": 5,
13
- "threads_per_block": 8,
14
- "num_iter": 5,
15
- "gpu_id": 0,
16
- "guide_weight": 10.0,
17
- "initialize": "identity",
18
- "tracking_window_size": 0,
19
- }
20
-
21
- @staticmethod
22
- def from_model_manager(model_manager):
23
- # TODO: fetch GPU ID from model_manager
24
- return FastBlendSmoother()
25
-
26
- def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
- frames_guide = [np.array(frame) for frame in frames_guide]
28
- frames_style = [np.array(frame) for frame in frames_style]
29
- table_manager = TableManager()
30
- patch_match_engine = PyramidPatchMatcher(
31
- image_height=frames_style[0].shape[0],
32
- image_width=frames_style[0].shape[1],
33
- channel=3,
34
- **ebsynth_config
35
- )
36
- # left part
37
- table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
- table_l = table_manager.remapping_table_to_blending_table(table_l)
39
- table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
- # right part
41
- table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
- table_r = table_manager.remapping_table_to_blending_table(table_r)
43
- table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
- # merge
45
- frames = []
46
- for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
- weight_m = -1
48
- weight = weight_l + weight_m + weight_r
49
- frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
- frames.append(frame)
51
- frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
- return frames
53
-
54
- def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
- frames = self.run(
56
- original_frames, rendered_frames,
57
- self.batch_size, self.window_size, self.ebsynth_config
58
- )
59
- mempool = cp.get_default_memory_pool()
60
- pinned_mempool = cp.get_default_pinned_memory_pool()
61
- mempool.free_all_blocks()
62
- pinned_mempool.free_all_blocks()
63
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/api.py DELETED
@@ -1,397 +0,0 @@
1
- from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
- from .data import VideoData, get_video_fps, save_video, search_for_images
3
- import os
4
- import gradio as gr
5
-
6
-
7
- def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
- frames_guide = VideoData(video_guide, video_guide_folder)
9
- frames_style = VideoData(video_style, video_style_folder)
10
- message = ""
11
- if len(frames_guide) < len(frames_style):
12
- message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
- frames_style.set_length(len(frames_guide))
14
- elif len(frames_guide) > len(frames_style):
15
- message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
- frames_guide.set_length(len(frames_style))
17
- height_guide, width_guide = frames_guide.shape()
18
- height_style, width_style = frames_style.shape()
19
- if height_guide != height_style or width_guide != width_style:
20
- message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
- frames_style.set_shape(height_guide, width_guide)
22
- return frames_guide, frames_style, message
23
-
24
-
25
- def smooth_video(
26
- video_guide,
27
- video_guide_folder,
28
- video_style,
29
- video_style_folder,
30
- mode,
31
- window_size,
32
- batch_size,
33
- tracking_window_size,
34
- output_path,
35
- fps,
36
- minimum_patch_size,
37
- num_iter,
38
- guide_weight,
39
- initialize,
40
- progress = None,
41
- ):
42
- # input
43
- frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
- if len(message) > 0:
45
- print(message)
46
- # output
47
- if output_path == "":
48
- if video_style is None:
49
- output_path = os.path.join(video_style_folder, "output")
50
- else:
51
- output_path = os.path.join(os.path.split(video_style)[0], "output")
52
- os.makedirs(output_path, exist_ok=True)
53
- print("No valid output_path. Your video will be saved here:", output_path)
54
- elif not os.path.exists(output_path):
55
- os.makedirs(output_path, exist_ok=True)
56
- print("Your video will be saved here:", output_path)
57
- frames_path = os.path.join(output_path, "frames")
58
- video_path = os.path.join(output_path, "video.mp4")
59
- os.makedirs(frames_path, exist_ok=True)
60
- # process
61
- if mode == "Fast" or mode == "Balanced":
62
- tracking_window_size = 0
63
- ebsynth_config = {
64
- "minimum_patch_size": minimum_patch_size,
65
- "threads_per_block": 8,
66
- "num_iter": num_iter,
67
- "gpu_id": 0,
68
- "guide_weight": guide_weight,
69
- "initialize": initialize,
70
- "tracking_window_size": tracking_window_size,
71
- }
72
- if mode == "Fast":
73
- FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
- elif mode == "Balanced":
75
- BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
- elif mode == "Accurate":
77
- AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
- # output
79
- try:
80
- fps = int(fps)
81
- except:
82
- fps = get_video_fps(video_style) if video_style is not None else 30
83
- print("Fps:", fps)
84
- print("Saving video...")
85
- video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
- print("Success!")
87
- print("Your frames are here:", frames_path)
88
- print("Your video is here:", video_path)
89
- return output_path, fps, video_path
90
-
91
-
92
- class KeyFrameMatcher:
93
- def __init__(self):
94
- pass
95
-
96
- def extract_number_from_filename(self, file_name):
97
- result = []
98
- number = -1
99
- for i in file_name:
100
- if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
- if number == -1:
102
- number = 0
103
- number = number*10 + ord(i) - ord("0")
104
- else:
105
- if number != -1:
106
- result.append(number)
107
- number = -1
108
- if number != -1:
109
- result.append(number)
110
- result = tuple(result)
111
- return result
112
-
113
- def extract_number_from_filenames(self, file_names):
114
- numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
- min_length = min(len(i) for i in numbers)
116
- for i in range(min_length-1, -1, -1):
117
- if len(set(number[i] for number in numbers))==len(file_names):
118
- return [number[i] for number in numbers]
119
- return list(range(len(file_names)))
120
-
121
- def match_using_filename(self, file_names_a, file_names_b):
122
- file_names_b_set = set(file_names_b)
123
- matched_file_name = []
124
- for file_name in file_names_a:
125
- if file_name not in file_names_b_set:
126
- matched_file_name.append(None)
127
- else:
128
- matched_file_name.append(file_name)
129
- return matched_file_name
130
-
131
- def match_using_numbers(self, file_names_a, file_names_b):
132
- numbers_a = self.extract_number_from_filenames(file_names_a)
133
- numbers_b = self.extract_number_from_filenames(file_names_b)
134
- numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
- matched_file_name = []
136
- for number in numbers_a:
137
- if number in numbers_b_dict:
138
- matched_file_name.append(numbers_b_dict[number])
139
- else:
140
- matched_file_name.append(None)
141
- return matched_file_name
142
-
143
- def match_filenames(self, file_names_a, file_names_b):
144
- matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
- if sum([i is not None for i in matched_file_name]) > 0:
146
- return matched_file_name
147
- matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
- return matched_file_name
149
-
150
-
151
- def detect_frames(frames_path, keyframes_path):
152
- if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
- return "Please input the directory of guide video and rendered frames"
154
- elif not os.path.exists(frames_path):
155
- return "Please input the directory of guide video"
156
- elif not os.path.exists(keyframes_path):
157
- return "Please input the directory of rendered frames"
158
- frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
- keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
- if len(frames)==0:
161
- return f"No images detected in {frames_path}"
162
- if len(keyframes)==0:
163
- return f"No images detected in {keyframes_path}"
164
- matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
- max_filename_length = max([len(i) for i in frames])
166
- if sum([i is not None for i in matched_keyframes])==0:
167
- message = ""
168
- for frame, matched_keyframe in zip(frames, matched_keyframes):
169
- message += frame + " " * (max_filename_length - len(frame) + 1)
170
- message += "--> No matched keyframes\n"
171
- else:
172
- message = ""
173
- for frame, matched_keyframe in zip(frames, matched_keyframes):
174
- message += frame + " " * (max_filename_length - len(frame) + 1)
175
- if matched_keyframe is None:
176
- message += "--> [to be rendered]\n"
177
- else:
178
- message += f"--> {matched_keyframe}\n"
179
- return message
180
-
181
-
182
- def check_input_for_interpolating(frames_path, keyframes_path):
183
- # search for images
184
- frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
- keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
- # match frames
187
- matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
- file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
- index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
- frames_guide = VideoData(None, frames_path)
191
- frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
- # match shape
193
- message = ""
194
- height_guide, width_guide = frames_guide.shape()
195
- height_style, width_style = frames_style.shape()
196
- if height_guide != height_style or width_guide != width_style:
197
- message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
- frames_style.set_shape(height_guide, width_guide)
199
- return frames_guide, frames_style, index_style, message
200
-
201
-
202
- def interpolate_video(
203
- frames_path,
204
- keyframes_path,
205
- output_path,
206
- fps,
207
- batch_size,
208
- tracking_window_size,
209
- minimum_patch_size,
210
- num_iter,
211
- guide_weight,
212
- initialize,
213
- progress = None,
214
- ):
215
- # input
216
- frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
- if len(message) > 0:
218
- print(message)
219
- # output
220
- if output_path == "":
221
- output_path = os.path.join(keyframes_path, "output")
222
- os.makedirs(output_path, exist_ok=True)
223
- print("No valid output_path. Your video will be saved here:", output_path)
224
- elif not os.path.exists(output_path):
225
- os.makedirs(output_path, exist_ok=True)
226
- print("Your video will be saved here:", output_path)
227
- output_frames_path = os.path.join(output_path, "frames")
228
- output_video_path = os.path.join(output_path, "video.mp4")
229
- os.makedirs(output_frames_path, exist_ok=True)
230
- # process
231
- ebsynth_config = {
232
- "minimum_patch_size": minimum_patch_size,
233
- "threads_per_block": 8,
234
- "num_iter": num_iter,
235
- "gpu_id": 0,
236
- "guide_weight": guide_weight,
237
- "initialize": initialize,
238
- "tracking_window_size": tracking_window_size
239
- }
240
- if len(index_style)==1:
241
- InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
- else:
243
- InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
- try:
245
- fps = int(fps)
246
- except:
247
- fps = 30
248
- print("Fps:", fps)
249
- print("Saving video...")
250
- video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
- print("Success!")
252
- print("Your frames are here:", output_frames_path)
253
- print("Your video is here:", video_path)
254
- return output_path, fps, video_path
255
-
256
-
257
- def on_ui_tabs():
258
- with gr.Blocks(analytics_enabled=False) as ui_component:
259
- with gr.Tab("Blend"):
260
- gr.Markdown("""
261
- # Blend
262
-
263
- Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
- """)
265
- with gr.Row():
266
- with gr.Column():
267
- with gr.Tab("Guide video"):
268
- video_guide = gr.Video(label="Guide video")
269
- with gr.Tab("Guide video (images format)"):
270
- video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
- with gr.Column():
272
- with gr.Tab("Style video"):
273
- video_style = gr.Video(label="Style video")
274
- with gr.Tab("Style video (images format)"):
275
- video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
- with gr.Column():
277
- output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
- fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
- video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
- btn = gr.Button(value="Blend")
281
- with gr.Row():
282
- with gr.Column():
283
- gr.Markdown("# Settings")
284
- mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
- window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
- batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
- tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
- gr.Markdown("## Advanced Settings")
289
- minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
- num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
- guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
- initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
- with gr.Column():
294
- gr.Markdown("""
295
- # Reference
296
-
297
- * Output directory: the directory to save the video.
298
- * Inference mode
299
-
300
- |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
- |-|-|-|-|-|-|
302
- |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
- |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
- |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
-
306
- * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
- * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
- * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
- * Advanced settings
310
- * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
- * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
- * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
- * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
- """)
315
- btn.click(
316
- smooth_video,
317
- inputs=[
318
- video_guide,
319
- video_guide_folder,
320
- video_style,
321
- video_style_folder,
322
- mode,
323
- window_size,
324
- batch_size,
325
- tracking_window_size,
326
- output_path,
327
- fps,
328
- minimum_patch_size,
329
- num_iter,
330
- guide_weight,
331
- initialize
332
- ],
333
- outputs=[output_path, fps, video_output]
334
- )
335
- with gr.Tab("Interpolate"):
336
- gr.Markdown("""
337
- # Interpolate
338
-
339
- Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
- """)
341
- with gr.Row():
342
- with gr.Column():
343
- with gr.Row():
344
- with gr.Column():
345
- video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
- with gr.Column():
347
- rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
- with gr.Row():
349
- detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
- video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
- rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
- with gr.Column():
353
- output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
- fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
- video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
- btn_ = gr.Button(value="Interpolate")
357
- with gr.Row():
358
- with gr.Column():
359
- gr.Markdown("# Settings")
360
- batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
- tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
- gr.Markdown("## Advanced Settings")
363
- minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
- num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
- guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
- initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
- with gr.Column():
368
- gr.Markdown("""
369
- # Reference
370
-
371
- * Output directory: the directory to save the video.
372
- * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
- * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
- * Advanced settings
375
- * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
- * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
- * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
- * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
- """)
380
- btn_.click(
381
- interpolate_video,
382
- inputs=[
383
- video_guide_folder_,
384
- rendered_keyframes_,
385
- output_path_,
386
- fps_,
387
- batch_size_,
388
- tracking_window_size_,
389
- minimum_patch_size_,
390
- num_iter_,
391
- guide_weight_,
392
- initialize_,
393
- ],
394
- outputs=[output_path_, fps_, video_output_]
395
- )
396
-
397
- return [(ui_component, "FastBlend", "FastBlend_ui")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/cupy_kernels.py DELETED
@@ -1,119 +0,0 @@
1
- import cupy as cp
2
-
3
- remapping_kernel = cp.RawKernel(r'''
4
- extern "C" __global__
5
- void remap(
6
- const int height,
7
- const int width,
8
- const int channel,
9
- const int patch_size,
10
- const int pad_size,
11
- const float* source_style,
12
- const int* nnf,
13
- float* target_style
14
- ) {
15
- const int r = (patch_size - 1) / 2;
16
- const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
- const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
- if (x >= height or y >= width) return;
19
- const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
- const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
- const int min_px = x < r ? -x : -r;
22
- const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
- const int min_py = y < r ? -y : -r;
24
- const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
- int num = 0;
26
- for (int px = min_px; px <= max_px; px++){
27
- for (int py = min_py; py <= max_py; py++){
28
- const int nid = (x + px) * width + y + py;
29
- const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
- const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
- if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
- const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
- num++;
34
- for (int c = 0; c < channel; c++){
35
- target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
- }
37
- }
38
- }
39
- for (int c = 0; c < channel; c++){
40
- target_style[z + pid * channel + c] /= num;
41
- }
42
- }
43
- ''', 'remap')
44
-
45
-
46
- patch_error_kernel = cp.RawKernel(r'''
47
- extern "C" __global__
48
- void patch_error(
49
- const int height,
50
- const int width,
51
- const int channel,
52
- const int patch_size,
53
- const int pad_size,
54
- const float* source,
55
- const int* nnf,
56
- const float* target,
57
- float* error
58
- ) {
59
- const int r = (patch_size - 1) / 2;
60
- const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
- const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
- const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
- if (x >= height or y >= width) return;
64
- const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
- const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
- float e = 0;
67
- for (int px = -r; px <= r; px++){
68
- for (int py = -r; py <= r; py++){
69
- const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
- const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
- for (int c = 0; c < channel; c++){
72
- const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
- e += diff * diff;
74
- }
75
- }
76
- }
77
- error[blockIdx.z * height * width + x * width + y] = e;
78
- }
79
- ''', 'patch_error')
80
-
81
-
82
- pairwise_patch_error_kernel = cp.RawKernel(r'''
83
- extern "C" __global__
84
- void pairwise_patch_error(
85
- const int height,
86
- const int width,
87
- const int channel,
88
- const int patch_size,
89
- const int pad_size,
90
- const float* source_a,
91
- const int* nnf_a,
92
- const float* source_b,
93
- const int* nnf_b,
94
- float* error
95
- ) {
96
- const int r = (patch_size - 1) / 2;
97
- const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
- const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
- const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
- if (x >= height or y >= width) return;
101
- const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
- const int x_a = nnf_a[z_nnf + 0];
103
- const int y_a = nnf_a[z_nnf + 1];
104
- const int x_b = nnf_b[z_nnf + 0];
105
- const int y_b = nnf_b[z_nnf + 1];
106
- float e = 0;
107
- for (int px = -r; px <= r; px++){
108
- for (int py = -r; py <= r; py++){
109
- const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
- const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
- for (int c = 0; c < channel; c++){
112
- const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
- e += diff * diff;
114
- }
115
- }
116
- }
117
- error[blockIdx.z * height * width + x * width + y] = e;
118
- }
119
- ''', 'pairwise_patch_error')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/data.py DELETED
@@ -1,146 +0,0 @@
1
- import imageio, os
2
- import numpy as np
3
- from PIL import Image
4
-
5
-
6
- def read_video(file_name):
7
- reader = imageio.get_reader(file_name)
8
- video = []
9
- for frame in reader:
10
- frame = np.array(frame)
11
- video.append(frame)
12
- reader.close()
13
- return video
14
-
15
-
16
- def get_video_fps(file_name):
17
- reader = imageio.get_reader(file_name)
18
- fps = reader.get_meta_data()["fps"]
19
- reader.close()
20
- return fps
21
-
22
-
23
- def save_video(frames_path, video_path, num_frames, fps):
24
- writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
- for i in range(num_frames):
26
- frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
- writer.append_data(frame)
28
- writer.close()
29
- return video_path
30
-
31
-
32
- class LowMemoryVideo:
33
- def __init__(self, file_name):
34
- self.reader = imageio.get_reader(file_name)
35
-
36
- def __len__(self):
37
- return self.reader.count_frames()
38
-
39
- def __getitem__(self, item):
40
- return np.array(self.reader.get_data(item))
41
-
42
- def __del__(self):
43
- self.reader.close()
44
-
45
-
46
- def split_file_name(file_name):
47
- result = []
48
- number = -1
49
- for i in file_name:
50
- if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
- if number == -1:
52
- number = 0
53
- number = number*10 + ord(i) - ord("0")
54
- else:
55
- if number != -1:
56
- result.append(number)
57
- number = -1
58
- result.append(i)
59
- if number != -1:
60
- result.append(number)
61
- result = tuple(result)
62
- return result
63
-
64
-
65
- def search_for_images(folder):
66
- file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
- file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
- file_list = [i[1] for i in sorted(file_list)]
69
- file_list = [os.path.join(folder, i) for i in file_list]
70
- return file_list
71
-
72
-
73
- def read_images(folder):
74
- file_list = search_for_images(folder)
75
- frames = [np.array(Image.open(i)) for i in file_list]
76
- return frames
77
-
78
-
79
- class LowMemoryImageFolder:
80
- def __init__(self, folder, file_list=None):
81
- if file_list is None:
82
- self.file_list = search_for_images(folder)
83
- else:
84
- self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
-
86
- def __len__(self):
87
- return len(self.file_list)
88
-
89
- def __getitem__(self, item):
90
- return np.array(Image.open(self.file_list[item]))
91
-
92
- def __del__(self):
93
- pass
94
-
95
-
96
- class VideoData:
97
- def __init__(self, video_file, image_folder, **kwargs):
98
- if video_file is not None:
99
- self.data_type = "video"
100
- self.data = LowMemoryVideo(video_file, **kwargs)
101
- elif image_folder is not None:
102
- self.data_type = "images"
103
- self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
- else:
105
- raise ValueError("Cannot open video or image folder")
106
- self.length = None
107
- self.height = None
108
- self.width = None
109
-
110
- def raw_data(self):
111
- frames = []
112
- for i in range(self.__len__()):
113
- frames.append(self.__getitem__(i))
114
- return frames
115
-
116
- def set_length(self, length):
117
- self.length = length
118
-
119
- def set_shape(self, height, width):
120
- self.height = height
121
- self.width = width
122
-
123
- def __len__(self):
124
- if self.length is None:
125
- return len(self.data)
126
- else:
127
- return self.length
128
-
129
- def shape(self):
130
- if self.height is not None and self.width is not None:
131
- return self.height, self.width
132
- else:
133
- height, width, _ = self.__getitem__(0).shape
134
- return height, width
135
-
136
- def __getitem__(self, item):
137
- frame = self.data.__getitem__(item)
138
- height, width, _ = frame.shape
139
- if self.height is not None and self.width is not None:
140
- if self.height != height or self.width != width:
141
- frame = Image.fromarray(frame).resize((self.width, self.height))
142
- frame = np.array(frame)
143
- return frame
144
-
145
- def __del__(self):
146
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/patch_match.py DELETED
@@ -1,298 +0,0 @@
1
- from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
- import numpy as np
3
- import cupy as cp
4
- import cv2
5
-
6
-
7
- class PatchMatcher:
8
- def __init__(
9
- self, height, width, channel, minimum_patch_size,
10
- threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
- random_search_steps=3, random_search_range=4,
12
- use_mean_target_style=False, use_pairwise_patch_error=False,
13
- tracking_window_size=0
14
- ):
15
- self.height = height
16
- self.width = width
17
- self.channel = channel
18
- self.minimum_patch_size = minimum_patch_size
19
- self.threads_per_block = threads_per_block
20
- self.num_iter = num_iter
21
- self.gpu_id = gpu_id
22
- self.guide_weight = guide_weight
23
- self.random_search_steps = random_search_steps
24
- self.random_search_range = random_search_range
25
- self.use_mean_target_style = use_mean_target_style
26
- self.use_pairwise_patch_error = use_pairwise_patch_error
27
- self.tracking_window_size = tracking_window_size
28
-
29
- self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
- self.pad_size = self.patch_size_list[0] // 2
31
- self.grid = (
32
- (height + threads_per_block - 1) // threads_per_block,
33
- (width + threads_per_block - 1) // threads_per_block
34
- )
35
- self.block = (threads_per_block, threads_per_block)
36
-
37
- def pad_image(self, image):
38
- return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
-
40
- def unpad_image(self, image):
41
- return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
-
43
- def apply_nnf_to_image(self, nnf, source):
44
- batch_size = source.shape[0]
45
- target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
- remapping_kernel(
47
- self.grid + (batch_size,),
48
- self.block,
49
- (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
- )
51
- return target
52
-
53
- def get_patch_error(self, source, nnf, target):
54
- batch_size = source.shape[0]
55
- error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
- patch_error_kernel(
57
- self.grid + (batch_size,),
58
- self.block,
59
- (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
- )
61
- return error
62
-
63
- def get_pairwise_patch_error(self, source, nnf):
64
- batch_size = source.shape[0]//2
65
- error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
- source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
- source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
- pairwise_patch_error_kernel(
69
- self.grid + (batch_size,),
70
- self.block,
71
- (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
- )
73
- error = error.repeat(2, axis=0)
74
- return error
75
-
76
- def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
- error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
- if self.use_mean_target_style:
79
- target_style = self.apply_nnf_to_image(nnf, source_style)
80
- target_style = target_style.mean(axis=0, keepdims=True)
81
- target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
- if self.use_pairwise_patch_error:
83
- error_style = self.get_pairwise_patch_error(source_style, nnf)
84
- else:
85
- error_style = self.get_patch_error(source_style, nnf, target_style)
86
- error = error_guide * self.guide_weight + error_style
87
- return error
88
-
89
- def clamp_bound(self, nnf):
90
- nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
- nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
- return nnf
93
-
94
- def random_step(self, nnf, r):
95
- batch_size = nnf.shape[0]
96
- step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
- upd_nnf = self.clamp_bound(nnf + step)
98
- return upd_nnf
99
-
100
- def neighboor_step(self, nnf, d):
101
- if d==0:
102
- upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
- upd_nnf[:, :, :, 0] += 1
104
- elif d==1:
105
- upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
- upd_nnf[:, :, :, 1] += 1
107
- elif d==2:
108
- upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
- upd_nnf[:, :, :, 0] -= 1
110
- elif d==3:
111
- upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
- upd_nnf[:, :, :, 1] -= 1
113
- upd_nnf = self.clamp_bound(upd_nnf)
114
- return upd_nnf
115
-
116
- def shift_nnf(self, nnf, d):
117
- if d>0:
118
- d = min(nnf.shape[0], d)
119
- upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
- else:
121
- d = max(-nnf.shape[0], d)
122
- upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
- return upd_nnf
124
-
125
- def track_step(self, nnf, d):
126
- if self.use_pairwise_patch_error:
127
- upd_nnf = cp.zeros_like(nnf)
128
- upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
- upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
- else:
131
- upd_nnf = self.shift_nnf(nnf, d)
132
- return upd_nnf
133
-
134
- def C(self, n, m):
135
- # not used
136
- c = 1
137
- for i in range(1, n+1):
138
- c *= i
139
- for i in range(1, m+1):
140
- c //= i
141
- for i in range(1, n-m+1):
142
- c //= i
143
- return c
144
-
145
- def bezier_step(self, nnf, r):
146
- # not used
147
- n = r * 2 - 1
148
- upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
- for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
- if d>0:
151
- ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
- elif d<0:
153
- ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
- upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
- upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
- return upd_nnf
157
-
158
- def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
- upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
- upd_idx = (upd_err < err)
161
- nnf[upd_idx] = upd_nnf[upd_idx]
162
- err[upd_idx] = upd_err[upd_idx]
163
- return nnf, err
164
-
165
- def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
- for d in cp.random.permutation(4):
167
- upd_nnf = self.neighboor_step(nnf, d)
168
- nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
- return nnf, err
170
-
171
- def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
- for i in range(self.random_search_steps):
173
- upd_nnf = self.random_step(nnf, self.random_search_range)
174
- nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
- return nnf, err
176
-
177
- def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
- for d in range(1, self.tracking_window_size + 1):
179
- upd_nnf = self.track_step(nnf, d)
180
- nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
- upd_nnf = self.track_step(nnf, -d)
182
- nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
- return nnf, err
184
-
185
- def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
- nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
- nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
- nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
- return nnf, err
190
-
191
- def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
- with cp.cuda.Device(self.gpu_id):
193
- source_guide = self.pad_image(source_guide)
194
- target_guide = self.pad_image(target_guide)
195
- source_style = self.pad_image(source_style)
196
- for it in range(self.num_iter):
197
- self.patch_size = self.patch_size_list[it]
198
- target_style = self.apply_nnf_to_image(nnf, source_style)
199
- err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
- nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
- target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
- return nnf, target_style
203
-
204
-
205
- class PyramidPatchMatcher:
206
- def __init__(
207
- self, image_height, image_width, channel, minimum_patch_size,
208
- threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
- use_mean_target_style=False, use_pairwise_patch_error=False,
210
- tracking_window_size=0,
211
- initialize="identity"
212
- ):
213
- maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
- self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
- self.pyramid_heights = []
216
- self.pyramid_widths = []
217
- self.patch_matchers = []
218
- self.minimum_patch_size = minimum_patch_size
219
- self.num_iter = num_iter
220
- self.gpu_id = gpu_id
221
- self.initialize = initialize
222
- for level in range(self.pyramid_level):
223
- height = image_height//(2**(self.pyramid_level - 1 - level))
224
- width = image_width//(2**(self.pyramid_level - 1 - level))
225
- self.pyramid_heights.append(height)
226
- self.pyramid_widths.append(width)
227
- self.patch_matchers.append(PatchMatcher(
228
- height, width, channel, minimum_patch_size=minimum_patch_size,
229
- threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
- use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
- tracking_window_size=tracking_window_size
232
- ))
233
-
234
- def resample_image(self, images, level):
235
- height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
- images = images.get()
237
- images_resample = []
238
- for image in images:
239
- image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
- images_resample.append(image_resample)
241
- images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
- return images_resample
243
-
244
- def initialize_nnf(self, batch_size):
245
- if self.initialize == "random":
246
- height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
- nnf = cp.stack([
248
- cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
- cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
- ], axis=3)
251
- elif self.initialize == "identity":
252
- height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
- nnf = cp.stack([
254
- cp.repeat(cp.arange(height), width).reshape(height, width),
255
- cp.tile(cp.arange(width), height).reshape(height, width)
256
- ], axis=2)
257
- nnf = cp.stack([nnf] * batch_size)
258
- else:
259
- raise NotImplementedError()
260
- return nnf
261
-
262
- def update_nnf(self, nnf, level):
263
- # upscale
264
- nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
- nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
- nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
- # check if scale is 2
268
- height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
- if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
- nnf = nnf.get().astype(np.float32)
271
- nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
- nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
- nnf = self.patch_matchers[level].clamp_bound(nnf)
274
- return nnf
275
-
276
- def apply_nnf_to_image(self, nnf, image):
277
- with cp.cuda.Device(self.gpu_id):
278
- image = self.patch_matchers[-1].pad_image(image)
279
- image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
- return image
281
-
282
- def estimate_nnf(self, source_guide, target_guide, source_style):
283
- with cp.cuda.Device(self.gpu_id):
284
- if not isinstance(source_guide, cp.ndarray):
285
- source_guide = cp.array(source_guide, dtype=cp.float32)
286
- if not isinstance(target_guide, cp.ndarray):
287
- target_guide = cp.array(target_guide, dtype=cp.float32)
288
- if not isinstance(source_style, cp.ndarray):
289
- source_style = cp.array(source_style, dtype=cp.float32)
290
- for level in range(self.pyramid_level):
291
- nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
- source_guide_ = self.resample_image(source_guide, level)
293
- target_guide_ = self.resample_image(target_guide, level)
294
- source_style_ = self.resample_image(source_style, level)
295
- nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
- source_guide_, target_guide_, source_style_, nnf
297
- )
298
- return nnf.get(), target_style.get()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/runners/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .accurate import AccurateModeRunner
2
- from .fast import FastModeRunner
3
- from .balanced import BalancedModeRunner
4
- from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
 
 
 
 
 
diffsynth/extensions/FastBlend/runners/accurate.py DELETED
@@ -1,35 +0,0 @@
1
- from ..patch_match import PyramidPatchMatcher
2
- import os
3
- import numpy as np
4
- from PIL import Image
5
- from tqdm import tqdm
6
-
7
-
8
- class AccurateModeRunner:
9
- def __init__(self):
10
- pass
11
-
12
- def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
- patch_match_engine = PyramidPatchMatcher(
14
- image_height=frames_style[0].shape[0],
15
- image_width=frames_style[0].shape[1],
16
- channel=3,
17
- use_mean_target_style=True,
18
- **ebsynth_config
19
- )
20
- # run
21
- n = len(frames_style)
22
- for target in tqdm(range(n), desc=desc):
23
- l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
- remapped_frames = []
25
- for i in range(l, r, batch_size):
26
- j = min(i + batch_size, r)
27
- source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
- target_guide = np.stack([frames_guide[target]] * (j - i))
29
- source_style = np.stack([frames_style[source] for source in range(i, j)])
30
- _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
- remapped_frames.append(target_style)
32
- frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
- frame = frame.clip(0, 255).astype("uint8")
34
- if save_path is not None:
35
- Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/runners/balanced.py DELETED
@@ -1,46 +0,0 @@
1
- from ..patch_match import PyramidPatchMatcher
2
- import os
3
- import numpy as np
4
- from PIL import Image
5
- from tqdm import tqdm
6
-
7
-
8
- class BalancedModeRunner:
9
- def __init__(self):
10
- pass
11
-
12
- def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
- patch_match_engine = PyramidPatchMatcher(
14
- image_height=frames_style[0].shape[0],
15
- image_width=frames_style[0].shape[1],
16
- channel=3,
17
- **ebsynth_config
18
- )
19
- # tasks
20
- n = len(frames_style)
21
- tasks = []
22
- for target in range(n):
23
- for source in range(target - window_size, target + window_size + 1):
24
- if source >= 0 and source < n and source != target:
25
- tasks.append((source, target))
26
- # run
27
- frames = [(None, 1) for i in range(n)]
28
- for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
- tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
- source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
- target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
- source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
- _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
- for (source, target), result in zip(tasks_batch, target_style):
35
- frame, weight = frames[target]
36
- if frame is None:
37
- frame = frames_style[target]
38
- frames[target] = (
39
- frame * (weight / (weight + 1)) + result / (weight + 1),
40
- weight + 1
41
- )
42
- if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
- frame = frame.clip(0, 255).astype("uint8")
44
- if save_path is not None:
45
- Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
- frames[target] = (None, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/runners/fast.py DELETED
@@ -1,141 +0,0 @@
1
- from ..patch_match import PyramidPatchMatcher
2
- import functools, os
3
- import numpy as np
4
- from PIL import Image
5
- from tqdm import tqdm
6
-
7
-
8
- class TableManager:
9
- def __init__(self):
10
- pass
11
-
12
- def task_list(self, n):
13
- tasks = []
14
- max_level = 1
15
- while (1<<max_level)<=n:
16
- max_level += 1
17
- for i in range(n):
18
- j = i
19
- for level in range(max_level):
20
- if i&(1<<level):
21
- continue
22
- j |= 1<<level
23
- if j>=n:
24
- break
25
- meta_data = {
26
- "source": i,
27
- "target": j,
28
- "level": level + 1
29
- }
30
- tasks.append(meta_data)
31
- tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
- return tasks
33
-
34
- def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
- n = len(frames_guide)
36
- tasks = self.task_list(n)
37
- remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
- for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
- tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
- source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
- target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
- source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
- _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
- for task, result in zip(tasks_batch, target_style):
45
- target, level = task["target"], task["level"]
46
- if len(remapping_table[target])==level:
47
- remapping_table[target].append((result, 1))
48
- else:
49
- frame, weight = remapping_table[target][level]
50
- remapping_table[target][level] = (
51
- frame * (weight / (weight + 1)) + result / (weight + 1),
52
- weight + 1
53
- )
54
- return remapping_table
55
-
56
- def remapping_table_to_blending_table(self, table):
57
- for i in range(len(table)):
58
- for j in range(1, len(table[i])):
59
- frame_1, weight_1 = table[i][j-1]
60
- frame_2, weight_2 = table[i][j]
61
- frame = (frame_1 + frame_2) / 2
62
- weight = weight_1 + weight_2
63
- table[i][j] = (frame, weight)
64
- return table
65
-
66
- def tree_query(self, leftbound, rightbound):
67
- node_list = []
68
- node_index = rightbound
69
- while node_index>=leftbound:
70
- node_level = 0
71
- while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
- node_level += 1
73
- node_list.append((node_index, node_level))
74
- node_index -= 1<<node_level
75
- return node_list
76
-
77
- def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
- n = len(blending_table)
79
- tasks = []
80
- frames_result = []
81
- for target in range(n):
82
- node_list = self.tree_query(max(target-window_size, 0), target)
83
- for source, level in node_list:
84
- if source!=target:
85
- meta_data = {
86
- "source": source,
87
- "target": target,
88
- "level": level
89
- }
90
- tasks.append(meta_data)
91
- else:
92
- frames_result.append(blending_table[target][level])
93
- for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
- tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
- source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
- target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
- source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
- _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
- for task, frame_2 in zip(tasks_batch, target_style):
100
- source, target, level = task["source"], task["target"], task["level"]
101
- frame_1, weight_1 = frames_result[target]
102
- weight_2 = blending_table[source][level][1]
103
- weight = weight_1 + weight_2
104
- frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
- frames_result[target] = (frame, weight)
106
- return frames_result
107
-
108
-
109
- class FastModeRunner:
110
- def __init__(self):
111
- pass
112
-
113
- def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
- frames_guide = frames_guide.raw_data()
115
- frames_style = frames_style.raw_data()
116
- table_manager = TableManager()
117
- patch_match_engine = PyramidPatchMatcher(
118
- image_height=frames_style[0].shape[0],
119
- image_width=frames_style[0].shape[1],
120
- channel=3,
121
- **ebsynth_config
122
- )
123
- # left part
124
- table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
- table_l = table_manager.remapping_table_to_blending_table(table_l)
126
- table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
- # right part
128
- table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
- table_r = table_manager.remapping_table_to_blending_table(table_r)
130
- table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
- # merge
132
- frames = []
133
- for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
- weight_m = -1
135
- weight = weight_l + weight_m + weight_r
136
- frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
- frames.append(frame)
138
- frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
- if save_path is not None:
140
- for target, frame in enumerate(frames):
141
- Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/FastBlend/runners/interpolation.py DELETED
@@ -1,121 +0,0 @@
1
- from ..patch_match import PyramidPatchMatcher
2
- import os
3
- import numpy as np
4
- from PIL import Image
5
- from tqdm import tqdm
6
-
7
-
8
- class InterpolationModeRunner:
9
- def __init__(self):
10
- pass
11
-
12
- def get_index_dict(self, index_style):
13
- index_dict = {}
14
- for i, index in enumerate(index_style):
15
- index_dict[index] = i
16
- return index_dict
17
-
18
- def get_weight(self, l, m, r):
19
- weight_l, weight_r = abs(m - r), abs(m - l)
20
- if weight_l + weight_r == 0:
21
- weight_l, weight_r = 0.5, 0.5
22
- else:
23
- weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
- return weight_l, weight_r
25
-
26
- def get_task_group(self, index_style, n):
27
- task_group = []
28
- index_style = sorted(index_style)
29
- # first frame
30
- if index_style[0]>0:
31
- tasks = []
32
- for m in range(index_style[0]):
33
- tasks.append((index_style[0], m, index_style[0]))
34
- task_group.append(tasks)
35
- # middle frames
36
- for l, r in zip(index_style[:-1], index_style[1:]):
37
- tasks = []
38
- for m in range(l, r):
39
- tasks.append((l, m, r))
40
- task_group.append(tasks)
41
- # last frame
42
- tasks = []
43
- for m in range(index_style[-1], n):
44
- tasks.append((index_style[-1], m, index_style[-1]))
45
- task_group.append(tasks)
46
- return task_group
47
-
48
- def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
- patch_match_engine = PyramidPatchMatcher(
50
- image_height=frames_style[0].shape[0],
51
- image_width=frames_style[0].shape[1],
52
- channel=3,
53
- use_mean_target_style=False,
54
- use_pairwise_patch_error=True,
55
- **ebsynth_config
56
- )
57
- # task
58
- index_dict = self.get_index_dict(index_style)
59
- task_group = self.get_task_group(index_style, len(frames_guide))
60
- # run
61
- for tasks in task_group:
62
- index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
- for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
- tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
- source_guide, target_guide, source_style = [], [], []
66
- for l, m, r in tasks_batch:
67
- # l -> m
68
- source_guide.append(frames_guide[l])
69
- target_guide.append(frames_guide[m])
70
- source_style.append(frames_style[index_dict[l]])
71
- # r -> m
72
- source_guide.append(frames_guide[r])
73
- target_guide.append(frames_guide[m])
74
- source_style.append(frames_style[index_dict[r]])
75
- source_guide = np.stack(source_guide)
76
- target_guide = np.stack(target_guide)
77
- source_style = np.stack(source_style)
78
- _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
- if save_path is not None:
80
- for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
- weight_l, weight_r = self.get_weight(l, m, r)
82
- frame = frame_l * weight_l + frame_r * weight_r
83
- frame = frame.clip(0, 255).astype("uint8")
84
- Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
-
86
-
87
- class InterpolationModeSingleFrameRunner:
88
- def __init__(self):
89
- pass
90
-
91
- def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
- # check input
93
- tracking_window_size = ebsynth_config["tracking_window_size"]
94
- if tracking_window_size * 2 >= batch_size:
95
- raise ValueError("batch_size should be larger than track_window_size * 2")
96
- frame_style = frames_style[0]
97
- frame_guide = frames_guide[index_style[0]]
98
- patch_match_engine = PyramidPatchMatcher(
99
- image_height=frame_style.shape[0],
100
- image_width=frame_style.shape[1],
101
- channel=3,
102
- **ebsynth_config
103
- )
104
- # run
105
- frame_id, n = 0, len(frames_guide)
106
- for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
- if i + batch_size > n:
108
- l, r = max(n - batch_size, 0), n
109
- else:
110
- l, r = i, i + batch_size
111
- source_guide = np.stack([frame_guide] * (r-l))
112
- target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
- source_style = np.stack([frame_style] * (r-l))
114
- _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
- for i, frame in zip(range(l, r), target_style):
116
- if i==frame_id:
117
- frame = frame.clip(0, 255).astype("uint8")
118
- Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
- frame_id += 1
120
- if r < n and r-frame_id <= tracking_window_size:
121
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .blip_pretrain import *
 
 
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py DELETED
@@ -1,77 +0,0 @@
1
- '''
2
- * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
- '''
4
-
5
- import warnings
6
- warnings.filterwarnings("ignore")
7
-
8
- import torch
9
- import os
10
- from urllib.parse import urlparse
11
- from timm.models.hub import download_cached_file
12
- from transformers import BertTokenizer
13
- from .vit import VisionTransformer, interpolate_pos_embed
14
-
15
-
16
- def default_bert():
17
- current_dir = os.path.dirname(os.path.abspath(__file__))
18
- project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
19
- model_path = os.path.join(project_root, 'models', 'QualityMetric')
20
- return os.path.join(model_path, "bert-base-uncased")
21
-
22
-
23
- def init_tokenizer(bert_model_path):
24
- tokenizer = BertTokenizer.from_pretrained(bert_model_path)
25
- tokenizer.add_special_tokens({'bos_token':'[DEC]'})
26
- tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
27
- tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
28
- return tokenizer
29
-
30
-
31
- def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
32
-
33
- assert vit in ['base', 'large'], "vit parameter must be base or large"
34
- if vit=='base':
35
- vision_width = 768
36
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
37
- num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
38
- drop_path_rate=0 or drop_path_rate
39
- )
40
- elif vit=='large':
41
- vision_width = 1024
42
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
43
- num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
44
- drop_path_rate=0.1 or drop_path_rate
45
- )
46
- return visual_encoder, vision_width
47
-
48
-
49
- def is_url(url_or_filename):
50
- parsed = urlparse(url_or_filename)
51
- return parsed.scheme in ("http", "https")
52
-
53
- def load_checkpoint(model,url_or_filename):
54
- if is_url(url_or_filename):
55
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
56
- checkpoint = torch.load(cached_file, map_location='cpu')
57
- elif os.path.isfile(url_or_filename):
58
- checkpoint = torch.load(url_or_filename, map_location='cpu')
59
- else:
60
- raise RuntimeError('checkpoint url or path is invalid')
61
-
62
- state_dict = checkpoint['model']
63
-
64
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
65
- if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
66
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
67
- model.visual_encoder_m)
68
- for key in model.state_dict().keys():
69
- if key in state_dict.keys():
70
- if state_dict[key].shape!=model.state_dict()[key].shape:
71
- print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
72
- del state_dict[key]
73
-
74
- msg = model.load_state_dict(state_dict,strict=False)
75
- print('load checkpoint from %s'%url_or_filename)
76
- return model,msg
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py DELETED
@@ -1,44 +0,0 @@
1
- '''
2
- * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
- '''
4
-
5
- import transformers
6
- transformers.logging.set_verbosity_error()
7
-
8
- from torch import nn
9
- import os
10
- from .med import BertConfig, BertModel
11
- from .blip import create_vit, init_tokenizer
12
-
13
- class BLIP_Pretrain(nn.Module):
14
- def __init__(self,
15
- med_config = "med_config.json",
16
- image_size = 224,
17
- vit = 'base',
18
- vit_grad_ckpt = False,
19
- vit_ckpt_layer = 0,
20
- embed_dim = 256,
21
- queue_size = 57600,
22
- momentum = 0.995,
23
- bert_model_path = ""
24
- ):
25
- """
26
- Args:
27
- med_config (str): path for the mixture of encoder-decoder model's configuration file
28
- image_size (int): input image size
29
- vit (str): model size of vision transformer
30
- """
31
- super().__init__()
32
-
33
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
34
-
35
- self.tokenizer = init_tokenizer(bert_model_path)
36
- encoder_config = BertConfig.from_json_file(med_config)
37
- encoder_config.encoder_width = vision_width
38
- self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
39
-
40
- text_width = self.text_encoder.config.hidden_size
41
-
42
- self.vision_proj = nn.Linear(vision_width, embed_dim)
43
- self.text_proj = nn.Linear(text_width, embed_dim)
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/BLIP/med.py DELETED
@@ -1,947 +0,0 @@
1
- '''
2
- * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
- * Based on huggingface code base
4
- * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
5
- '''
6
-
7
- import math
8
- from typing import Tuple
9
-
10
- import torch
11
- from torch import Tensor, device, nn
12
- import torch.utils.checkpoint
13
- from torch import nn
14
- from torch.nn import CrossEntropyLoss
15
-
16
- from transformers.activations import ACT2FN
17
- from transformers.file_utils import (
18
- ModelOutput,
19
- )
20
- from transformers.modeling_outputs import (
21
- BaseModelOutputWithPastAndCrossAttentions,
22
- BaseModelOutputWithPoolingAndCrossAttentions,
23
- CausalLMOutputWithCrossAttentions,
24
- MaskedLMOutput,
25
- MultipleChoiceModelOutput,
26
- NextSentencePredictorOutput,
27
- QuestionAnsweringModelOutput,
28
- SequenceClassifierOutput,
29
- TokenClassifierOutput,
30
- )
31
- from transformers.modeling_utils import (
32
- PreTrainedModel,
33
- apply_chunking_to_forward,
34
- find_pruneable_heads_and_indices,
35
- prune_linear_layer,
36
- )
37
- from transformers.utils import logging
38
- from transformers.models.bert.configuration_bert import BertConfig
39
-
40
-
41
- logger = logging.get_logger(__name__)
42
-
43
-
44
- class BertEmbeddings(nn.Module):
45
- """Construct the embeddings from word and position embeddings."""
46
-
47
- def __init__(self, config):
48
- super().__init__()
49
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
50
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
51
-
52
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
53
- # any TensorFlow checkpoint file
54
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
56
-
57
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
58
- self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
59
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
60
-
61
- self.config = config
62
-
63
- def forward(
64
- self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
65
- ):
66
- if input_ids is not None:
67
- input_shape = input_ids.size()
68
- else:
69
- input_shape = inputs_embeds.size()[:-1]
70
-
71
- seq_length = input_shape[1]
72
-
73
- if position_ids is None:
74
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
75
-
76
- if inputs_embeds is None:
77
- inputs_embeds = self.word_embeddings(input_ids)
78
-
79
- embeddings = inputs_embeds
80
-
81
- if self.position_embedding_type == "absolute":
82
- position_embeddings = self.position_embeddings(position_ids)
83
- embeddings += position_embeddings
84
- embeddings = self.LayerNorm(embeddings)
85
- embeddings = self.dropout(embeddings)
86
- return embeddings
87
-
88
-
89
- class BertSelfAttention(nn.Module):
90
- def __init__(self, config, is_cross_attention):
91
- super().__init__()
92
- self.config = config
93
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
94
- raise ValueError(
95
- "The hidden size (%d) is not a multiple of the number of attention "
96
- "heads (%d)" % (config.hidden_size, config.num_attention_heads)
97
- )
98
-
99
- self.num_attention_heads = config.num_attention_heads
100
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
101
- self.all_head_size = self.num_attention_heads * self.attention_head_size
102
-
103
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
104
- if is_cross_attention:
105
- self.key = nn.Linear(config.encoder_width, self.all_head_size)
106
- self.value = nn.Linear(config.encoder_width, self.all_head_size)
107
- else:
108
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
109
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
110
-
111
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
112
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
114
- self.max_position_embeddings = config.max_position_embeddings
115
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
116
- self.save_attention = False
117
-
118
- def save_attn_gradients(self, attn_gradients):
119
- self.attn_gradients = attn_gradients
120
-
121
- def get_attn_gradients(self):
122
- return self.attn_gradients
123
-
124
- def save_attention_map(self, attention_map):
125
- self.attention_map = attention_map
126
-
127
- def get_attention_map(self):
128
- return self.attention_map
129
-
130
- def transpose_for_scores(self, x):
131
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
132
- x = x.view(*new_x_shape)
133
- return x.permute(0, 2, 1, 3)
134
-
135
- def forward(
136
- self,
137
- hidden_states,
138
- attention_mask=None,
139
- head_mask=None,
140
- encoder_hidden_states=None,
141
- encoder_attention_mask=None,
142
- past_key_value=None,
143
- output_attentions=False,
144
- ):
145
- mixed_query_layer = self.query(hidden_states)
146
-
147
- # If this is instantiated as a cross-attention module, the keys
148
- # and values come from an encoder; the attention mask needs to be
149
- # such that the encoder's padding tokens are not attended to.
150
- is_cross_attention = encoder_hidden_states is not None
151
-
152
- if is_cross_attention:
153
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
154
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
155
- attention_mask = encoder_attention_mask
156
- elif past_key_value is not None:
157
- key_layer = self.transpose_for_scores(self.key(hidden_states))
158
- value_layer = self.transpose_for_scores(self.value(hidden_states))
159
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
160
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
161
- else:
162
- key_layer = self.transpose_for_scores(self.key(hidden_states))
163
- value_layer = self.transpose_for_scores(self.value(hidden_states))
164
-
165
- query_layer = self.transpose_for_scores(mixed_query_layer)
166
-
167
- past_key_value = (key_layer, value_layer)
168
-
169
- # Take the dot product between "query" and "key" to get the raw attention scores.
170
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
171
-
172
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
173
- seq_length = hidden_states.size()[1]
174
- position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
175
- position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
176
- distance = position_ids_l - position_ids_r
177
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
178
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
179
-
180
- if self.position_embedding_type == "relative_key":
181
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
182
- attention_scores = attention_scores + relative_position_scores
183
- elif self.position_embedding_type == "relative_key_query":
184
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
185
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
186
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
187
-
188
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
189
- if attention_mask is not None:
190
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
191
- attention_scores = attention_scores + attention_mask
192
-
193
- # Normalize the attention scores to probabilities.
194
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
195
-
196
- if is_cross_attention and self.save_attention:
197
- self.save_attention_map(attention_probs)
198
- attention_probs.register_hook(self.save_attn_gradients)
199
-
200
- # This is actually dropping out entire tokens to attend to, which might
201
- # seem a bit unusual, but is taken from the original Transformer paper.
202
- attention_probs_dropped = self.dropout(attention_probs)
203
-
204
- # Mask heads if we want to
205
- if head_mask is not None:
206
- attention_probs_dropped = attention_probs_dropped * head_mask
207
-
208
- context_layer = torch.matmul(attention_probs_dropped, value_layer)
209
-
210
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
211
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
212
- context_layer = context_layer.view(*new_context_layer_shape)
213
-
214
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
215
-
216
- outputs = outputs + (past_key_value,)
217
- return outputs
218
-
219
-
220
- class BertSelfOutput(nn.Module):
221
- def __init__(self, config):
222
- super().__init__()
223
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
224
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
225
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
226
-
227
- def forward(self, hidden_states, input_tensor):
228
- hidden_states = self.dense(hidden_states)
229
- hidden_states = self.dropout(hidden_states)
230
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
231
- return hidden_states
232
-
233
-
234
- class BertAttention(nn.Module):
235
- def __init__(self, config, is_cross_attention=False):
236
- super().__init__()
237
- self.self = BertSelfAttention(config, is_cross_attention)
238
- self.output = BertSelfOutput(config)
239
- self.pruned_heads = set()
240
-
241
- def prune_heads(self, heads):
242
- if len(heads) == 0:
243
- return
244
- heads, index = find_pruneable_heads_and_indices(
245
- heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
246
- )
247
-
248
- # Prune linear layers
249
- self.self.query = prune_linear_layer(self.self.query, index)
250
- self.self.key = prune_linear_layer(self.self.key, index)
251
- self.self.value = prune_linear_layer(self.self.value, index)
252
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
253
-
254
- # Update hyper params and store pruned heads
255
- self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
256
- self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
257
- self.pruned_heads = self.pruned_heads.union(heads)
258
-
259
- def forward(
260
- self,
261
- hidden_states,
262
- attention_mask=None,
263
- head_mask=None,
264
- encoder_hidden_states=None,
265
- encoder_attention_mask=None,
266
- past_key_value=None,
267
- output_attentions=False,
268
- ):
269
- self_outputs = self.self(
270
- hidden_states,
271
- attention_mask,
272
- head_mask,
273
- encoder_hidden_states,
274
- encoder_attention_mask,
275
- past_key_value,
276
- output_attentions,
277
- )
278
- attention_output = self.output(self_outputs[0], hidden_states)
279
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
- return outputs
281
-
282
-
283
- class BertIntermediate(nn.Module):
284
- def __init__(self, config):
285
- super().__init__()
286
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
287
- if isinstance(config.hidden_act, str):
288
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
289
- else:
290
- self.intermediate_act_fn = config.hidden_act
291
-
292
- def forward(self, hidden_states):
293
- hidden_states = self.dense(hidden_states)
294
- hidden_states = self.intermediate_act_fn(hidden_states)
295
- return hidden_states
296
-
297
-
298
- class BertOutput(nn.Module):
299
- def __init__(self, config):
300
- super().__init__()
301
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
302
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
303
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
304
-
305
- def forward(self, hidden_states, input_tensor):
306
- hidden_states = self.dense(hidden_states)
307
- hidden_states = self.dropout(hidden_states)
308
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
309
- return hidden_states
310
-
311
-
312
- class BertLayer(nn.Module):
313
- def __init__(self, config, layer_num):
314
- super().__init__()
315
- self.config = config
316
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
317
- self.seq_len_dim = 1
318
- self.attention = BertAttention(config)
319
- self.layer_num = layer_num
320
- if self.config.add_cross_attention:
321
- self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
322
- self.intermediate = BertIntermediate(config)
323
- self.output = BertOutput(config)
324
-
325
- def forward(
326
- self,
327
- hidden_states,
328
- attention_mask=None,
329
- head_mask=None,
330
- encoder_hidden_states=None,
331
- encoder_attention_mask=None,
332
- past_key_value=None,
333
- output_attentions=False,
334
- mode=None,
335
- ):
336
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
337
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
338
- self_attention_outputs = self.attention(
339
- hidden_states,
340
- attention_mask,
341
- head_mask,
342
- output_attentions=output_attentions,
343
- past_key_value=self_attn_past_key_value,
344
- )
345
- attention_output = self_attention_outputs[0]
346
-
347
- outputs = self_attention_outputs[1:-1]
348
- present_key_value = self_attention_outputs[-1]
349
-
350
- if mode=='multimodal':
351
- assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
352
-
353
- cross_attention_outputs = self.crossattention(
354
- attention_output,
355
- attention_mask,
356
- head_mask,
357
- encoder_hidden_states,
358
- encoder_attention_mask,
359
- output_attentions=output_attentions,
360
- )
361
- attention_output = cross_attention_outputs[0]
362
- outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
363
- layer_output = apply_chunking_to_forward(
364
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
365
- )
366
- outputs = (layer_output,) + outputs
367
-
368
- outputs = outputs + (present_key_value,)
369
-
370
- return outputs
371
-
372
- def feed_forward_chunk(self, attention_output):
373
- intermediate_output = self.intermediate(attention_output)
374
- layer_output = self.output(intermediate_output, attention_output)
375
- return layer_output
376
-
377
-
378
- class BertEncoder(nn.Module):
379
- def __init__(self, config):
380
- super().__init__()
381
- self.config = config
382
- self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
383
- self.gradient_checkpointing = False
384
-
385
- def forward(
386
- self,
387
- hidden_states,
388
- attention_mask=None,
389
- head_mask=None,
390
- encoder_hidden_states=None,
391
- encoder_attention_mask=None,
392
- past_key_values=None,
393
- use_cache=None,
394
- output_attentions=False,
395
- output_hidden_states=False,
396
- return_dict=True,
397
- mode='multimodal',
398
- ):
399
- all_hidden_states = () if output_hidden_states else None
400
- all_self_attentions = () if output_attentions else None
401
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
402
-
403
- next_decoder_cache = () if use_cache else None
404
-
405
- for i in range(self.config.num_hidden_layers):
406
- layer_module = self.layer[i]
407
- if output_hidden_states:
408
- all_hidden_states = all_hidden_states + (hidden_states,)
409
-
410
- layer_head_mask = head_mask[i] if head_mask is not None else None
411
- past_key_value = past_key_values[i] if past_key_values is not None else None
412
-
413
- if self.gradient_checkpointing and self.training:
414
-
415
- if use_cache:
416
- logger.warn(
417
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
418
- )
419
- use_cache = False
420
-
421
- def create_custom_forward(module):
422
- def custom_forward(*inputs):
423
- return module(*inputs, past_key_value, output_attentions)
424
-
425
- return custom_forward
426
-
427
- layer_outputs = torch.utils.checkpoint.checkpoint(
428
- create_custom_forward(layer_module),
429
- hidden_states,
430
- attention_mask,
431
- layer_head_mask,
432
- encoder_hidden_states,
433
- encoder_attention_mask,
434
- mode=mode,
435
- )
436
- else:
437
- layer_outputs = layer_module(
438
- hidden_states,
439
- attention_mask,
440
- layer_head_mask,
441
- encoder_hidden_states,
442
- encoder_attention_mask,
443
- past_key_value,
444
- output_attentions,
445
- mode=mode,
446
- )
447
-
448
- hidden_states = layer_outputs[0]
449
- if use_cache:
450
- next_decoder_cache += (layer_outputs[-1],)
451
- if output_attentions:
452
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
453
-
454
- if output_hidden_states:
455
- all_hidden_states = all_hidden_states + (hidden_states,)
456
-
457
- if not return_dict:
458
- return tuple(
459
- v
460
- for v in [
461
- hidden_states,
462
- next_decoder_cache,
463
- all_hidden_states,
464
- all_self_attentions,
465
- all_cross_attentions,
466
- ]
467
- if v is not None
468
- )
469
- return BaseModelOutputWithPastAndCrossAttentions(
470
- last_hidden_state=hidden_states,
471
- past_key_values=next_decoder_cache,
472
- hidden_states=all_hidden_states,
473
- attentions=all_self_attentions,
474
- cross_attentions=all_cross_attentions,
475
- )
476
-
477
-
478
- class BertPooler(nn.Module):
479
- def __init__(self, config):
480
- super().__init__()
481
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
482
- self.activation = nn.Tanh()
483
-
484
- def forward(self, hidden_states):
485
- # We "pool" the model by simply taking the hidden state corresponding
486
- # to the first token.
487
- first_token_tensor = hidden_states[:, 0]
488
- pooled_output = self.dense(first_token_tensor)
489
- pooled_output = self.activation(pooled_output)
490
- return pooled_output
491
-
492
-
493
- class BertPredictionHeadTransform(nn.Module):
494
- def __init__(self, config):
495
- super().__init__()
496
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
497
- if isinstance(config.hidden_act, str):
498
- self.transform_act_fn = ACT2FN[config.hidden_act]
499
- else:
500
- self.transform_act_fn = config.hidden_act
501
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
-
503
- def forward(self, hidden_states):
504
- hidden_states = self.dense(hidden_states)
505
- hidden_states = self.transform_act_fn(hidden_states)
506
- hidden_states = self.LayerNorm(hidden_states)
507
- return hidden_states
508
-
509
-
510
- class BertLMPredictionHead(nn.Module):
511
- def __init__(self, config):
512
- super().__init__()
513
- self.transform = BertPredictionHeadTransform(config)
514
-
515
- # The output weights are the same as the input embeddings, but there is
516
- # an output-only bias for each token.
517
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
518
-
519
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
520
-
521
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
- self.decoder.bias = self.bias
523
-
524
- def forward(self, hidden_states):
525
- hidden_states = self.transform(hidden_states)
526
- hidden_states = self.decoder(hidden_states)
527
- return hidden_states
528
-
529
-
530
- class BertOnlyMLMHead(nn.Module):
531
- def __init__(self, config):
532
- super().__init__()
533
- self.predictions = BertLMPredictionHead(config)
534
-
535
- def forward(self, sequence_output):
536
- prediction_scores = self.predictions(sequence_output)
537
- return prediction_scores
538
-
539
-
540
- class BertPreTrainedModel(PreTrainedModel):
541
- """
542
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
543
- models.
544
- """
545
-
546
- config_class = BertConfig
547
- base_model_prefix = "bert"
548
- _keys_to_ignore_on_load_missing = [r"position_ids"]
549
-
550
- def _init_weights(self, module):
551
- """ Initialize the weights """
552
- if isinstance(module, (nn.Linear, nn.Embedding)):
553
- # Slightly different from the TF version which uses truncated_normal for initialization
554
- # cf https://github.com/pytorch/pytorch/pull/5617
555
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
- elif isinstance(module, nn.LayerNorm):
557
- module.bias.data.zero_()
558
- module.weight.data.fill_(1.0)
559
- if isinstance(module, nn.Linear) and module.bias is not None:
560
- module.bias.data.zero_()
561
-
562
-
563
- class BertModel(BertPreTrainedModel):
564
- """
565
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
566
- cross-attention is added between the self-attention layers, following the architecture described in `Attention is
567
- all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
568
- Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
569
- argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
570
- input to the forward pass.
571
- """
572
-
573
- def __init__(self, config, add_pooling_layer=True):
574
- super().__init__(config)
575
- self.config = config
576
-
577
- self.embeddings = BertEmbeddings(config)
578
-
579
- self.encoder = BertEncoder(config)
580
-
581
- self.pooler = BertPooler(config) if add_pooling_layer else None
582
-
583
- self.init_weights()
584
-
585
-
586
- def get_input_embeddings(self):
587
- return self.embeddings.word_embeddings
588
-
589
- def set_input_embeddings(self, value):
590
- self.embeddings.word_embeddings = value
591
-
592
- def _prune_heads(self, heads_to_prune):
593
- """
594
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
595
- class PreTrainedModel
596
- """
597
- for layer, heads in heads_to_prune.items():
598
- self.encoder.layer[layer].attention.prune_heads(heads)
599
-
600
-
601
- def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
602
- """
603
- Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
604
-
605
- Arguments:
606
- attention_mask (:obj:`torch.Tensor`):
607
- Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
608
- input_shape (:obj:`Tuple[int]`):
609
- The shape of the input to the model.
610
- device: (:obj:`torch.device`):
611
- The device of the input to the model.
612
-
613
- Returns:
614
- :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
615
- """
616
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
617
- # ourselves in which case we just need to make it broadcastable to all heads.
618
- if attention_mask.dim() == 3:
619
- extended_attention_mask = attention_mask[:, None, :, :]
620
- elif attention_mask.dim() == 2:
621
- # Provided a padding mask of dimensions [batch_size, seq_length]
622
- # - if the model is a decoder, apply a causal mask in addition to the padding mask
623
- # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
624
- if is_decoder:
625
- batch_size, seq_length = input_shape
626
-
627
- seq_ids = torch.arange(seq_length, device=device)
628
- causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
629
- # in case past_key_values are used we need to add a prefix ones mask to the causal mask
630
- # causal and attention masks must have same type with pytorch version < 1.3
631
- causal_mask = causal_mask.to(attention_mask.dtype)
632
-
633
- if causal_mask.shape[1] < attention_mask.shape[1]:
634
- prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
635
- causal_mask = torch.cat(
636
- [
637
- torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
638
- causal_mask,
639
- ],
640
- axis=-1,
641
- )
642
-
643
- extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
644
- else:
645
- extended_attention_mask = attention_mask[:, None, None, :]
646
- else:
647
- raise ValueError(
648
- "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
649
- input_shape, attention_mask.shape
650
- )
651
- )
652
-
653
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
- # masked positions, this operation will create a tensor which is 0.0 for
655
- # positions we want to attend and -10000.0 for masked positions.
656
- # Since we are adding it to the raw scores before the softmax, this is
657
- # effectively the same as removing these entirely.
658
- extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
660
- return extended_attention_mask
661
-
662
- def forward(
663
- self,
664
- input_ids=None,
665
- attention_mask=None,
666
- position_ids=None,
667
- head_mask=None,
668
- inputs_embeds=None,
669
- encoder_embeds=None,
670
- encoder_hidden_states=None,
671
- encoder_attention_mask=None,
672
- past_key_values=None,
673
- use_cache=None,
674
- output_attentions=None,
675
- output_hidden_states=None,
676
- return_dict=None,
677
- is_decoder=False,
678
- mode='multimodal',
679
- ):
680
- r"""
681
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
682
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
683
- the model is configured as a decoder.
684
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
685
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
686
- the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
687
- - 1 for tokens that are **not masked**,
688
- - 0 for tokens that are **masked**.
689
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
690
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
691
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
692
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
693
- instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
694
- use_cache (:obj:`bool`, `optional`):
695
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
696
- decoding (see :obj:`past_key_values`).
697
- """
698
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
699
- output_hidden_states = (
700
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
701
- )
702
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
-
704
- if is_decoder:
705
- use_cache = use_cache if use_cache is not None else self.config.use_cache
706
- else:
707
- use_cache = False
708
-
709
- if input_ids is not None and inputs_embeds is not None:
710
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
711
- elif input_ids is not None:
712
- input_shape = input_ids.size()
713
- batch_size, seq_length = input_shape
714
- device = input_ids.device
715
- elif inputs_embeds is not None:
716
- input_shape = inputs_embeds.size()[:-1]
717
- batch_size, seq_length = input_shape
718
- device = inputs_embeds.device
719
- elif encoder_embeds is not None:
720
- input_shape = encoder_embeds.size()[:-1]
721
- batch_size, seq_length = input_shape
722
- device = encoder_embeds.device
723
- else:
724
- raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
725
-
726
- # past_key_values_length
727
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
728
-
729
- if attention_mask is None:
730
- attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
731
-
732
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
733
- # ourselves in which case we just need to make it broadcastable to all heads.
734
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
735
- device, is_decoder)
736
-
737
- # If a 2D or 3D attention mask is provided for the cross-attention
738
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
739
- if encoder_hidden_states is not None:
740
- if type(encoder_hidden_states) == list:
741
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
742
- else:
743
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
744
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
745
-
746
- if type(encoder_attention_mask) == list:
747
- encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
748
- elif encoder_attention_mask is None:
749
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
750
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
751
- else:
752
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
753
- else:
754
- encoder_extended_attention_mask = None
755
-
756
- # Prepare head mask if needed
757
- # 1.0 in head_mask indicate we keep the head
758
- # attention_probs has shape bsz x n_heads x N x N
759
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
760
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
761
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
762
-
763
- if encoder_embeds is None:
764
- embedding_output = self.embeddings(
765
- input_ids=input_ids,
766
- position_ids=position_ids,
767
- inputs_embeds=inputs_embeds,
768
- past_key_values_length=past_key_values_length,
769
- )
770
- else:
771
- embedding_output = encoder_embeds
772
-
773
- encoder_outputs = self.encoder(
774
- embedding_output,
775
- attention_mask=extended_attention_mask,
776
- head_mask=head_mask,
777
- encoder_hidden_states=encoder_hidden_states,
778
- encoder_attention_mask=encoder_extended_attention_mask,
779
- past_key_values=past_key_values,
780
- use_cache=use_cache,
781
- output_attentions=output_attentions,
782
- output_hidden_states=output_hidden_states,
783
- return_dict=return_dict,
784
- mode=mode,
785
- )
786
- sequence_output = encoder_outputs[0]
787
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
788
-
789
- if not return_dict:
790
- return (sequence_output, pooled_output) + encoder_outputs[1:]
791
-
792
- return BaseModelOutputWithPoolingAndCrossAttentions(
793
- last_hidden_state=sequence_output,
794
- pooler_output=pooled_output,
795
- past_key_values=encoder_outputs.past_key_values,
796
- hidden_states=encoder_outputs.hidden_states,
797
- attentions=encoder_outputs.attentions,
798
- cross_attentions=encoder_outputs.cross_attentions,
799
- )
800
-
801
-
802
-
803
- class BertLMHeadModel(BertPreTrainedModel):
804
-
805
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
806
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
807
-
808
- def __init__(self, config):
809
- super().__init__(config)
810
-
811
- self.bert = BertModel(config, add_pooling_layer=False)
812
- self.cls = BertOnlyMLMHead(config)
813
-
814
- self.init_weights()
815
-
816
- def get_output_embeddings(self):
817
- return self.cls.predictions.decoder
818
-
819
- def set_output_embeddings(self, new_embeddings):
820
- self.cls.predictions.decoder = new_embeddings
821
-
822
- def forward(
823
- self,
824
- input_ids=None,
825
- attention_mask=None,
826
- position_ids=None,
827
- head_mask=None,
828
- inputs_embeds=None,
829
- encoder_hidden_states=None,
830
- encoder_attention_mask=None,
831
- labels=None,
832
- past_key_values=None,
833
- use_cache=None,
834
- output_attentions=None,
835
- output_hidden_states=None,
836
- return_dict=None,
837
- return_logits=False,
838
- is_decoder=True,
839
- reduction='mean',
840
- mode='multimodal',
841
- ):
842
- r"""
843
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
844
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
845
- the model is configured as a decoder.
846
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
847
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
848
- the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
849
- - 1 for tokens that are **not masked**,
850
- - 0 for tokens that are **masked**.
851
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852
- Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
853
- ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
854
- ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
855
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
856
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
857
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
858
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
859
- instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
860
- use_cache (:obj:`bool`, `optional`):
861
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
862
- decoding (see :obj:`past_key_values`).
863
- Returns:
864
- Example::
865
- >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
866
- >>> import torch
867
- >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
868
- >>> config = BertConfig.from_pretrained("bert-base-cased")
869
- >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
870
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
871
- >>> outputs = model(**inputs)
872
- >>> prediction_logits = outputs.logits
873
- """
874
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
- if labels is not None:
876
- use_cache = False
877
-
878
- outputs = self.bert(
879
- input_ids,
880
- attention_mask=attention_mask,
881
- position_ids=position_ids,
882
- head_mask=head_mask,
883
- inputs_embeds=inputs_embeds,
884
- encoder_hidden_states=encoder_hidden_states,
885
- encoder_attention_mask=encoder_attention_mask,
886
- past_key_values=past_key_values,
887
- use_cache=use_cache,
888
- output_attentions=output_attentions,
889
- output_hidden_states=output_hidden_states,
890
- return_dict=return_dict,
891
- is_decoder=is_decoder,
892
- mode=mode,
893
- )
894
-
895
- sequence_output = outputs[0]
896
- prediction_scores = self.cls(sequence_output)
897
-
898
- if return_logits:
899
- return prediction_scores[:, :-1, :].contiguous()
900
-
901
- lm_loss = None
902
- if labels is not None:
903
- # we are doing next-token prediction; shift prediction scores and input ids by one
904
- shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
905
- labels = labels[:, 1:].contiguous()
906
- loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
907
- lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
908
- if reduction=='none':
909
- lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
910
-
911
- if not return_dict:
912
- output = (prediction_scores,) + outputs[2:]
913
- return ((lm_loss,) + output) if lm_loss is not None else output
914
-
915
- return CausalLMOutputWithCrossAttentions(
916
- loss=lm_loss,
917
- logits=prediction_scores,
918
- past_key_values=outputs.past_key_values,
919
- hidden_states=outputs.hidden_states,
920
- attentions=outputs.attentions,
921
- cross_attentions=outputs.cross_attentions,
922
- )
923
-
924
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
925
- input_shape = input_ids.shape
926
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
927
- if attention_mask is None:
928
- attention_mask = input_ids.new_ones(input_shape)
929
-
930
- # cut decoder_input_ids if past is used
931
- if past is not None:
932
- input_ids = input_ids[:, -1:]
933
-
934
- return {
935
- "input_ids": input_ids,
936
- "attention_mask": attention_mask,
937
- "past_key_values": past,
938
- "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
939
- "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
940
- "is_decoder": True,
941
- }
942
-
943
- def _reorder_cache(self, past, beam_idx):
944
- reordered_past = ()
945
- for layer_past in past:
946
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
947
- return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py DELETED
@@ -1,301 +0,0 @@
1
- '''
2
- * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
- * Based on timm code base
4
- * https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
- '''
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from functools import partial
11
-
12
- from timm.models.vision_transformer import _cfg, PatchEmbed
13
- from timm.models.registry import register_model
14
- from timm.models.layers import trunc_normal_, DropPath
15
- from timm.models.helpers import named_apply, adapt_input_conv
16
-
17
- # from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
18
-
19
- class Mlp(nn.Module):
20
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
21
- """
22
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
- super().__init__()
24
- out_features = out_features or in_features
25
- hidden_features = hidden_features or in_features
26
- self.fc1 = nn.Linear(in_features, hidden_features)
27
- self.act = act_layer()
28
- self.fc2 = nn.Linear(hidden_features, out_features)
29
- self.drop = nn.Dropout(drop)
30
-
31
- def forward(self, x):
32
- x = self.fc1(x)
33
- x = self.act(x)
34
- x = self.drop(x)
35
- x = self.fc2(x)
36
- x = self.drop(x)
37
- return x
38
-
39
-
40
- class Attention(nn.Module):
41
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
42
- super().__init__()
43
- self.num_heads = num_heads
44
- head_dim = dim // num_heads
45
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
46
- self.scale = qk_scale or head_dim ** -0.5
47
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
- self.attn_drop = nn.Dropout(attn_drop)
49
- self.proj = nn.Linear(dim, dim)
50
- self.proj_drop = nn.Dropout(proj_drop)
51
- self.attn_gradients = None
52
- self.attention_map = None
53
-
54
- def save_attn_gradients(self, attn_gradients):
55
- self.attn_gradients = attn_gradients
56
-
57
- def get_attn_gradients(self):
58
- return self.attn_gradients
59
-
60
- def save_attention_map(self, attention_map):
61
- self.attention_map = attention_map
62
-
63
- def get_attention_map(self):
64
- return self.attention_map
65
-
66
- def forward(self, x, register_hook=False):
67
- B, N, C = x.shape
68
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
70
-
71
- attn = (q @ k.transpose(-2, -1)) * self.scale
72
- attn = attn.softmax(dim=-1)
73
- attn = self.attn_drop(attn)
74
-
75
- if register_hook:
76
- self.save_attention_map(attn)
77
- attn.register_hook(self.save_attn_gradients)
78
-
79
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
- x = self.proj(x)
81
- x = self.proj_drop(x)
82
- return x
83
-
84
-
85
- class Block(nn.Module):
86
-
87
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
88
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
89
- super().__init__()
90
- self.norm1 = norm_layer(dim)
91
- self.attn = Attention(
92
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
93
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
- self.norm2 = norm_layer(dim)
96
- mlp_hidden_dim = int(dim * mlp_ratio)
97
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
-
99
- # if use_grad_checkpointing:
100
- # self.attn = checkpoint_wrapper(self.attn)
101
- # self.mlp = checkpoint_wrapper(self.mlp)
102
-
103
- def forward(self, x, register_hook=False):
104
- x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
- x = x + self.drop_path(self.mlp(self.norm2(x)))
106
- return x
107
-
108
-
109
- class VisionTransformer(nn.Module):
110
- """ Vision Transformer
111
- A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
- https://arxiv.org/abs/2010.11929
113
- """
114
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
115
- num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
116
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
117
- use_grad_checkpointing=False, ckpt_layer=0):
118
- """
119
- Args:
120
- img_size (int, tuple): input image size
121
- patch_size (int, tuple): patch size
122
- in_chans (int): number of input channels
123
- num_classes (int): number of classes for classification head
124
- embed_dim (int): embedding dimension
125
- depth (int): depth of transformer
126
- num_heads (int): number of attention heads
127
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
128
- qkv_bias (bool): enable bias for qkv if True
129
- qk_scale (float): override default qk scale of head_dim ** -0.5 if set
130
- representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
131
- drop_rate (float): dropout rate
132
- attn_drop_rate (float): attention dropout rate
133
- drop_path_rate (float): stochastic depth rate
134
- norm_layer: (nn.Module): normalization layer
135
- """
136
- super().__init__()
137
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
138
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
139
-
140
- self.patch_embed = PatchEmbed(
141
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
142
-
143
- num_patches = self.patch_embed.num_patches
144
-
145
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147
- self.pos_drop = nn.Dropout(p=drop_rate)
148
-
149
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150
- self.blocks = nn.ModuleList([
151
- Block(
152
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
154
- use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
155
- )
156
- for i in range(depth)])
157
- self.norm = norm_layer(embed_dim)
158
-
159
- trunc_normal_(self.pos_embed, std=.02)
160
- trunc_normal_(self.cls_token, std=.02)
161
- self.apply(self._init_weights)
162
-
163
- def _init_weights(self, m):
164
- if isinstance(m, nn.Linear):
165
- trunc_normal_(m.weight, std=.02)
166
- if isinstance(m, nn.Linear) and m.bias is not None:
167
- nn.init.constant_(m.bias, 0)
168
- elif isinstance(m, nn.LayerNorm):
169
- nn.init.constant_(m.bias, 0)
170
- nn.init.constant_(m.weight, 1.0)
171
-
172
- @torch.jit.ignore
173
- def no_weight_decay(self):
174
- return {'pos_embed', 'cls_token'}
175
-
176
- def forward(self, x, register_blk=-1):
177
- B = x.shape[0]
178
- x = self.patch_embed(x)
179
-
180
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
181
- x = torch.cat((cls_tokens, x), dim=1)
182
-
183
- x = x + self.pos_embed[:,:x.size(1),:]
184
- x = self.pos_drop(x)
185
-
186
- for i,blk in enumerate(self.blocks):
187
- x = blk(x, register_blk==i)
188
- x = self.norm(x)
189
-
190
- return x
191
-
192
- @torch.jit.ignore()
193
- def load_pretrained(self, checkpoint_path, prefix=''):
194
- _load_weights(self, checkpoint_path, prefix)
195
-
196
-
197
- @torch.no_grad()
198
- def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
199
- """ Load weights from .npz checkpoints for official Google Brain Flax implementation
200
- """
201
- import numpy as np
202
-
203
- def _n2p(w, t=True):
204
- if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
205
- w = w.flatten()
206
- if t:
207
- if w.ndim == 4:
208
- w = w.transpose([3, 2, 0, 1])
209
- elif w.ndim == 3:
210
- w = w.transpose([2, 0, 1])
211
- elif w.ndim == 2:
212
- w = w.transpose([1, 0])
213
- return torch.from_numpy(w)
214
-
215
- w = np.load(checkpoint_path)
216
- if not prefix and 'opt/target/embedding/kernel' in w:
217
- prefix = 'opt/target/'
218
-
219
- if hasattr(model.patch_embed, 'backbone'):
220
- # hybrid
221
- backbone = model.patch_embed.backbone
222
- stem_only = not hasattr(backbone, 'stem')
223
- stem = backbone if stem_only else backbone.stem
224
- stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
225
- stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
226
- stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
227
- if not stem_only:
228
- for i, stage in enumerate(backbone.stages):
229
- for j, block in enumerate(stage.blocks):
230
- bp = f'{prefix}block{i + 1}/unit{j + 1}/'
231
- for r in range(3):
232
- getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
233
- getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
234
- getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
235
- if block.downsample is not None:
236
- block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
237
- block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
238
- block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
239
- embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
240
- else:
241
- embed_conv_w = adapt_input_conv(
242
- model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
243
- model.patch_embed.proj.weight.copy_(embed_conv_w)
244
- model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
245
- model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
246
- pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
247
- if pos_embed_w.shape != model.pos_embed.shape:
248
- pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
249
- pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
250
- model.pos_embed.copy_(pos_embed_w)
251
- model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
252
- model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
253
- # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
254
- # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
255
- # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
256
- # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
257
- # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
258
- # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
259
- for i, block in enumerate(model.blocks.children()):
260
- block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
261
- mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
262
- block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
263
- block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
264
- block.attn.qkv.weight.copy_(torch.cat([
265
- _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
266
- block.attn.qkv.bias.copy_(torch.cat([
267
- _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
268
- block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
269
- block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
270
- for r in range(2):
271
- getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
272
- getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
273
- block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
274
- block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
275
-
276
-
277
- def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
278
- # interpolate position embedding
279
- embedding_size = pos_embed_checkpoint.shape[-1]
280
- num_patches = visual_encoder.patch_embed.num_patches
281
- num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
282
- # height (== width) for the checkpoint position embedding
283
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
284
- # height (== width) for the new position embedding
285
- new_size = int(num_patches ** 0.5)
286
-
287
- if orig_size!=new_size:
288
- # class_token and dist_token are kept unchanged
289
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
290
- # only the position tokens are interpolated
291
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
292
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
- pos_tokens = torch.nn.functional.interpolate(
294
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
296
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
297
- print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
298
-
299
- return new_pos_embed
300
- else:
301
- return pos_embed_checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/__init__.py DELETED
@@ -1,148 +0,0 @@
1
- from modelscope import snapshot_download
2
- from typing_extensions import Literal, TypeAlias
3
- import os
4
- from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
5
- from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
6
- from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
7
- from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
8
- from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
9
- from diffsynth.extensions.ImageQualityMetric.mps import MPScore
10
-
11
-
12
- preference_model_id: TypeAlias = Literal[
13
- "ImageReward",
14
- "Aesthetic",
15
- "PickScore",
16
- "CLIP",
17
- "HPSv2",
18
- "HPSv2.1",
19
- "MPS",
20
- ]
21
- model_dict = {
22
- "ImageReward": {
23
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
24
- "allow_file_pattern": [
25
- "ImageReward/ImageReward.safetensors",
26
- "ImageReward/med_config.json",
27
- "bert-base-uncased/config.json",
28
- "bert-base-uncased/model.safetensors",
29
- "bert-base-uncased/tokenizer.json",
30
- "bert-base-uncased/tokenizer_config.json",
31
- "bert-base-uncased/vocab.txt",
32
- ],
33
- "load_path": {
34
- "imagereward": "ImageReward/ImageReward.safetensors",
35
- "med_config": "ImageReward/med_config.json",
36
- "bert_model_path": "bert-base-uncased",
37
- },
38
- "model_class": ImageRewardScore
39
- },
40
- "Aesthetic": {
41
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
42
- "allow_file_pattern": [
43
- "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
44
- "clip-vit-large-patch14/config.json",
45
- "clip-vit-large-patch14/merges.txt",
46
- "clip-vit-large-patch14/model.safetensors",
47
- "clip-vit-large-patch14/preprocessor_config.json",
48
- "clip-vit-large-patch14/special_tokens_map.json",
49
- "clip-vit-large-patch14/tokenizer.json",
50
- "clip-vit-large-patch14/tokenizer_config.json",
51
- "clip-vit-large-patch14/vocab.json",
52
- ],
53
- "load_path": {
54
- "aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
55
- "clip-large": "clip-vit-large-patch14",
56
- },
57
- "model_class": AestheticScore
58
- },
59
- "PickScore": {
60
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
61
- "allow_file_pattern": [
62
- "PickScore_v1/*",
63
- "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
64
- "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
65
- "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
66
- "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
67
- "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
68
- "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
69
- "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
70
- ],
71
- "load_path": {
72
- "pickscore": "PickScore_v1",
73
- "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
74
- },
75
- "model_class": PickScore
76
- },
77
- "CLIP": {
78
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
79
- "allow_file_pattern": [
80
- "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
81
- "bpe_simple_vocab_16e6.txt.gz",
82
- ],
83
- "load_path": {
84
- "open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
85
- "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
86
- },
87
- "model_class": CLIPScore
88
- },
89
- "HPSv2": {
90
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
91
- "allow_file_pattern": [
92
- "HPS_v2/HPS_v2_compressed.safetensors",
93
- "bpe_simple_vocab_16e6.txt.gz",
94
- ],
95
- "load_path": {
96
- "hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
97
- "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
98
- },
99
- "model_class": HPScore_v2,
100
- "extra_kwargs": {"model_version": "v2"}
101
- },
102
- "HPSv2.1": {
103
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
104
- "allow_file_pattern": [
105
- "HPS_v2/HPS_v2.1_compressed.safetensors",
106
- "bpe_simple_vocab_16e6.txt.gz",
107
- ],
108
- "load_path": {
109
- "hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
110
- "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
111
- },
112
- "model_class": HPScore_v2,
113
- "extra_kwargs": {"model_version": "v21"}
114
- },
115
- "MPS": {
116
- "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
117
- "allow_file_pattern": [
118
- "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
119
- "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
120
- "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
121
- "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
122
- "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
123
- "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
124
- "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
125
- "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
126
- ],
127
- "load_path": {
128
- "mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
129
- "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
130
- },
131
- "model_class": MPScore
132
- },
133
- }
134
-
135
-
136
- def download_preference_model(model_name: preference_model_id, cache_dir="models"):
137
- metadata = model_dict[model_name]
138
- snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
139
- load_path = metadata["load_path"]
140
- load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
141
- return load_path
142
-
143
-
144
- def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
145
- model_class = model_dict[model_name]["model_class"]
146
- extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
147
- preference_model = model_class(device=device, path=path, **extra_kwargs)
148
- return preference_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/aesthetic.py DELETED
@@ -1,148 +0,0 @@
1
- from typing import List, Optional
2
- from PIL import Image
3
- import torch
4
- from transformers import AutoProcessor, AutoModel
5
- from safetensors.torch import load_file
6
- import os
7
- from typing import Union, List
8
- from .config import MODEL_PATHS
9
-
10
- class MLP(torch.nn.Module):
11
- def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
12
- super().__init__()
13
- self.input_size = input_size
14
- self.xcol = xcol
15
- self.ycol = ycol
16
- self.layers = torch.nn.Sequential(
17
- torch.nn.Linear(self.input_size, 1024),
18
- #torch.nn.ReLU(),
19
- torch.nn.Dropout(0.2),
20
- torch.nn.Linear(1024, 128),
21
- #torch.nn.ReLU(),
22
- torch.nn.Dropout(0.2),
23
- torch.nn.Linear(128, 64),
24
- #torch.nn.ReLU(),
25
- torch.nn.Dropout(0.1),
26
- torch.nn.Linear(64, 16),
27
- #torch.nn.ReLU(),
28
- torch.nn.Linear(16, 1),
29
- )
30
-
31
- def forward(self, x: torch.Tensor) -> torch.Tensor:
32
- return self.layers(x)
33
-
34
- def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
35
- x = batch[self.xcol]
36
- y = batch[self.ycol].reshape(-1, 1)
37
- x_hat = self.layers(x)
38
- loss = torch.nn.functional.mse_loss(x_hat, y)
39
- return loss
40
-
41
- def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
42
- x = batch[self.xcol]
43
- y = batch[self.ycol].reshape(-1, 1)
44
- x_hat = self.layers(x)
45
- loss = torch.nn.functional.mse_loss(x_hat, y)
46
- return loss
47
-
48
- def configure_optimizers(self) -> torch.optim.Optimizer:
49
- return torch.optim.Adam(self.parameters(), lr=1e-3)
50
-
51
-
52
- class AestheticScore(torch.nn.Module):
53
- def __init__(self, device: torch.device, path: str = MODEL_PATHS):
54
- super().__init__()
55
- self.device = device
56
- self.aes_model_path = path.get("aesthetic_predictor")
57
- # Load the MLP model
58
- self.model = MLP(768)
59
- try:
60
- if self.aes_model_path.endswith(".safetensors"):
61
- state_dict = load_file(self.aes_model_path)
62
- else:
63
- state_dict = torch.load(self.aes_model_path)
64
- self.model.load_state_dict(state_dict)
65
- except Exception as e:
66
- raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
67
-
68
- self.model.to(device)
69
- self.model.eval()
70
-
71
- # Load the CLIP model and processor
72
- clip_model_name = path.get('clip-large')
73
- self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
74
- self.processor = AutoProcessor.from_pretrained(clip_model_name)
75
-
76
- def _calculate_score(self, image: torch.Tensor) -> float:
77
- """Calculate the aesthetic score for a single image.
78
-
79
- Args:
80
- image (torch.Tensor): The processed image tensor.
81
-
82
- Returns:
83
- float: The aesthetic score.
84
- """
85
- with torch.no_grad():
86
- # Get image embeddings
87
- image_embs = self.model2.get_image_features(image)
88
- image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
89
-
90
- # Compute score
91
- score = self.model(image_embs).cpu().flatten().item()
92
-
93
- return score
94
-
95
- @torch.no_grad()
96
- def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
97
- """Score the images based on their aesthetic quality.
98
-
99
- Args:
100
- images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
101
-
102
- Returns:
103
- List[float]: List of scores for the images.
104
- """
105
- try:
106
- if isinstance(images, (str, Image.Image)):
107
- # Single image
108
- if isinstance(images, str):
109
- pil_image = Image.open(images)
110
- else:
111
- pil_image = images
112
-
113
- # Prepare image inputs
114
- image_inputs = self.processor(
115
- images=pil_image,
116
- padding=True,
117
- truncation=True,
118
- max_length=77,
119
- return_tensors="pt",
120
- ).to(self.device)
121
-
122
- return [self._calculate_score(image_inputs["pixel_values"])]
123
- elif isinstance(images, list):
124
- # Multiple images
125
- scores = []
126
- for one_image in images:
127
- if isinstance(one_image, str):
128
- pil_image = Image.open(one_image)
129
- elif isinstance(one_image, Image.Image):
130
- pil_image = one_image
131
- else:
132
- raise TypeError("The type of parameter images is illegal.")
133
-
134
- # Prepare image inputs
135
- image_inputs = self.processor(
136
- images=pil_image,
137
- padding=True,
138
- truncation=True,
139
- max_length=77,
140
- return_tensors="pt",
141
- ).to(self.device)
142
-
143
- scores.append(self._calculate_score(image_inputs["pixel_values"]))
144
- return scores
145
- else:
146
- raise TypeError("The type of parameter images is illegal.")
147
- except Exception as e:
148
- raise RuntimeError(f"Error in scoring images: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/clip.py DELETED
@@ -1,97 +0,0 @@
1
- from typing import List, Union
2
- from PIL import Image
3
- import torch
4
- from .open_clip import create_model_and_transforms, get_tokenizer
5
- from .config import MODEL_PATHS
6
-
7
- class CLIPScore(torch.nn.Module):
8
- def __init__(self, device: torch.device, path: str = MODEL_PATHS):
9
- super().__init__()
10
- """Initialize the CLIPScore with a model and tokenizer.
11
-
12
- Args:
13
- device (torch.device): The device to load the model on.
14
- """
15
- self.device = device
16
-
17
- # Create model and transforms
18
- self.model, _, self.preprocess_val = create_model_and_transforms(
19
- "ViT-H-14",
20
- # "laion2B-s32B-b79K",
21
- pretrained=path.get("open_clip"),
22
- precision="amp",
23
- device=device,
24
- jit=False,
25
- force_quick_gelu=False,
26
- force_custom_text=False,
27
- force_patch_dropout=False,
28
- force_image_size=None,
29
- pretrained_image=False,
30
- image_mean=None,
31
- image_std=None,
32
- light_augmentation=True,
33
- aug_cfg={},
34
- output_dict=True,
35
- with_score_predictor=False,
36
- with_region_predictor=False,
37
- )
38
-
39
- # Initialize tokenizer
40
- self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
41
- self.model = self.model.to(device)
42
- self.model.eval()
43
-
44
- def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
45
- """Calculate the CLIP score for a single image and prompt.
46
-
47
- Args:
48
- image (torch.Tensor): The processed image tensor.
49
- prompt (str): The prompt text.
50
-
51
- Returns:
52
- float: The CLIP score.
53
- """
54
- with torch.no_grad():
55
- # Process the prompt
56
- text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
57
-
58
- # Calculate the CLIP score
59
- outputs = self.model(image, text)
60
- image_features, text_features = outputs["image_features"], outputs["text_features"]
61
- logits_per_image = image_features @ text_features.T
62
- clip_score = torch.diagonal(logits_per_image).cpu().numpy()
63
-
64
- return clip_score[0].item()
65
-
66
- @torch.no_grad()
67
- def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
68
- """Score the images based on the prompt.
69
-
70
- Args:
71
- images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
72
- prompt (str): The prompt text.
73
-
74
- Returns:
75
- List[float]: List of CLIP scores for the images.
76
- """
77
- if isinstance(images, (str, Image.Image)):
78
- # Single image
79
- if isinstance(images, str):
80
- image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
81
- else:
82
- image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
83
- return [self._calculate_score(image, prompt)]
84
- elif isinstance(images, list):
85
- # Multiple images
86
- scores = []
87
- for one_images in images:
88
- if isinstance(one_images, str):
89
- image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
90
- elif isinstance(one_images, Image.Image):
91
- image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
92
- else:
93
- raise TypeError("The type of parameter images is illegal.")
94
- scores.append(self._calculate_score(image, prompt))
95
- return scores
96
- else:
97
- raise TypeError("The type of parameter images is illegal.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/config.py DELETED
@@ -1,23 +0,0 @@
1
- import os
2
-
3
- current_dir = os.path.dirname(os.path.abspath(__file__))
4
- project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
5
- model_path = os.path.join(project_root, 'models', 'QualityMetric')
6
-
7
-
8
- def get_model_path(model_name):
9
- return os.path.join(model_path, model_name)
10
-
11
-
12
- MODEL_PATHS = {
13
- "aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
14
- "open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
15
- "hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
16
- "hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
17
- "imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
18
- "med_config": get_model_path("ImageReward/med_config.json"),
19
- "clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
20
- "clip-large": get_model_path("clip-vit-large-patch14"),
21
- "mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
22
- "pickscore": get_model_path("PickScore_v1")
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/hps.py DELETED
@@ -1,118 +0,0 @@
1
- from typing import List, Union
2
- from PIL import Image
3
- import torch
4
- from .open_clip import create_model_and_transforms, get_tokenizer
5
- from safetensors.torch import load_file
6
- import os
7
- from .config import MODEL_PATHS
8
-
9
- class HPScore_v2(torch.nn.Module):
10
- def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
11
- super().__init__()
12
- """Initialize the Selector with a model and tokenizer.
13
-
14
- Args:
15
- device (torch.device): The device to load the model on.
16
- model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
17
- """
18
- self.device = device
19
-
20
- if model_version == "v2":
21
- safetensors_path = path.get("hpsv2")
22
- elif model_version == "v21":
23
- safetensors_path = path.get("hpsv2.1")
24
- else:
25
- raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
26
-
27
- # Create model and transforms
28
- model, _, self.preprocess_val = create_model_and_transforms(
29
- "ViT-H-14",
30
- # "laion2B-s32B-b79K",
31
- pretrained=path.get("open_clip"),
32
- precision="amp",
33
- device=device,
34
- jit=False,
35
- force_quick_gelu=False,
36
- force_custom_text=False,
37
- force_patch_dropout=False,
38
- force_image_size=None,
39
- pretrained_image=False,
40
- image_mean=None,
41
- image_std=None,
42
- light_augmentation=True,
43
- aug_cfg={},
44
- output_dict=True,
45
- with_score_predictor=False,
46
- with_region_predictor=False,
47
- )
48
-
49
- # Load model weights
50
- try:
51
- state_dict = load_file(safetensors_path)
52
- model.load_state_dict(state_dict)
53
- except Exception as e:
54
- raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
55
-
56
- # Initialize tokenizer and model
57
- self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
58
- model = model.to(device)
59
- model.eval()
60
- self.model = model
61
-
62
- def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
63
- """Calculate the HPS score for a single image and prompt.
64
-
65
- Args:
66
- image (torch.Tensor): The processed image tensor.
67
- prompt (str): The prompt text.
68
-
69
- Returns:
70
- float: The HPS score.
71
- """
72
- with torch.no_grad():
73
- # Process the prompt
74
- text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
75
-
76
- # Calculate the HPS score
77
- outputs = self.model(image, text)
78
- image_features, text_features = outputs["image_features"], outputs["text_features"]
79
- logits_per_image = image_features @ text_features.T
80
- hps_score = torch.diagonal(logits_per_image).cpu().numpy()
81
-
82
- return hps_score[0].item()
83
-
84
- @torch.no_grad()
85
- def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
86
- """Score the images based on the prompt.
87
-
88
- Args:
89
- images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
90
- prompt (str): The prompt text.
91
-
92
- Returns:
93
- List[float]: List of HPS scores for the images.
94
- """
95
- try:
96
- if isinstance(images, (str, Image.Image)):
97
- # Single image
98
- if isinstance(images, str):
99
- image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
100
- else:
101
- image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
102
- return [self._calculate_score(image, prompt)]
103
- elif isinstance(images, list):
104
- # Multiple images
105
- scores = []
106
- for one_images in images:
107
- if isinstance(one_images, str):
108
- image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
109
- elif isinstance(one_images, Image.Image):
110
- image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
111
- else:
112
- raise TypeError("The type of parameter images is illegal.")
113
- scores.append(self._calculate_score(image, prompt))
114
- return scores
115
- else:
116
- raise TypeError("The type of parameter images is illegal.")
117
- except Exception as e:
118
- raise RuntimeError(f"Error in scoring images: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/imagereward.py DELETED
@@ -1,212 +0,0 @@
1
- import os
2
- import torch
3
- from PIL import Image
4
- from typing import List, Union
5
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
6
- from .BLIP.blip_pretrain import BLIP_Pretrain
7
- from torchvision.transforms import InterpolationMode
8
- from safetensors.torch import load_file
9
- from .config import MODEL_PATHS
10
- BICUBIC = InterpolationMode.BICUBIC
11
-
12
- def _convert_image_to_rgb(image):
13
- return image.convert("RGB")
14
-
15
- def _transform(n_px):
16
- return Compose([
17
- Resize(n_px, interpolation=BICUBIC),
18
- CenterCrop(n_px),
19
- _convert_image_to_rgb,
20
- ToTensor(),
21
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
22
- ])
23
-
24
- class MLP(torch.nn.Module):
25
- def __init__(self, input_size):
26
- super().__init__()
27
- self.input_size = input_size
28
-
29
- self.layers = torch.nn.Sequential(
30
- torch.nn.Linear(self.input_size, 1024),
31
- #nn.ReLU(),
32
- torch.nn.Dropout(0.2),
33
- torch.nn.Linear(1024, 128),
34
- #nn.ReLU(),
35
- torch.nn.Dropout(0.2),
36
- torch.nn.Linear(128, 64),
37
- #nn.ReLU(),
38
- torch.nn.Dropout(0.1),
39
- torch.nn.Linear(64, 16),
40
- #nn.ReLU(),
41
- torch.nn.Linear(16, 1)
42
- )
43
-
44
- # initial MLP param
45
- for name, param in self.layers.named_parameters():
46
- if 'weight' in name:
47
- torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
48
- if 'bias' in name:
49
- torch.nn.init.constant_(param, val=0)
50
-
51
- def forward(self, input):
52
- return self.layers(input)
53
-
54
- class ImageReward(torch.nn.Module):
55
- def __init__(self, med_config, device='cpu', bert_model_path=""):
56
- super().__init__()
57
- self.device = device
58
-
59
- self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
60
- self.preprocess = _transform(224)
61
- self.mlp = MLP(768)
62
-
63
- self.mean = 0.16717362830052426
64
- self.std = 1.0333394966054072
65
-
66
- def score_grad(self, prompt_ids, prompt_attention_mask, image):
67
- """Calculate the score with gradient for a single image and prompt.
68
-
69
- Args:
70
- prompt_ids (torch.Tensor): Tokenized prompt IDs.
71
- prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
72
- image (torch.Tensor): The processed image tensor.
73
-
74
- Returns:
75
- torch.Tensor: The reward score.
76
- """
77
- image_embeds = self.blip.visual_encoder(image)
78
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
79
- text_output = self.blip.text_encoder(
80
- prompt_ids,
81
- attention_mask=prompt_attention_mask,
82
- encoder_hidden_states=image_embeds,
83
- encoder_attention_mask=image_atts,
84
- return_dict=True,
85
- )
86
- txt_features = text_output.last_hidden_state[:, 0, :]
87
- rewards = self.mlp(txt_features)
88
- rewards = (rewards - self.mean) / self.std
89
- return rewards
90
-
91
- def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
92
- """Score the images based on the prompt.
93
-
94
- Args:
95
- prompt (str): The prompt text.
96
- images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
97
-
98
- Returns:
99
- List[float]: List of scores for the images.
100
- """
101
- if isinstance(images, (str, Image.Image)):
102
- # Single image
103
- if isinstance(images, str):
104
- pil_image = Image.open(images)
105
- else:
106
- pil_image = images
107
- image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
108
- return [self._calculate_score(prompt, image).item()]
109
- elif isinstance(images, list):
110
- # Multiple images
111
- scores = []
112
- for one_image in images:
113
- if isinstance(one_image, str):
114
- pil_image = Image.open(one_image)
115
- elif isinstance(one_image, Image.Image):
116
- pil_image = one_image
117
- else:
118
- raise TypeError("The type of parameter images is illegal.")
119
- image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
120
- scores.append(self._calculate_score(prompt, image).item())
121
- return scores
122
- else:
123
- raise TypeError("The type of parameter images is illegal.")
124
-
125
- def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
126
- """Calculate the score for a single image and prompt.
127
-
128
- Args:
129
- prompt (str): The prompt text.
130
- image (torch.Tensor): The processed image tensor.
131
-
132
- Returns:
133
- torch.Tensor: The reward score.
134
- """
135
- text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
136
- image_embeds = self.blip.visual_encoder(image)
137
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
138
- text_output = self.blip.text_encoder(
139
- text_input.input_ids,
140
- attention_mask=text_input.attention_mask,
141
- encoder_hidden_states=image_embeds,
142
- encoder_attention_mask=image_atts,
143
- return_dict=True,
144
- )
145
- txt_features = text_output.last_hidden_state[:, 0, :].float()
146
- rewards = self.mlp(txt_features)
147
- rewards = (rewards - self.mean) / self.std
148
- return rewards
149
-
150
- def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
151
- """Rank the images based on the prompt.
152
-
153
- Args:
154
- prompt (str): The prompt text.
155
- generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
156
-
157
- Returns:
158
- tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
159
- """
160
- text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
161
- txt_set = []
162
- for generation in generations_list:
163
- if isinstance(generation, str):
164
- pil_image = Image.open(generation)
165
- elif isinstance(generation, Image.Image):
166
- pil_image = generation
167
- else:
168
- raise TypeError("The type of parameter generations_list is illegal.")
169
- image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
170
- image_embeds = self.blip.visual_encoder(image)
171
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
172
- text_output = self.blip.text_encoder(
173
- text_input.input_ids,
174
- attention_mask=text_input.attention_mask,
175
- encoder_hidden_states=image_embeds,
176
- encoder_attention_mask=image_atts,
177
- return_dict=True,
178
- )
179
- txt_set.append(text_output.last_hidden_state[:, 0, :])
180
- txt_features = torch.cat(txt_set, 0).float()
181
- rewards = self.mlp(txt_features)
182
- rewards = (rewards - self.mean) / self.std
183
- rewards = torch.squeeze(rewards)
184
- _, rank = torch.sort(rewards, dim=0, descending=True)
185
- _, indices = torch.sort(rank, dim=0)
186
- indices = indices + 1
187
- return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
188
-
189
-
190
- class ImageRewardScore(torch.nn.Module):
191
- def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
192
- super().__init__()
193
- self.device = device if isinstance(device, torch.device) else torch.device(device)
194
- model_path = path.get("imagereward")
195
- med_config = path.get("med_config")
196
- state_dict = load_file(model_path)
197
- self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
198
- self.model.load_state_dict(state_dict, strict=False)
199
- self.model.eval()
200
-
201
- @torch.no_grad()
202
- def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
203
- """Score the images based on the prompt.
204
-
205
- Args:
206
- images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
207
- prompt (str): The prompt text.
208
-
209
- Returns:
210
- List[float]: List of scores for the images.
211
- """
212
- return self.model.score(images, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/mps.py DELETED
@@ -1,129 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image
4
- from io import BytesIO
5
- from tqdm.auto import tqdm
6
- from transformers import CLIPFeatureExtractor, CLIPImageProcessor
7
- from transformers import CLIPConfig
8
- from dataclasses import dataclass
9
- from transformers import CLIPModel as HFCLIPModel
10
- from safetensors.torch import load_file
11
- from torch import nn, einsum
12
-
13
- from .trainer.models.base_model import BaseModelConfig
14
-
15
- from transformers import CLIPConfig
16
- from transformers import AutoProcessor, AutoModel, AutoTokenizer
17
- from typing import Any, Optional, Tuple, Union, List
18
- import torch
19
-
20
- from .trainer.models.cross_modeling import Cross_model
21
- from .trainer.models import clip_model
22
- import torch.nn.functional as F
23
- import gc
24
- import json
25
- from .config import MODEL_PATHS
26
-
27
- class MPScore(torch.nn.Module):
28
- def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
29
- super().__init__()
30
- """Initialize the MPSModel with a processor, tokenizer, and model.
31
-
32
- Args:
33
- device (Union[str, torch.device]): The device to load the model on.
34
- """
35
- self.device = device
36
- processor_name_or_path = path.get("clip")
37
- self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
38
- self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
39
- self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
40
- state_dict = load_file(path.get("mps"))
41
- self.model.load_state_dict(state_dict, strict=False)
42
- self.model.to(device)
43
- self.condition = condition
44
-
45
- def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
46
- """Calculate the reward score for a single image and prompt.
47
-
48
- Args:
49
- image (torch.Tensor): The processed image tensor.
50
- prompt (str): The prompt text.
51
-
52
- Returns:
53
- float: The reward score.
54
- """
55
- def _tokenize(caption):
56
- input_ids = self.tokenizer(
57
- caption,
58
- max_length=self.tokenizer.model_max_length,
59
- padding="max_length",
60
- truncation=True,
61
- return_tensors="pt"
62
- ).input_ids
63
- return input_ids
64
-
65
- text_input = _tokenize(prompt).to(self.device)
66
- if self.condition == 'overall':
67
- condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
68
- elif self.condition == 'aesthetics':
69
- condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
70
- elif self.condition == 'quality':
71
- condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
72
- elif self.condition == 'semantic':
73
- condition_prompt = 'quantity, attributes, position, number, location'
74
- else:
75
- raise ValueError(
76
- f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
77
- condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
78
-
79
- with torch.no_grad():
80
- text_f, text_features = self.model.model.get_text_features(text_input)
81
-
82
- image_f = self.model.model.get_image_features(image.half())
83
- condition_f, _ = self.model.model.get_text_features(condition_batch)
84
-
85
- sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
86
- sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
87
- sim_text_condition = sim_text_condition / sim_text_condition.max()
88
- mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
89
- mask = mask.repeat(1, image_f.shape[1], 1)
90
- image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
91
-
92
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
93
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
94
- image_score = self.model.logit_scale.exp() * text_features @ image_features.T
95
-
96
- return image_score[0].cpu().numpy().item()
97
-
98
- @torch.no_grad()
99
- def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
100
- """Score the images based on the prompt.
101
-
102
- Args:
103
- images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
104
- prompt (str): The prompt text.
105
-
106
- Returns:
107
- List[float]: List of reward scores for the images.
108
- """
109
- if isinstance(images, (str, Image.Image)):
110
- # Single image
111
- if isinstance(images, str):
112
- image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
113
- else:
114
- image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
115
- return [self._calculate_score(image, prompt)]
116
- elif isinstance(images, list):
117
- # Multiple images
118
- scores = []
119
- for one_images in images:
120
- if isinstance(one_images, str):
121
- image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
122
- elif isinstance(one_images, Image.Image):
123
- image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
124
- else:
125
- raise TypeError("The type of parameter images is illegal.")
126
- scores.append(self._calculate_score(image, prompt))
127
- return scores
128
- else:
129
- raise TypeError("The type of parameter images is illegal.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- from .coca_model import CoCa
2
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
- from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4
- from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
- from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
- from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
- convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
8
- from .openai import load_openai_model, list_openai_models
9
- from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
- get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
- from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
- from .tokenizer import SimpleTokenizer
13
- from .transform import image_transform, AugmentationCfg
14
- from .utils import freeze_batch_norm_2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py DELETED
@@ -1,458 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- import numpy as np
7
- from dataclasses import dataclass
8
-
9
- from .transformer import (
10
- LayerNormFp32,
11
- LayerNorm,
12
- QuickGELU,
13
- MultimodalTransformer,
14
- )
15
- from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
-
17
- try:
18
- from transformers import (
19
- BeamSearchScorer,
20
- LogitsProcessorList,
21
- TopPLogitsWarper,
22
- TopKLogitsWarper,
23
- RepetitionPenaltyLogitsProcessor,
24
- MinLengthLogitsProcessor,
25
- MaxLengthCriteria,
26
- StoppingCriteriaList
27
- )
28
-
29
- GENERATION_TYPES = {
30
- "top_k": TopKLogitsWarper,
31
- "top_p": TopPLogitsWarper,
32
- "beam_search": "beam_search"
33
- }
34
- _has_transformers = True
35
- except ImportError as e:
36
- GENERATION_TYPES = {
37
- "top_k": None,
38
- "top_p": None,
39
- "beam_search": "beam_search"
40
- }
41
- _has_transformers = False
42
-
43
-
44
- @dataclass
45
- class MultimodalCfg(CLIPTextCfg):
46
- mlp_ratio: int = 4
47
- dim_head: int = 64
48
- heads: int = 8
49
- n_queries: int = 256
50
- attn_pooler_heads: int = 8
51
-
52
-
53
- def _build_text_decoder_tower(
54
- embed_dim,
55
- multimodal_cfg,
56
- quick_gelu: bool = False,
57
- cast_dtype: Optional[torch.dtype] = None,
58
- ):
59
- multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
- act_layer = QuickGELU if quick_gelu else nn.GELU
61
- norm_layer = (
62
- LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
- )
64
-
65
- decoder = MultimodalTransformer(
66
- context_length=multimodal_cfg.context_length,
67
- width=multimodal_cfg.width,
68
- heads=multimodal_cfg.heads,
69
- layers=multimodal_cfg.layers,
70
- ls_init_value=multimodal_cfg.ls_init_value,
71
- output_dim=embed_dim,
72
- act_layer=act_layer,
73
- norm_layer=norm_layer,
74
- )
75
-
76
- return decoder
77
-
78
-
79
- class CoCa(nn.Module):
80
- def __init__(
81
- self,
82
- embed_dim,
83
- multimodal_cfg: MultimodalCfg,
84
- text_cfg: CLIPTextCfg,
85
- vision_cfg: CLIPVisionCfg,
86
- quick_gelu: bool = False,
87
- cast_dtype: Optional[torch.dtype] = None,
88
- pad_id: int = 0,
89
- ):
90
- super().__init__()
91
- multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
- text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
- vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
-
95
- self.text = _build_text_tower(
96
- embed_dim=embed_dim,
97
- text_cfg=text_cfg,
98
- quick_gelu=quick_gelu,
99
- cast_dtype=cast_dtype,
100
- )
101
-
102
- vocab_size = (
103
- text_cfg.vocab_size # for hf models
104
- if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
- else text_cfg.vocab_size
106
- )
107
-
108
- self.visual = _build_vision_tower(
109
- embed_dim=embed_dim,
110
- vision_cfg=vision_cfg,
111
- quick_gelu=quick_gelu,
112
- cast_dtype=cast_dtype,
113
- )
114
-
115
- self.text_decoder = _build_text_decoder_tower(
116
- vocab_size,
117
- multimodal_cfg=multimodal_cfg,
118
- quick_gelu=quick_gelu,
119
- cast_dtype=cast_dtype,
120
- )
121
-
122
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
- self.pad_id = pad_id
124
-
125
- @torch.jit.ignore
126
- def set_grad_checkpointing(self, enable=True):
127
- self.visual.set_grad_checkpointing(enable)
128
- self.text.set_grad_checkpointing(enable)
129
- self.text_decoder.set_grad_checkpointing(enable)
130
-
131
- def _encode_image(self, images, normalize=True):
132
- image_latent, tokens_embs = self.visual(images)
133
- image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
- return image_latent, tokens_embs
135
-
136
- def _encode_text(self, text, normalize=True, embed_cls=True):
137
- text = text[:, :-1] if embed_cls else text # make space for CLS token
138
- text_latent, token_emb = self.text(text)
139
- text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
- return text_latent, token_emb
141
-
142
- def encode_image(self, images, normalize=True):
143
- image_latent, _ = self._encode_image(images, normalize=normalize)
144
- return image_latent
145
-
146
- def encode_text(self, text, normalize=True, embed_cls=True):
147
- text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
- return text_latent
149
-
150
- def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
- text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
- if image_latent is None or image_embs is None:
153
- image_latent, image_embs = self._encode_image(image)
154
-
155
- # TODO: add assertion to avoid bugs?
156
- labels = text[:, -token_embs.shape[1]:]
157
-
158
- logits = self.text_decoder(image_embs, token_embs)
159
- return {
160
- "image_features": image_latent,
161
- "text_features": text_latent,
162
- "logits": logits,
163
- "labels": labels,
164
- "logit_scale": self.logit_scale.exp()
165
- }
166
-
167
- def generate(
168
- self,
169
- image,
170
- text=None,
171
- seq_len=30,
172
- max_seq_len=77,
173
- temperature=1.,
174
- generation_type="beam_search",
175
- top_p=0.1, # keep tokens in the 1 - top_p quantile
176
- top_k=1, # keeps the top_k most probable tokens
177
- pad_token_id=None,
178
- eos_token_id=None,
179
- sot_token_id=None,
180
- num_beams=6,
181
- num_beam_groups=3,
182
- min_seq_len=5,
183
- stopping_criteria=None,
184
- repetition_penalty=1.0,
185
- fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
- ):
187
- # taking many ideas and components from HuggingFace GenerationMixin
188
- # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
- assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
- assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
-
192
- with torch.no_grad():
193
- sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
- eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
- pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
- logit_processor = LogitsProcessorList(
197
- [
198
- MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
- RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
- ]
201
- )
202
-
203
- if stopping_criteria is None:
204
- stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
-
206
- stopping_criteria = StoppingCriteriaList(
207
- stopping_criteria
208
- )
209
-
210
- device = image.device
211
-
212
- if generation_type == "beam_search":
213
- output = self._generate_beamsearch(
214
- image_inputs = image,
215
- pad_token_id=pad_token_id,
216
- eos_token_id=eos_token_id,
217
- sot_token_id=sot_token_id,
218
- num_beams=num_beams,
219
- num_beam_groups=num_beam_groups,
220
- min_seq_len=min_seq_len,
221
- stopping_criteria=stopping_criteria,
222
- logit_processor=logit_processor,
223
- )
224
- if fixed_output_length and output.shape[1] < seq_len:
225
- return torch.cat(
226
- (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
- dim=1
228
- )
229
- return output
230
-
231
- elif generation_type == "top_p":
232
- logit_warper = GENERATION_TYPES[generation_type](top_p)
233
- elif generation_type == "top_k":
234
- logit_warper = GENERATION_TYPES[generation_type](top_k)
235
- else:
236
- raise ValueError(
237
- f"generation_type has to be one of "
238
- f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
- )
240
-
241
- image_latent, image_embs = self._encode_image(image)
242
-
243
- if text is None:
244
- text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
-
246
- was_training = self.training
247
- num_dims = len(text.shape)
248
-
249
- if num_dims == 1:
250
- text = text[None, :]
251
-
252
- cur_len = text.shape[1]
253
- self.eval()
254
- out = text
255
-
256
- while True:
257
- x = out[:, -max_seq_len:]
258
- cur_len = x.shape[1]
259
- logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
- mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
- sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
-
263
- if mask.all():
264
- if not fixed_output_length:
265
- break
266
- else:
267
- logits = logits[~mask, :]
268
- filtered_logits = logit_processor(x[~mask, :], logits)
269
- filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
- probs = F.softmax(filtered_logits / temperature, dim=-1)
271
-
272
- if (cur_len + 1 == seq_len):
273
- sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
- else:
275
- sample[~mask, :] = torch.multinomial(probs, 1)
276
-
277
- out = torch.cat((out, sample), dim=-1)
278
-
279
- cur_len += 1
280
-
281
- if stopping_criteria(out, None):
282
- break
283
-
284
- if num_dims == 1:
285
- out = out.squeeze(0)
286
-
287
- self.train(was_training)
288
- return out
289
-
290
- def _generate_beamsearch(
291
- self,
292
- image_inputs,
293
- pad_token_id=None,
294
- eos_token_id=None,
295
- sot_token_id=None,
296
- num_beams=6,
297
- num_beam_groups=3,
298
- min_seq_len=5,
299
- stopping_criteria=None,
300
- logit_processor=None,
301
- logit_warper=None,
302
- ):
303
- device = image_inputs.device
304
- batch_size = image_inputs.shape[0]
305
- image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
- image_latent, image_embs = self._encode_image(image_inputs)
307
-
308
- input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
- input_ids = input_ids * sot_token_id
310
- beam_scorer = BeamSearchScorer(
311
- batch_size=batch_size,
312
- num_beams=num_beams,
313
- device=device,
314
- num_beam_groups=num_beam_groups,
315
- )
316
- # instantiate logits processors
317
- logits_processor = (
318
- LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
- if logit_processor is None
320
- else logit_processor
321
- )
322
-
323
- batch_size = len(beam_scorer._beam_hyps)
324
- num_beams = beam_scorer.num_beams
325
- num_beam_groups = beam_scorer.num_beam_groups
326
- num_sub_beams = num_beams // num_beam_groups
327
- batch_beam_size, cur_len = input_ids.shape
328
- beam_indices = None
329
-
330
- if num_beams * batch_size != batch_beam_size:
331
- raise ValueError(
332
- f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
- )
334
-
335
- beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
- # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
- # the same group don't produce same tokens everytime.
338
- beam_scores[:, ::num_sub_beams] = 0
339
- beam_scores = beam_scores.view((batch_size * num_beams,))
340
-
341
- while True:
342
-
343
- # predicted tokens in cur_len step
344
- current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
-
346
- # indices which will form the beams in the next time step
347
- reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
-
349
- # do one decoder step on all beams of all sentences in batch
350
- model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
- outputs = self(
352
- model_inputs['images'],
353
- model_inputs['text'],
354
- embed_cls=False,
355
- image_latent=image_latent,
356
- image_embs=image_embs
357
- )
358
-
359
- for beam_group_idx in range(num_beam_groups):
360
- group_start_idx = beam_group_idx * num_sub_beams
361
- group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
- group_size = group_end_idx - group_start_idx
363
-
364
- # indices of beams of current group among all sentences in batch
365
- batch_group_indices = []
366
-
367
- for batch_idx in range(batch_size):
368
- batch_group_indices.extend(
369
- [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
- )
371
- group_input_ids = input_ids[batch_group_indices]
372
-
373
- # select outputs of beams of currentg group only
374
- next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
- vocab_size = next_token_logits.shape[-1]
376
-
377
- next_token_scores_processed = logits_processor(
378
- group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
- )
380
- next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
- next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
-
383
- # reshape for beam search
384
- next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
-
386
- next_token_scores, next_tokens = torch.topk(
387
- next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
- )
389
-
390
- next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
- next_tokens = next_tokens % vocab_size
392
-
393
- # stateless
394
- process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
- beam_outputs = beam_scorer.process(
396
- group_input_ids,
397
- next_token_scores,
398
- next_tokens,
399
- next_indices,
400
- pad_token_id=pad_token_id,
401
- eos_token_id=eos_token_id,
402
- beam_indices=process_beam_indices,
403
- )
404
- beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
- beam_next_tokens = beam_outputs["next_beam_tokens"]
406
- beam_idx = beam_outputs["next_beam_indices"]
407
-
408
- input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
- group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
- current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
-
412
- # (beam_idx // group_size) -> batch_idx
413
- # (beam_idx % group_size) -> offset of idx inside the group
414
- reordering_indices[batch_group_indices] = (
415
- num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
- )
417
-
418
- input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
-
420
- # increase cur_len
421
- cur_len = cur_len + 1
422
- if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
- break
424
-
425
- final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
- sequence_outputs = beam_scorer.finalize(
427
- input_ids,
428
- beam_scores,
429
- next_tokens,
430
- next_indices,
431
- pad_token_id=pad_token_id,
432
- eos_token_id=eos_token_id,
433
- max_length=stopping_criteria.max_length,
434
- beam_indices=final_beam_indices,
435
- )
436
- return sequence_outputs['sequences']
437
-
438
-
439
- def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
- if past:
441
- input_ids = input_ids[:, -1].unsqueeze(-1)
442
-
443
- attention_mask = kwargs.get("attention_mask", None)
444
- position_ids = kwargs.get("position_ids", None)
445
-
446
- if attention_mask is not None and position_ids is None:
447
- # create position_ids on the fly for batch generation
448
- position_ids = attention_mask.long().cumsum(-1) - 1
449
- position_ids.masked_fill_(attention_mask == 0, 1)
450
- else:
451
- position_ids = None
452
- return {
453
- "text": input_ids,
454
- "images": image_inputs,
455
- "past_key_values": past,
456
- "position_ids": position_ids,
457
- "attention_mask": attention_mask,
458
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/constants.py DELETED
@@ -1,2 +0,0 @@
1
- OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
- OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py DELETED
@@ -1,433 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
- # from turtle import forward
9
- from typing import Any, Dict, Optional, Tuple, Union
10
-
11
- import torch
12
-
13
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
14
- from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
15
- resize_pos_embed, get_cast_dtype
16
- from .coca_model import CoCa
17
- from .loss import ClipLoss, DistillClipLoss, CoCaLoss
18
- from .openai import load_openai_model
19
- from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
20
- from .transform import image_transform, AugmentationCfg
21
- from .tokenizer import HFTokenizer, SimpleTokenizer
22
-
23
-
24
- HF_HUB_PREFIX = 'hf-hub:'
25
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
-
28
-
29
- def _natural_key(string_):
30
- return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
31
-
32
-
33
- def _rescan_model_configs():
34
- global _MODEL_CONFIGS
35
-
36
- config_ext = ('.json',)
37
- config_files = []
38
- for config_path in _MODEL_CONFIG_PATHS:
39
- if config_path.is_file() and config_path.suffix in config_ext:
40
- config_files.append(config_path)
41
- elif config_path.is_dir():
42
- for ext in config_ext:
43
- config_files.extend(config_path.glob(f'*{ext}'))
44
-
45
- for cf in config_files:
46
- with open(cf, 'r') as f:
47
- model_cfg = json.load(f)
48
- if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
49
- _MODEL_CONFIGS[cf.stem] = model_cfg
50
-
51
- _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
52
-
53
-
54
- _rescan_model_configs() # initial populate of model config registry
55
-
56
-
57
- def list_models():
58
- """ enumerate available model architectures based on config files """
59
- return list(_MODEL_CONFIGS.keys())
60
-
61
-
62
- def add_model_config(path):
63
- """ add model config path or file and update registry """
64
- if not isinstance(path, Path):
65
- path = Path(path)
66
- _MODEL_CONFIG_PATHS.append(path)
67
- _rescan_model_configs()
68
-
69
-
70
- def get_model_config(model_name):
71
- if model_name in _MODEL_CONFIGS:
72
- return deepcopy(_MODEL_CONFIGS[model_name])
73
- else:
74
- return None
75
-
76
-
77
- def get_tokenizer(model_name, open_clip_bpe_path=None):
78
- if model_name.startswith(HF_HUB_PREFIX):
79
- tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
80
- else:
81
- config = get_model_config(model_name)
82
- tokenizer = HFTokenizer(
83
- config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
84
- return tokenizer
85
-
86
-
87
- def load_state_dict(checkpoint_path: str, map_location='cpu'):
88
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
89
- if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
90
- state_dict = checkpoint['state_dict']
91
- else:
92
- state_dict = checkpoint
93
- if next(iter(state_dict.items()))[0].startswith('module'):
94
- state_dict = {k[7:]: v for k, v in state_dict.items()}
95
- return state_dict
96
-
97
-
98
- def load_checkpoint(model, checkpoint_path, strict=True):
99
- state_dict = load_state_dict(checkpoint_path)
100
- # detect old format and make compatible with new format
101
- if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
102
- state_dict = convert_to_custom_text_state_dict(state_dict)
103
- resize_pos_embed(state_dict, model)
104
- incompatible_keys = model.load_state_dict(state_dict, strict=strict)
105
- return incompatible_keys
106
-
107
-
108
- def create_model(
109
- model_name: str,
110
- pretrained: Optional[str] = None,
111
- precision: str = 'fp32',
112
- device: Union[str, torch.device] = 'cpu',
113
- jit: bool = False,
114
- force_quick_gelu: bool = False,
115
- force_custom_text: bool = False,
116
- force_patch_dropout: Optional[float] = None,
117
- force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
118
- pretrained_image: bool = False,
119
- pretrained_hf: bool = True,
120
- cache_dir: Optional[str] = None,
121
- output_dict: Optional[bool] = None,
122
- require_pretrained: bool = False,
123
- ):
124
- has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
125
- if has_hf_hub_prefix:
126
- model_id = model_name[len(HF_HUB_PREFIX):]
127
- checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
128
- config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
129
-
130
- with open(config_path, 'r', encoding='utf-8') as f:
131
- config = json.load(f)
132
- pretrained_cfg = config['preprocess_cfg']
133
- model_cfg = config['model_cfg']
134
- else:
135
- model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
136
- checkpoint_path = None
137
- pretrained_cfg = {}
138
- model_cfg = None
139
-
140
- if isinstance(device, str):
141
- device = torch.device(device)
142
-
143
- if pretrained and pretrained.lower() == 'openai':
144
- logging.info(f'Loading pretrained {model_name} from OpenAI.')
145
- model = load_openai_model(
146
- model_name,
147
- precision=precision,
148
- device=device,
149
- jit=jit,
150
- cache_dir=cache_dir,
151
- )
152
-
153
- # to always output dict even if it is clip
154
- if output_dict and hasattr(model, "output_dict"):
155
- model.output_dict = True
156
- else:
157
- model_cfg = model_cfg or get_model_config(model_name)
158
- if model_cfg is not None:
159
- logging.info(f'Loaded {model_name} model config.')
160
- else:
161
- logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
162
- raise RuntimeError(f'Model config for {model_name} not found.')
163
-
164
- if force_quick_gelu:
165
- # override for use of QuickGELU on non-OpenAI transformer models
166
- model_cfg["quick_gelu"] = True
167
-
168
- if force_patch_dropout is not None:
169
- # override the default patch dropout value
170
- model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
171
-
172
- if force_image_size is not None:
173
- # override model config's image size
174
- model_cfg["vision_cfg"]["image_size"] = force_image_size
175
-
176
- if pretrained_image:
177
- if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
178
- # pretrained weight loading for timm models set via vision_cfg
179
- model_cfg['vision_cfg']['timm_model_pretrained'] = True
180
- else:
181
- assert False, 'pretrained image towers currently only supported for timm models'
182
-
183
- cast_dtype = get_cast_dtype(precision)
184
- is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
185
- custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
186
-
187
- if custom_text:
188
- if is_hf_model:
189
- model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
190
- if "coca" in model_name:
191
- model = CoCa(**model_cfg, cast_dtype=cast_dtype)
192
- else:
193
- model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
194
- else:
195
- model = CLIP(**model_cfg, cast_dtype=cast_dtype)
196
-
197
- pretrained_loaded = False
198
- if pretrained:
199
- checkpoint_path = ''
200
- pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
201
- if pretrained_cfg:
202
- checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
203
- elif os.path.exists(pretrained):
204
- checkpoint_path = pretrained
205
-
206
- if checkpoint_path:
207
- logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
208
- load_checkpoint(model, checkpoint_path)
209
- else:
210
- error_str = (
211
- f'Pretrained weights ({pretrained}) not found for model {model_name}.'
212
- f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
213
- logging.warning(error_str)
214
- raise RuntimeError(error_str)
215
- pretrained_loaded = True
216
- elif has_hf_hub_prefix:
217
- logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
218
- load_checkpoint(model, checkpoint_path)
219
- pretrained_loaded = True
220
-
221
- if require_pretrained and not pretrained_loaded:
222
- # callers of create_model_from_pretrained always expect pretrained weights
223
- raise RuntimeError(
224
- f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
225
-
226
- model.to(device=device)
227
- if precision in ("fp16", "bf16"):
228
- convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
229
-
230
- # set image / mean metadata from pretrained_cfg if available, or use default
231
- model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
232
- model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
233
-
234
- # to always output dict even if it is clip
235
- if output_dict and hasattr(model, "output_dict"):
236
- model.output_dict = True
237
-
238
- if jit:
239
- model = torch.jit.script(model)
240
-
241
- return model
242
-
243
-
244
- def create_loss(args):
245
- if args.distill:
246
- return DistillClipLoss(
247
- local_loss=args.local_loss,
248
- gather_with_grad=args.gather_with_grad,
249
- cache_labels=True,
250
- rank=args.rank,
251
- world_size=args.world_size,
252
- use_horovod=args.horovod,
253
- )
254
- elif "coca" in args.model.lower():
255
- return CoCaLoss(
256
- caption_loss_weight=args.coca_caption_loss_weight,
257
- clip_loss_weight=args.coca_contrastive_loss_weight,
258
- local_loss=args.local_loss,
259
- gather_with_grad=args.gather_with_grad,
260
- cache_labels=True,
261
- rank=args.rank,
262
- world_size=args.world_size,
263
- use_horovod=args.horovod,
264
- )
265
- return ClipLoss(
266
- local_loss=args.local_loss,
267
- gather_with_grad=args.gather_with_grad,
268
- cache_labels=True,
269
- rank=args.rank,
270
- world_size=args.world_size,
271
- use_horovod=args.horovod,
272
- )
273
-
274
- class MLP(torch.nn.Module):
275
- def __init__(self, input_size):
276
- super().__init__()
277
- self.input_size = input_size
278
- self.layers = torch.nn.Sequential(
279
- torch.nn.Linear(self.input_size, 1024),
280
- torch.nn.Dropout(0.2),
281
- torch.nn.Linear(1024, 128),
282
- torch.nn.Dropout(0.2),
283
- torch.nn.Linear(128, 64),
284
- torch.nn.Dropout(0.1),
285
- torch.nn.Linear(64, 16),
286
- torch.nn.Linear(16, 1)
287
- )
288
-
289
- def forward(self, x):
290
- return self.layers(x)
291
-
292
- # class semantic_head(torch.nn.Module):
293
- # def __init__(self, input_size):
294
- # super().__init__()
295
- # self.input_size = input_size # for ViT-L-14 is 1024
296
- # self.seg_head = torch.nn.Sequential(
297
- # torch.nn.Linear(input_size, 128),
298
- # torch.nn.Dropout(0.2),
299
- # torch.nn.Linear(128, 64),
300
- # torch.nn.Dropout(0.1),
301
- # torch.nn.Linear(64, 16),
302
- # torch.nn.Linear(16, 1),
303
- # )
304
- # self.sigmoid = torch.nn.Sigmoid()
305
-
306
- # def forward(self, x):
307
- # return self.sigmoid(self.seg_head(x))
308
-
309
- def create_model_and_transforms(
310
- model_name: str,
311
- pretrained: Optional[str] = None,
312
- precision: str = 'fp32',
313
- device: Union[str, torch.device] = 'cpu',
314
- jit: bool = False,
315
- force_quick_gelu: bool = False,
316
- force_custom_text: bool = False,
317
- force_patch_dropout: Optional[float] = None,
318
- force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
319
- pretrained_image: bool = False,
320
- pretrained_hf: bool = True,
321
- image_mean: Optional[Tuple[float, ...]] = None,
322
- image_std: Optional[Tuple[float, ...]] = None,
323
- aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
324
- cache_dir: Optional[str] = None,
325
- light_augmentation = False,
326
- output_dict: Optional[bool] = None,
327
- with_score_predictor: bool = False,
328
- with_region_predictor: bool = False
329
- ):
330
- model = create_model(
331
- model_name,
332
- pretrained,
333
- precision=precision,
334
- device=device,
335
- jit=jit,
336
- force_quick_gelu=force_quick_gelu,
337
- force_custom_text=force_custom_text,
338
- force_patch_dropout=force_patch_dropout,
339
- force_image_size=force_image_size,
340
- pretrained_image=pretrained_image,
341
- pretrained_hf=pretrained_hf,
342
- cache_dir=cache_dir,
343
- output_dict=output_dict,
344
- )
345
-
346
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
347
- image_std = image_std or getattr(model.visual, 'image_std', None)
348
-
349
- if with_score_predictor:
350
- model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
351
-
352
- if with_region_predictor:
353
- # model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
354
- model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
355
- # preprocess_train = image_transform_region(
356
- # model.visual.image_size,
357
- # is_train=True,
358
- # mean=image_mean,
359
- # std=image_std
360
- # )
361
- # preprocess_val = image_transform_region(
362
- # model.visual.image_size,
363
- # is_train=False,
364
- # mean=image_mean,
365
- # std=image_std
366
- # )
367
-
368
- if light_augmentation:
369
- preprocess_val = image_transform(
370
- model.visual.image_size,
371
- is_train=False,
372
- mean=image_mean,
373
- std=image_std,
374
- resize_longest_max=True,
375
- )
376
- preprocess_train = preprocess_val
377
- else:
378
- preprocess_train = image_transform(
379
- model.visual.image_size,
380
- is_train=True,
381
- mean=image_mean,
382
- std=image_std
383
- )
384
- preprocess_val = image_transform(
385
- model.visual.image_size,
386
- is_train=False,
387
- mean=image_mean,
388
- std=image_std
389
- )
390
-
391
- return model, preprocess_train, preprocess_val
392
-
393
-
394
- def create_model_from_pretrained(
395
- model_name: str,
396
- pretrained: Optional[str] = None,
397
- precision: str = 'fp32',
398
- device: Union[str, torch.device] = 'cpu',
399
- jit: bool = False,
400
- force_quick_gelu: bool = False,
401
- force_custom_text: bool = False,
402
- force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
403
- return_transform: bool = True,
404
- image_mean: Optional[Tuple[float, ...]] = None,
405
- image_std: Optional[Tuple[float, ...]] = None,
406
- cache_dir: Optional[str] = None,
407
- ):
408
- model = create_model(
409
- model_name,
410
- pretrained,
411
- precision=precision,
412
- device=device,
413
- jit=jit,
414
- force_quick_gelu=force_quick_gelu,
415
- force_custom_text=force_custom_text,
416
- force_image_size=force_image_size,
417
- cache_dir=cache_dir,
418
- require_pretrained=True,
419
- )
420
-
421
- if not return_transform:
422
- return model
423
-
424
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
425
- image_std = image_std or getattr(model.visual, 'image_std', None)
426
- preprocess = image_transform(
427
- model.visual.image_size,
428
- is_train=False,
429
- mean=image_mean,
430
- std=image_std,
431
- )
432
-
433
- return model, preprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py DELETED
File without changes
diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py DELETED
@@ -1,45 +0,0 @@
1
- # HF architecture dict:
2
- arch_dict = {
3
- # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
- "roberta": {
5
- "config_names": {
6
- "context_length": "max_position_embeddings",
7
- "vocab_size": "vocab_size",
8
- "width": "hidden_size",
9
- "heads": "num_attention_heads",
10
- "layers": "num_hidden_layers",
11
- "layer_attr": "layer",
12
- "token_embeddings_attr": "embeddings"
13
- },
14
- "pooler": "mean_pooler",
15
- },
16
- # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
- "xlm-roberta": {
18
- "config_names": {
19
- "context_length": "max_position_embeddings",
20
- "vocab_size": "vocab_size",
21
- "width": "hidden_size",
22
- "heads": "num_attention_heads",
23
- "layers": "num_hidden_layers",
24
- "layer_attr": "layer",
25
- "token_embeddings_attr": "embeddings"
26
- },
27
- "pooler": "mean_pooler",
28
- },
29
- # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
- "mt5": {
31
- "config_names": {
32
- # unlimited seqlen
33
- # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
- # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
- "context_length": "",
36
- "vocab_size": "vocab_size",
37
- "width": "d_model",
38
- "heads": "num_heads",
39
- "layers": "num_layers",
40
- "layer_attr": "block",
41
- "token_embeddings_attr": "embed_tokens"
42
- },
43
- "pooler": "mean_pooler",
44
- },
45
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py DELETED
@@ -1,176 +0,0 @@
1
- """ huggingface model adapter
2
-
3
- Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
- """
5
-
6
- import re
7
-
8
- import torch
9
- import torch.nn as nn
10
- from torch import TensorType
11
-
12
- try:
13
- import transformers
14
- from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
- BaseModelOutputWithPoolingAndCrossAttentions
17
- except ImportError as e:
18
- transformers = None
19
-
20
-
21
- class BaseModelOutput:
22
- pass
23
-
24
-
25
- class PretrainedConfig:
26
- pass
27
-
28
- from .hf_configs import arch_dict
29
-
30
-
31
- # utils
32
- def _camel2snake(s):
33
- return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
34
-
35
-
36
- # TODO: ?last - for gpt-like models
37
- _POOLERS = {}
38
-
39
-
40
- def register_pooler(cls):
41
- """Decorator registering pooler class"""
42
- _POOLERS[_camel2snake(cls.__name__)] = cls
43
- return cls
44
-
45
-
46
- @register_pooler
47
- class MeanPooler(nn.Module):
48
- """Mean pooling"""
49
-
50
- def forward(self, x: BaseModelOutput, attention_mask: TensorType):
51
- masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
52
- return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
53
-
54
-
55
- @register_pooler
56
- class MaxPooler(nn.Module):
57
- """Max pooling"""
58
-
59
- def forward(self, x: BaseModelOutput, attention_mask: TensorType):
60
- masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
61
- return masked_output.max(1).values
62
-
63
-
64
- @register_pooler
65
- class ClsPooler(nn.Module):
66
- """CLS token pooling"""
67
-
68
- def __init__(self, use_pooler_output=True):
69
- super().__init__()
70
- self.cls_token_position = 0
71
- self.use_pooler_output = use_pooler_output
72
-
73
- def forward(self, x: BaseModelOutput, attention_mask: TensorType):
74
- if (self.use_pooler_output and
75
- isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
76
- (x.pooler_output is not None)
77
- ):
78
- return x.pooler_output
79
-
80
- return x.last_hidden_state[:, self.cls_token_position, :]
81
-
82
-
83
- class HFTextEncoder(nn.Module):
84
- """HuggingFace model adapter"""
85
- output_tokens: torch.jit.Final[bool]
86
-
87
- def __init__(
88
- self,
89
- model_name_or_path: str,
90
- output_dim: int,
91
- config: PretrainedConfig = None,
92
- pooler_type: str = None,
93
- proj: str = None,
94
- pretrained: bool = True,
95
- output_tokens: bool = False,
96
- ):
97
- super().__init__()
98
- self.output_tokens = output_tokens
99
- self.output_dim = output_dim
100
-
101
- # TODO: find better way to get this information
102
- uses_transformer_pooler = (pooler_type == "cls_pooler")
103
-
104
- if transformers is None:
105
- raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
106
- if config is None:
107
- self.config = AutoConfig.from_pretrained(model_name_or_path)
108
- create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
109
- AutoModel.from_config, self.config)
110
- # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
111
- if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
112
- self.transformer = create_func(model_args)
113
- self.transformer = self.transformer.encoder
114
- else:
115
- self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
116
- else:
117
- self.config = config
118
- self.transformer = AutoModel.from_config(config)
119
- if pooler_type is None: # get default arch pooler
120
- pooler_type = (arch_dict[self.config.model_type]["pooler"])
121
-
122
- self.pooler = _POOLERS[pooler_type]()
123
-
124
- d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
125
- if (d_model == output_dim) and (proj is None): # do we always need a proj?
126
- self.proj = nn.Identity()
127
- elif proj == 'linear':
128
- self.proj = nn.Linear(d_model, output_dim, bias=False)
129
- elif proj == 'mlp':
130
- hidden_size = (d_model + output_dim) // 2
131
- self.proj = nn.Sequential(
132
- nn.Linear(d_model, hidden_size, bias=False),
133
- nn.GELU(),
134
- nn.Linear(hidden_size, output_dim, bias=False),
135
- )
136
-
137
- def forward(self, x: TensorType):
138
- attn_mask = (x != self.config.pad_token_id).long()
139
- out = self.transformer(input_ids=x, attention_mask=attn_mask)
140
- pooled_out = self.pooler(out, attn_mask)
141
- projected = self.proj(pooled_out)
142
-
143
- seq_len = out.last_hidden_state.shape[1]
144
- tokens = (
145
- out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
146
- if type(self.pooler) == ClsPooler
147
- else out.last_hidden_state
148
- )
149
-
150
- if self.output_tokens:
151
- return projected, tokens
152
- return projected
153
-
154
- def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
155
- if not unlocked_layers: # full freezing
156
- for n, p in self.transformer.named_parameters():
157
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
158
- return
159
-
160
- encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
161
- layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
162
- print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
163
- embeddings = getattr(
164
- self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
165
- modules = [embeddings, *layer_list][:-unlocked_layers]
166
- # freeze layers
167
- for module in modules:
168
- for n, p in module.named_parameters():
169
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
170
-
171
- @torch.jit.ignore
172
- def set_grad_checkpointing(self, enable=True):
173
- self.transformer.gradient_checkpointing_enable()
174
-
175
- def init_parameters(self):
176
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py DELETED
@@ -1,270 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import functional as F
4
- from torch.nn.utils.rnn import pad_sequence
5
-
6
- try:
7
- import torch.distributed.nn
8
- from torch import distributed as dist
9
-
10
- has_distributed = True
11
- except ImportError:
12
- has_distributed = False
13
-
14
- try:
15
- import horovod.torch as hvd
16
- except ImportError:
17
- hvd = None
18
-
19
-
20
- def gather_features(
21
- image_features,
22
- text_features,
23
- local_loss=False,
24
- gather_with_grad=False,
25
- rank=0,
26
- world_size=1,
27
- use_horovod=False
28
- ):
29
- assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
30
- if use_horovod:
31
- assert hvd is not None, 'Please install horovod'
32
- if gather_with_grad:
33
- all_image_features = hvd.allgather(image_features)
34
- all_text_features = hvd.allgather(text_features)
35
- else:
36
- with torch.no_grad():
37
- all_image_features = hvd.allgather(image_features)
38
- all_text_features = hvd.allgather(text_features)
39
- if not local_loss:
40
- # ensure grads for local rank when all_* features don't have a gradient
41
- gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
42
- gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
43
- gathered_image_features[rank] = image_features
44
- gathered_text_features[rank] = text_features
45
- all_image_features = torch.cat(gathered_image_features, dim=0)
46
- all_text_features = torch.cat(gathered_text_features, dim=0)
47
- else:
48
- # We gather tensors from all gpus
49
- if gather_with_grad:
50
- all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
51
- all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
52
- else:
53
- gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
54
- gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
55
- dist.all_gather(gathered_image_features, image_features)
56
- dist.all_gather(gathered_text_features, text_features)
57
- if not local_loss:
58
- # ensure grads for local rank when all_* features don't have a gradient
59
- gathered_image_features[rank] = image_features
60
- gathered_text_features[rank] = text_features
61
- all_image_features = torch.cat(gathered_image_features, dim=0)
62
- all_text_features = torch.cat(gathered_text_features, dim=0)
63
-
64
- return all_image_features, all_text_features
65
-
66
-
67
- class ClipLoss(nn.Module):
68
-
69
- def __init__(
70
- self,
71
- local_loss=False,
72
- gather_with_grad=False,
73
- cache_labels=False,
74
- rank=0,
75
- world_size=1,
76
- use_horovod=False,
77
- ):
78
- super().__init__()
79
- self.local_loss = local_loss
80
- self.gather_with_grad = gather_with_grad
81
- self.cache_labels = cache_labels
82
- self.rank = rank
83
- self.world_size = world_size
84
- self.use_horovod = use_horovod
85
-
86
- # cache state
87
- self.prev_num_logits = 0
88
- self.labels = {}
89
-
90
- def get_ground_truth(self, device, num_logits) -> torch.Tensor:
91
- # calculated ground-truth and cache if enabled
92
- if self.prev_num_logits != num_logits or device not in self.labels:
93
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
94
- if self.world_size > 1 and self.local_loss:
95
- labels = labels + num_logits * self.rank
96
- if self.cache_labels:
97
- self.labels[device] = labels
98
- self.prev_num_logits = num_logits
99
- else:
100
- labels = self.labels[device]
101
- return labels
102
-
103
- def get_logits(self, image_features, text_features, logit_scale):
104
- if self.world_size > 1:
105
- all_image_features, all_text_features = gather_features(
106
- image_features, text_features,
107
- self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
108
-
109
- if self.local_loss:
110
- logits_per_image = logit_scale * image_features @ all_text_features.T
111
- logits_per_text = logit_scale * text_features @ all_image_features.T
112
- else:
113
- logits_per_image = logit_scale * all_image_features @ all_text_features.T
114
- logits_per_text = logits_per_image.T
115
- else:
116
- logits_per_image = logit_scale * image_features @ text_features.T
117
- logits_per_text = logit_scale * text_features @ image_features.T
118
-
119
- return logits_per_image, logits_per_text
120
-
121
- def forward(self, image_features, text_features, logit_scale, output_dict=False):
122
- device = image_features.device
123
- logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
124
-
125
- labels = self.get_ground_truth(device, logits_per_image.shape[0])
126
-
127
- total_loss = (
128
- F.cross_entropy(logits_per_image, labels) +
129
- F.cross_entropy(logits_per_text, labels)
130
- ) / 2
131
- return total_loss
132
-
133
- class PreferenceLoss(nn.Module):
134
-
135
- def forward(self, logits_per_image, num_images, labels):
136
-
137
- paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
138
- paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
139
-
140
- ce_loss = F.cross_entropy(paired_logits, labels)
141
- return ce_loss
142
-
143
- class HPSLoss(nn.Module):
144
-
145
- def forward(self, text_logits, labels):
146
-
147
- device = text_logits.device
148
- text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
149
- label_0, label_1 = labels.chunk(2, dim=-1)
150
-
151
- index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
152
- text_0_logits = text_0_logits[index, index]
153
- text_1_logits = text_1_logits[index, index]
154
- text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
155
- text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
156
- text_1_labels = text_0_labels + 1
157
-
158
- text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
159
- text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
160
-
161
- text_loss = label_0 * text_0_loss + label_1 * text_1_loss
162
-
163
- # absolute_example_weight = 1 / num_per_prompt
164
- # denominator = absolute_example_weight.sum()
165
- # weight_per_example = absolute_example_weight / denominator
166
- # text_loss *= weight_per_example
167
-
168
- text_loss = text_loss.sum()
169
- return text_loss
170
-
171
- class RankingLoss(nn.Module):
172
-
173
- def forward(self, logits_per_image, num_images, labels, margin = 1.0):
174
- paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
175
- label_list = [label for label in labels.split(num_images.tolist())]
176
- # ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
177
-
178
- paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
179
- padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
180
-
181
- # regulized_logits = torch.log(torch.sigmoid(paired_logits))
182
-
183
- diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
184
- # diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
185
- # diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
186
- diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
187
- mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
188
-
189
- loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
190
- return loss
191
-
192
- class CoCaLoss(ClipLoss):
193
- def __init__(
194
- self,
195
- caption_loss_weight,
196
- clip_loss_weight,
197
- pad_id=0, # pad_token for open_clip custom tokenizer
198
- local_loss=False,
199
- gather_with_grad=False,
200
- cache_labels=False,
201
- rank=0,
202
- world_size=1,
203
- use_horovod=False,
204
- ):
205
- super().__init__(
206
- local_loss=local_loss,
207
- gather_with_grad=gather_with_grad,
208
- cache_labels=cache_labels,
209
- rank=rank,
210
- world_size=world_size,
211
- use_horovod=use_horovod
212
- )
213
-
214
- self.clip_loss_weight = clip_loss_weight
215
- self.caption_loss_weight = caption_loss_weight
216
- self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
217
-
218
- def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
219
- clip_loss = super().forward(image_features, text_features, logit_scale)
220
- clip_loss = self.clip_loss_weight * clip_loss
221
-
222
- caption_loss = self.caption_loss(
223
- logits.permute(0, 2, 1),
224
- labels,
225
- )
226
- caption_loss = caption_loss * self.caption_loss_weight
227
-
228
- if output_dict:
229
- return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
230
-
231
- return clip_loss, caption_loss
232
-
233
-
234
- class DistillClipLoss(ClipLoss):
235
-
236
- def dist_loss(self, teacher_logits, student_logits):
237
- return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
238
-
239
- def forward(
240
- self,
241
- image_features,
242
- text_features,
243
- logit_scale,
244
- dist_image_features,
245
- dist_text_features,
246
- dist_logit_scale,
247
- output_dict=False,
248
- ):
249
- logits_per_image, logits_per_text = \
250
- self.get_logits(image_features, text_features, logit_scale)
251
-
252
- dist_logits_per_image, dist_logits_per_text = \
253
- self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
254
-
255
- labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
256
-
257
- contrastive_loss = (
258
- F.cross_entropy(logits_per_image, labels) +
259
- F.cross_entropy(logits_per_text, labels)
260
- ) / 2
261
-
262
- distill_loss = (
263
- self.dist_loss(dist_logits_per_image, logits_per_image) +
264
- self.dist_loss(dist_logits_per_text, logits_per_text)
265
- ) / 2
266
-
267
- if output_dict:
268
- return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
269
-
270
- return contrastive_loss, distill_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/model.py DELETED
@@ -1,461 +0,0 @@
1
- """ CLIP Model
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- from dataclasses import dataclass
6
- import logging
7
- import math
8
- from typing import Optional, Tuple, Union
9
-
10
- import numpy as np
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import nn
14
- from torch.utils.checkpoint import checkpoint
15
-
16
- from .hf_model import HFTextEncoder
17
- from .modified_resnet import ModifiedResNet
18
- from .timm_model import TimmModel
19
- from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
- from .utils import to_2tuple
21
-
22
-
23
- @dataclass
24
- class CLIPVisionCfg:
25
- layers: Union[Tuple[int, int, int, int], int] = 12
26
- width: int = 768
27
- head_width: int = 64
28
- mlp_ratio: float = 4.0
29
- patch_size: int = 16
30
- image_size: Union[Tuple[int, int], int] = 224
31
- ls_init_value: Optional[float] = None # layer scale initial value
32
- patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
33
- input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
34
- global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
35
- attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
36
- n_queries: int = 256 # n_queries for attentional pooler
37
- attn_pooler_heads: int = 8 # n heads for attentional_pooling
38
- timm_model_name: str = None # a valid model name overrides layers, width, patch_size
39
- timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
40
- timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
41
- timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
42
- timm_proj_bias: bool = False # enable bias final projection
43
- timm_drop: float = 0. # head dropout
44
- timm_drop_path: Optional[float] = None # backbone stochastic depth
45
- output_tokens: bool = False
46
-
47
-
48
- @dataclass
49
- class CLIPTextCfg:
50
- context_length: int = 77
51
- vocab_size: int = 49408
52
- width: int = 512
53
- heads: int = 8
54
- layers: int = 12
55
- ls_init_value: Optional[float] = None # layer scale initial value
56
- hf_model_name: str = None
57
- hf_tokenizer_name: str = None
58
- hf_model_pretrained: bool = True
59
- proj: str = 'mlp'
60
- pooler_type: str = 'mean_pooler'
61
- embed_cls: bool = False
62
- pad_id: int = 0
63
- output_tokens: bool = False
64
-
65
-
66
- def get_cast_dtype(precision: str):
67
- cast_dtype = None
68
- if precision == 'bf16':
69
- cast_dtype = torch.bfloat16
70
- elif precision == 'fp16':
71
- cast_dtype = torch.float16
72
- return cast_dtype
73
-
74
-
75
- def _build_vision_tower(
76
- embed_dim: int,
77
- vision_cfg: CLIPVisionCfg,
78
- quick_gelu: bool = False,
79
- cast_dtype: Optional[torch.dtype] = None
80
- ):
81
- if isinstance(vision_cfg, dict):
82
- vision_cfg = CLIPVisionCfg(**vision_cfg)
83
-
84
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
85
- # memory efficient in recent PyTorch releases (>= 1.10).
86
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
87
- act_layer = QuickGELU if quick_gelu else nn.GELU
88
-
89
- if vision_cfg.timm_model_name:
90
- visual = TimmModel(
91
- vision_cfg.timm_model_name,
92
- pretrained=vision_cfg.timm_model_pretrained,
93
- pool=vision_cfg.timm_pool,
94
- proj=vision_cfg.timm_proj,
95
- proj_bias=vision_cfg.timm_proj_bias,
96
- drop=vision_cfg.timm_drop,
97
- drop_path=vision_cfg.timm_drop_path,
98
- embed_dim=embed_dim,
99
- image_size=vision_cfg.image_size,
100
- )
101
- act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
102
- elif isinstance(vision_cfg.layers, (tuple, list)):
103
- vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
104
- visual = ModifiedResNet(
105
- layers=vision_cfg.layers,
106
- output_dim=embed_dim,
107
- heads=vision_heads,
108
- image_size=vision_cfg.image_size,
109
- width=vision_cfg.width,
110
- )
111
- else:
112
- vision_heads = vision_cfg.width // vision_cfg.head_width
113
- norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
114
- visual = VisionTransformer(
115
- image_size=vision_cfg.image_size,
116
- patch_size=vision_cfg.patch_size,
117
- width=vision_cfg.width,
118
- layers=vision_cfg.layers,
119
- heads=vision_heads,
120
- mlp_ratio=vision_cfg.mlp_ratio,
121
- ls_init_value=vision_cfg.ls_init_value,
122
- patch_dropout=vision_cfg.patch_dropout,
123
- input_patchnorm=vision_cfg.input_patchnorm,
124
- global_average_pool=vision_cfg.global_average_pool,
125
- attentional_pool=vision_cfg.attentional_pool,
126
- n_queries=vision_cfg.n_queries,
127
- attn_pooler_heads=vision_cfg.attn_pooler_heads,
128
- output_tokens=vision_cfg.output_tokens,
129
- output_dim=embed_dim,
130
- act_layer=act_layer,
131
- norm_layer=norm_layer,
132
- )
133
-
134
- return visual
135
-
136
-
137
- def _build_text_tower(
138
- embed_dim: int,
139
- text_cfg: CLIPTextCfg,
140
- quick_gelu: bool = False,
141
- cast_dtype: Optional[torch.dtype] = None,
142
- ):
143
- if isinstance(text_cfg, dict):
144
- text_cfg = CLIPTextCfg(**text_cfg)
145
-
146
- if text_cfg.hf_model_name:
147
- text = HFTextEncoder(
148
- text_cfg.hf_model_name,
149
- output_dim=embed_dim,
150
- proj=text_cfg.proj,
151
- pooler_type=text_cfg.pooler_type,
152
- pretrained=text_cfg.hf_model_pretrained,
153
- output_tokens=text_cfg.output_tokens,
154
- )
155
- else:
156
- act_layer = QuickGELU if quick_gelu else nn.GELU
157
- norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
158
-
159
- text = TextTransformer(
160
- context_length=text_cfg.context_length,
161
- vocab_size=text_cfg.vocab_size,
162
- width=text_cfg.width,
163
- heads=text_cfg.heads,
164
- layers=text_cfg.layers,
165
- ls_init_value=text_cfg.ls_init_value,
166
- output_dim=embed_dim,
167
- embed_cls=text_cfg.embed_cls,
168
- output_tokens=text_cfg.output_tokens,
169
- pad_id=text_cfg.pad_id,
170
- act_layer=act_layer,
171
- norm_layer=norm_layer,
172
- )
173
- return text
174
-
175
-
176
- class CLIP(nn.Module):
177
- output_dict: torch.jit.Final[bool]
178
-
179
- def __init__(
180
- self,
181
- embed_dim: int,
182
- vision_cfg: CLIPVisionCfg,
183
- text_cfg: CLIPTextCfg,
184
- quick_gelu: bool = False,
185
- cast_dtype: Optional[torch.dtype] = None,
186
- output_dict: bool = False,
187
- ):
188
- super().__init__()
189
- self.output_dict = output_dict
190
- self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
191
-
192
- text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
193
- self.transformer = text.transformer
194
- self.vocab_size = text.vocab_size
195
- self.token_embedding = text.token_embedding
196
- self.positional_embedding = text.positional_embedding
197
- self.ln_final = text.ln_final
198
- self.text_projection = text.text_projection
199
- self.register_buffer('attn_mask', text.attn_mask, persistent=False)
200
-
201
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
202
-
203
- def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
204
- # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
205
- self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
206
-
207
- def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
208
- locked_layers = []
209
- locked_layers.append(self.token_embedding)
210
- self.positional_embedding.requires_grad = False
211
- if unlocked_layers > 0:
212
- locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
213
- else:
214
- locked_layers.append(self.transformer)
215
- locked_layers.append(self.ln_final)
216
- self.text_projection.requires_grad = False
217
-
218
- # freeze layers
219
- for module in locked_layers:
220
- for n, p in module.named_parameters():
221
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
222
-
223
- @torch.jit.ignore
224
- def set_grad_checkpointing(self, enable=True):
225
- self.visual.set_grad_checkpointing(enable)
226
- self.transformer.grad_checkpointing = enable
227
-
228
- def encode_image(self, image, normalize: bool = False):
229
- features = self.visual(image)
230
- return F.normalize(features, dim=-1) if normalize else features
231
-
232
- def encode_text(self, text, normalize: bool = False):
233
- cast_dtype = self.transformer.get_cast_dtype()
234
-
235
- x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
236
-
237
- x = x + self.positional_embedding.to(cast_dtype)
238
- x = x.permute(1, 0, 2) # NLD -> LND
239
- x = self.transformer(x, attn_mask=self.attn_mask)
240
- x = x.permute(1, 0, 2) # LND -> NLD
241
- x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
242
- # take features from the eot embedding (eot_token is the highest number in each sequence)
243
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
244
- return F.normalize(x, dim=-1) if normalize else x
245
-
246
- def forward(self, image, text):
247
- image_features = self.encode_image(image, normalize=True)
248
- text_features = self.encode_text(text, normalize=True)
249
- if self.output_dict:
250
- return {
251
- "image_features": image_features,
252
- "text_features": text_features,
253
- "logit_scale": self.logit_scale.exp()
254
- }
255
- return image_features, text_features, self.logit_scale.exp()
256
-
257
-
258
- class CustomTextCLIP(nn.Module):
259
- output_dict: torch.jit.Final[bool]
260
-
261
- def __init__(
262
- self,
263
- embed_dim: int,
264
- vision_cfg: CLIPVisionCfg,
265
- text_cfg: CLIPTextCfg,
266
- quick_gelu: bool = False,
267
- cast_dtype: Optional[torch.dtype] = None,
268
- output_dict: bool = False,
269
- ):
270
- super().__init__()
271
- self.output_dict = output_dict
272
- self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
273
- self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
274
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
275
-
276
- def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
277
- # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
278
- self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
279
-
280
- def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
281
- self.text.lock(unlocked_layers, freeze_layer_norm)
282
-
283
- @torch.jit.ignore
284
- def set_grad_checkpointing(self, enable=True):
285
- self.visual.set_grad_checkpointing(enable)
286
- self.text.set_grad_checkpointing(enable)
287
-
288
- def encode_image(self, image, normalize: bool = False):
289
- features = self.visual(image)
290
- return F.normalize(features, dim=-1) if normalize else features
291
-
292
- def encode_text(self, text, normalize: bool = False):
293
- features = self.text(text)
294
- return F.normalize(features, dim=-1) if normalize else features
295
-
296
- def forward(self, image, text):
297
- image_features = self.encode_image(image, normalize=True)
298
- text_features = self.encode_text(text, normalize=True)
299
- if self.output_dict:
300
- return {
301
- "image_features": image_features,
302
- "text_features": text_features,
303
- "logit_scale": self.logit_scale.exp()
304
- }
305
- return image_features, text_features, self.logit_scale.exp()
306
-
307
-
308
- def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
309
- """Convert applicable model parameters to low-precision (bf16 or fp16)"""
310
-
311
- def _convert_weights(l):
312
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
313
- l.weight.data = l.weight.data.to(dtype)
314
- if l.bias is not None:
315
- l.bias.data = l.bias.data.to(dtype)
316
-
317
- if isinstance(l, (nn.MultiheadAttention, Attention)):
318
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
319
- tensor = getattr(l, attr)
320
- if tensor is not None:
321
- tensor.data = tensor.data.to(dtype)
322
-
323
- for name in ["text_projection", "proj"]:
324
- if hasattr(l, name):
325
- attr = getattr(l, name)
326
- if attr is not None:
327
- attr.data = attr.data.to(dtype)
328
-
329
- model.apply(_convert_weights)
330
-
331
-
332
- convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
333
-
334
-
335
- # used to maintain checkpoint compatibility
336
- def convert_to_custom_text_state_dict(state_dict: dict):
337
- if 'text_projection' in state_dict:
338
- # old format state_dict, move text tower -> .text
339
- new_state_dict = {}
340
- for k, v in state_dict.items():
341
- if any(k.startswith(p) for p in (
342
- 'text_projection',
343
- 'positional_embedding',
344
- 'token_embedding',
345
- 'transformer',
346
- 'ln_final',
347
- )):
348
- k = 'text.' + k
349
- new_state_dict[k] = v
350
- return new_state_dict
351
- return state_dict
352
-
353
-
354
- def build_model_from_openai_state_dict(
355
- state_dict: dict,
356
- quick_gelu=True,
357
- cast_dtype=torch.float16,
358
- ):
359
- vit = "visual.proj" in state_dict
360
-
361
- if vit:
362
- vision_width = state_dict["visual.conv1.weight"].shape[0]
363
- vision_layers = len(
364
- [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
365
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
366
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
367
- image_size = vision_patch_size * grid_size
368
- else:
369
- counts: list = [
370
- len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
371
- vision_layers = tuple(counts)
372
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
373
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
374
- vision_patch_size = None
375
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
376
- image_size = output_width * 32
377
-
378
- embed_dim = state_dict["text_projection"].shape[1]
379
- context_length = state_dict["positional_embedding"].shape[0]
380
- vocab_size = state_dict["token_embedding.weight"].shape[0]
381
- transformer_width = state_dict["ln_final.weight"].shape[0]
382
- transformer_heads = transformer_width // 64
383
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
384
-
385
- vision_cfg = CLIPVisionCfg(
386
- layers=vision_layers,
387
- width=vision_width,
388
- patch_size=vision_patch_size,
389
- image_size=image_size,
390
- )
391
- text_cfg = CLIPTextCfg(
392
- context_length=context_length,
393
- vocab_size=vocab_size,
394
- width=transformer_width,
395
- heads=transformer_heads,
396
- layers=transformer_layers,
397
- )
398
- model = CLIP(
399
- embed_dim,
400
- vision_cfg=vision_cfg,
401
- text_cfg=text_cfg,
402
- quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
403
- cast_dtype=cast_dtype,
404
- )
405
-
406
- for key in ["input_resolution", "context_length", "vocab_size"]:
407
- state_dict.pop(key, None)
408
-
409
- convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
410
- model.load_state_dict(state_dict)
411
- return model.eval()
412
-
413
-
414
- def trace_model(model, batch_size=256, device=torch.device('cpu')):
415
- model.eval()
416
- image_size = model.visual.image_size
417
- example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
418
- example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
419
- model = torch.jit.trace_module(
420
- model,
421
- inputs=dict(
422
- forward=(example_images, example_text),
423
- encode_text=(example_text,),
424
- encode_image=(example_images,)
425
- ))
426
- model.visual.image_size = image_size
427
- return model
428
-
429
-
430
- def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
431
- # Rescale the grid of position embeddings when loading from state_dict
432
- old_pos_embed = state_dict.get('visual.positional_embedding', None)
433
- if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
434
- return
435
- grid_size = to_2tuple(model.visual.grid_size)
436
- extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
437
- new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
438
- if new_seq_len == old_pos_embed.shape[0]:
439
- return
440
-
441
- if extra_tokens:
442
- pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
443
- else:
444
- pos_emb_tok, pos_emb_img = None, old_pos_embed
445
- old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
446
-
447
- logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
448
- pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
449
- pos_emb_img = F.interpolate(
450
- pos_emb_img,
451
- size=grid_size,
452
- mode=interpolation,
453
- antialias=antialias,
454
- align_corners=False,
455
- )
456
- pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
457
- if pos_emb_tok is not None:
458
- new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
459
- else:
460
- new_pos_embed = pos_emb_img
461
- state_dict['visual.positional_embedding'] = new_pos_embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json DELETED
@@ -1,17 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 32,
6
- "width": 1280,
7
- "head_width": 80,
8
- "patch_size": 14
9
- },
10
- "text_cfg": {
11
- "context_length": 77,
12
- "vocab_size": 49408,
13
- "width": 1024,
14
- "heads": 16,
15
- "layers": 24
16
- }
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py DELETED
@@ -1,181 +0,0 @@
1
- from collections import OrderedDict
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from .utils import freeze_batch_norm_2d
8
-
9
-
10
- class Bottleneck(nn.Module):
11
- expansion = 4
12
-
13
- def __init__(self, inplanes, planes, stride=1):
14
- super().__init__()
15
-
16
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
- self.bn1 = nn.BatchNorm2d(planes)
19
- self.act1 = nn.ReLU(inplace=True)
20
-
21
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
- self.bn2 = nn.BatchNorm2d(planes)
23
- self.act2 = nn.ReLU(inplace=True)
24
-
25
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
-
27
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
- self.act3 = nn.ReLU(inplace=True)
30
-
31
- self.downsample = None
32
- self.stride = stride
33
-
34
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
- self.downsample = nn.Sequential(OrderedDict([
37
- ("-1", nn.AvgPool2d(stride)),
38
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
- ("1", nn.BatchNorm2d(planes * self.expansion))
40
- ]))
41
-
42
- def forward(self, x: torch.Tensor):
43
- identity = x
44
-
45
- out = self.act1(self.bn1(self.conv1(x)))
46
- out = self.act2(self.bn2(self.conv2(out)))
47
- out = self.avgpool(out)
48
- out = self.bn3(self.conv3(out))
49
-
50
- if self.downsample is not None:
51
- identity = self.downsample(x)
52
-
53
- out += identity
54
- out = self.act3(out)
55
- return out
56
-
57
-
58
- class AttentionPool2d(nn.Module):
59
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
- super().__init__()
61
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
- self.k_proj = nn.Linear(embed_dim, embed_dim)
63
- self.q_proj = nn.Linear(embed_dim, embed_dim)
64
- self.v_proj = nn.Linear(embed_dim, embed_dim)
65
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
- self.num_heads = num_heads
67
-
68
- def forward(self, x):
69
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
- x, _ = F.multi_head_attention_forward(
73
- query=x, key=x, value=x,
74
- embed_dim_to_check=x.shape[-1],
75
- num_heads=self.num_heads,
76
- q_proj_weight=self.q_proj.weight,
77
- k_proj_weight=self.k_proj.weight,
78
- v_proj_weight=self.v_proj.weight,
79
- in_proj_weight=None,
80
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
- bias_k=None,
82
- bias_v=None,
83
- add_zero_attn=False,
84
- dropout_p=0.,
85
- out_proj_weight=self.c_proj.weight,
86
- out_proj_bias=self.c_proj.bias,
87
- use_separate_proj_weight=True,
88
- training=self.training,
89
- need_weights=False
90
- )
91
-
92
- return x[0]
93
-
94
-
95
- class ModifiedResNet(nn.Module):
96
- """
97
- A ResNet class that is similar to torchvision's but contains the following changes:
98
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
- - The final pooling layer is a QKV attention instead of an average pool
101
- """
102
-
103
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104
- super().__init__()
105
- self.output_dim = output_dim
106
- self.image_size = image_size
107
-
108
- # the 3-layer stem
109
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
- self.bn1 = nn.BatchNorm2d(width // 2)
111
- self.act1 = nn.ReLU(inplace=True)
112
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
- self.bn2 = nn.BatchNorm2d(width // 2)
114
- self.act2 = nn.ReLU(inplace=True)
115
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
- self.bn3 = nn.BatchNorm2d(width)
117
- self.act3 = nn.ReLU(inplace=True)
118
- self.avgpool = nn.AvgPool2d(2)
119
-
120
- # residual layers
121
- self._inplanes = width # this is a *mutable* variable used during construction
122
- self.layer1 = self._make_layer(width, layers[0])
123
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
-
127
- embed_dim = width * 32 # the ResNet feature dimension
128
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129
-
130
- self.init_parameters()
131
-
132
- def _make_layer(self, planes, blocks, stride=1):
133
- layers = [Bottleneck(self._inplanes, planes, stride)]
134
-
135
- self._inplanes = planes * Bottleneck.expansion
136
- for _ in range(1, blocks):
137
- layers.append(Bottleneck(self._inplanes, planes))
138
-
139
- return nn.Sequential(*layers)
140
-
141
- def init_parameters(self):
142
- if self.attnpool is not None:
143
- std = self.attnpool.c_proj.in_features ** -0.5
144
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148
-
149
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150
- for name, param in resnet_block.named_parameters():
151
- if name.endswith("bn3.weight"):
152
- nn.init.zeros_(param)
153
-
154
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155
- assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156
- for param in self.parameters():
157
- param.requires_grad = False
158
- if freeze_bn_stats:
159
- freeze_batch_norm_2d(self)
160
-
161
- @torch.jit.ignore
162
- def set_grad_checkpointing(self, enable=True):
163
- # FIXME support for non-transformer
164
- pass
165
-
166
- def stem(self, x):
167
- x = self.act1(self.bn1(self.conv1(x)))
168
- x = self.act2(self.bn2(self.conv2(x)))
169
- x = self.act3(self.bn3(self.conv3(x)))
170
- x = self.avgpool(x)
171
- return x
172
-
173
- def forward(self, x):
174
- x = self.stem(x)
175
- x = self.layer1(x)
176
- x = self.layer2(x)
177
- x = self.layer3(x)
178
- x = self.layer4(x)
179
- x = self.attnpool(x)
180
-
181
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/openai.py DELETED
@@ -1,144 +0,0 @@
1
- """ OpenAI pretrained model functions
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
-
6
- import os
7
- import warnings
8
- from typing import List, Optional, Union
9
-
10
- import torch
11
-
12
- from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
- from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
-
15
- __all__ = ["list_openai_models", "load_openai_model"]
16
-
17
-
18
- def list_openai_models() -> List[str]:
19
- """Returns the names of available CLIP models"""
20
- return list_pretrained_models_by_tag('openai')
21
-
22
-
23
- def load_openai_model(
24
- name: str,
25
- precision: Optional[str] = None,
26
- device: Optional[Union[str, torch.device]] = None,
27
- jit: bool = True,
28
- cache_dir: Optional[str] = None,
29
- ):
30
- """Load a CLIP model
31
-
32
- Parameters
33
- ----------
34
- name : str
35
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
- precision: str
37
- Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
- device : Union[str, torch.device]
39
- The device to put the loaded model
40
- jit : bool
41
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
- cache_dir : Optional[str]
43
- The directory to cache the downloaded model weights
44
-
45
- Returns
46
- -------
47
- model : torch.nn.Module
48
- The CLIP model
49
- preprocess : Callable[[PIL.Image], torch.Tensor]
50
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
- """
52
- if device is None:
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- if precision is None:
55
- precision = 'fp32' if device == 'cpu' else 'fp16'
56
-
57
- if get_pretrained_url(name, 'openai'):
58
- model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
- elif os.path.isfile(name):
60
- model_path = name
61
- else:
62
- raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
-
64
- try:
65
- # loading JIT archive
66
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
- state_dict = None
68
- except RuntimeError:
69
- # loading saved state dict
70
- if jit:
71
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
- jit = False
73
- state_dict = torch.load(model_path, map_location="cpu")
74
-
75
- if not jit:
76
- # Build a non-jit model from the OpenAI jitted model state dict
77
- cast_dtype = get_cast_dtype(precision)
78
- try:
79
- model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
- except KeyError:
81
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
- model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
-
84
- # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
- model = model.to(device)
86
- if precision.startswith('amp') or precision == 'fp32':
87
- model.float()
88
- elif precision == 'bf16':
89
- convert_weights_to_lp(model, dtype=torch.bfloat16)
90
-
91
- return model
92
-
93
- # patch the device names
94
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
-
97
- def patch_device(module):
98
- try:
99
- graphs = [module.graph] if hasattr(module, "graph") else []
100
- except RuntimeError:
101
- graphs = []
102
-
103
- if hasattr(module, "forward1"):
104
- graphs.append(module.forward1.graph)
105
-
106
- for graph in graphs:
107
- for node in graph.findAllNodes("prim::Constant"):
108
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
- node.copyAttributes(device_node)
110
-
111
- model.apply(patch_device)
112
- patch_device(model.encode_image)
113
- patch_device(model.encode_text)
114
-
115
- # patch dtype to float32 (typically for CPU)
116
- if precision == 'fp32':
117
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
- float_node = float_input.node()
120
-
121
- def patch_float(module):
122
- try:
123
- graphs = [module.graph] if hasattr(module, "graph") else []
124
- except RuntimeError:
125
- graphs = []
126
-
127
- if hasattr(module, "forward1"):
128
- graphs.append(module.forward1.graph)
129
-
130
- for graph in graphs:
131
- for node in graph.findAllNodes("aten::to"):
132
- inputs = list(node.inputs())
133
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
- if inputs[i].node()["value"] == 5:
135
- inputs[i].node().copyAttributes(float_node)
136
-
137
- model.apply(patch_float)
138
- patch_float(model.encode_image)
139
- patch_float(model.encode_text)
140
- model.float()
141
-
142
- # ensure image_size attr available at consistent location for both jit and non-jit
143
- model.visual.image_size = model.input_resolution.item()
144
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py DELETED
@@ -1,376 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from functools import partial
6
- from typing import Dict, Union
7
-
8
- from tqdm import tqdm
9
-
10
- from .version import __version__
11
-
12
- try:
13
- from huggingface_hub import hf_hub_download
14
- hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
15
- _has_hf_hub = True
16
- except ImportError:
17
- hf_hub_download = None
18
- _has_hf_hub = False
19
-
20
-
21
- def _pcfg(url='', hf_hub='', mean=None, std=None):
22
- return dict(
23
- url=url,
24
- hf_hub=hf_hub,
25
- mean=mean,
26
- std=std,
27
- )
28
-
29
-
30
- _RN50 = dict(
31
- openai=_pcfg(
32
- "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
33
- yfcc15m=_pcfg(
34
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
35
- cc12m=_pcfg(
36
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
37
- )
38
-
39
- _RN50_quickgelu = dict(
40
- openai=_pcfg(
41
- "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
42
- yfcc15m=_pcfg(
43
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
44
- cc12m=_pcfg(
45
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
46
- )
47
-
48
- _RN101 = dict(
49
- openai=_pcfg(
50
- "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
51
- yfcc15m=_pcfg(
52
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
53
- )
54
-
55
- _RN101_quickgelu = dict(
56
- openai=_pcfg(
57
- "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
58
- yfcc15m=_pcfg(
59
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
60
- )
61
-
62
- _RN50x4 = dict(
63
- openai=_pcfg(
64
- "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
65
- )
66
-
67
- _RN50x16 = dict(
68
- openai=_pcfg(
69
- "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
70
- )
71
-
72
- _RN50x64 = dict(
73
- openai=_pcfg(
74
- "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
75
- )
76
-
77
- _VITB32 = dict(
78
- openai=_pcfg(
79
- "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
80
- laion400m_e31=_pcfg(
81
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
82
- laion400m_e32=_pcfg(
83
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
84
- laion2b_e16=_pcfg(
85
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
86
- laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
87
- )
88
-
89
- _VITB32_quickgelu = dict(
90
- openai=_pcfg(
91
- "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
92
- laion400m_e31=_pcfg(
93
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
94
- laion400m_e32=_pcfg(
95
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
96
- )
97
-
98
- _VITB16 = dict(
99
- openai=_pcfg(
100
- "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
101
- laion400m_e31=_pcfg(
102
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
103
- laion400m_e32=_pcfg(
104
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
105
- # laion400m_32k=_pcfg(
106
- # url="",
107
- # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
108
- # laion400m_64k=_pcfg(
109
- # url="",
110
- # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
111
- laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
112
- )
113
-
114
- _VITB16_PLUS_240 = dict(
115
- laion400m_e31=_pcfg(
116
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
117
- laion400m_e32=_pcfg(
118
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
119
- )
120
-
121
- _VITL14 = dict(
122
- openai=_pcfg(
123
- "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
124
- laion400m_e31=_pcfg(
125
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
126
- laion400m_e32=_pcfg(
127
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
128
- laion2b_s32b_b82k=_pcfg(
129
- hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
130
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
131
- )
132
-
133
- _VITL14_336 = dict(
134
- openai=_pcfg(
135
- "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
136
- )
137
-
138
- _VITH14 = dict(
139
- laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
140
- )
141
-
142
- _VITg14 = dict(
143
- laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
144
- laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
145
- )
146
-
147
- _VITbigG14 = dict(
148
- laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
149
- )
150
-
151
- _robertaViTB32 = dict(
152
- laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
153
- )
154
-
155
- _xlmRobertaBaseViTB32 = dict(
156
- laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
157
- )
158
-
159
- _xlmRobertaLargeFrozenViTH14 = dict(
160
- frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
161
- )
162
-
163
- _convnext_base = dict(
164
- laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
165
- )
166
-
167
- _convnext_base_w = dict(
168
- laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
169
- laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
170
- laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
171
- )
172
-
173
- _convnext_base_w_320 = dict(
174
- laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
175
- laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
176
- )
177
-
178
- _convnext_large_d = dict(
179
- laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
180
- )
181
-
182
- _convnext_large_d_320 = dict(
183
- laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
184
- laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
185
- )
186
-
187
- _convnext_xxlarge = dict(
188
- laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
189
- laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
190
- laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
191
- )
192
-
193
- _coca_VITB32 = dict(
194
- laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
195
- mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
196
- )
197
-
198
- _coca_VITL14 = dict(
199
- laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
200
- mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
201
- )
202
-
203
-
204
- _PRETRAINED = {
205
- "RN50": _RN50,
206
- "RN50-quickgelu": _RN50_quickgelu,
207
- "RN101": _RN101,
208
- "RN101-quickgelu": _RN101_quickgelu,
209
- "RN50x4": _RN50x4,
210
- "RN50x16": _RN50x16,
211
- "RN50x64": _RN50x64,
212
- "ViT-B-32": _VITB32,
213
- "ViT-B-32-quickgelu": _VITB32_quickgelu,
214
- "ViT-B-16": _VITB16,
215
- "ViT-B-16-plus-240": _VITB16_PLUS_240,
216
- "ViT-L-14": _VITL14,
217
- "ViT-L-14-336": _VITL14_336,
218
- "ViT-H-14": _VITH14,
219
- "ViT-g-14": _VITg14,
220
- "ViT-bigG-14": _VITbigG14,
221
- "roberta-ViT-B-32": _robertaViTB32,
222
- "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
223
- "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
224
- "convnext_base": _convnext_base,
225
- "convnext_base_w": _convnext_base_w,
226
- "convnext_base_w_320": _convnext_base_w_320,
227
- "convnext_large_d": _convnext_large_d,
228
- "convnext_large_d_320": _convnext_large_d_320,
229
- "convnext_xxlarge": _convnext_xxlarge,
230
- "coca_ViT-B-32": _coca_VITB32,
231
- "coca_ViT-L-14": _coca_VITL14,
232
- }
233
-
234
-
235
- def _clean_tag(tag: str):
236
- # normalize pretrained tags
237
- return tag.lower().replace('-', '_')
238
-
239
-
240
- def list_pretrained(as_str: bool = False):
241
- """ returns list of pretrained models
242
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
243
- """
244
- return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
245
-
246
-
247
- def list_pretrained_models_by_tag(tag: str):
248
- """ return all models having the specified pretrain tag """
249
- models = []
250
- tag = _clean_tag(tag)
251
- for k in _PRETRAINED.keys():
252
- if tag in _PRETRAINED[k]:
253
- models.append(k)
254
- return models
255
-
256
-
257
- def list_pretrained_tags_by_model(model: str):
258
- """ return all pretrain tags for the specified model architecture """
259
- tags = []
260
- if model in _PRETRAINED:
261
- tags.extend(_PRETRAINED[model].keys())
262
- return tags
263
-
264
-
265
- def is_pretrained_cfg(model: str, tag: str):
266
- if model not in _PRETRAINED:
267
- return False
268
- return _clean_tag(tag) in _PRETRAINED[model]
269
-
270
-
271
- def get_pretrained_cfg(model: str, tag: str):
272
- if model not in _PRETRAINED:
273
- return {}
274
- model_pretrained = _PRETRAINED[model]
275
- return model_pretrained.get(_clean_tag(tag), {})
276
-
277
-
278
- def get_pretrained_url(model: str, tag: str):
279
- cfg = get_pretrained_cfg(model, _clean_tag(tag))
280
- return cfg.get('url', '')
281
-
282
-
283
- def download_pretrained_from_url(
284
- url: str,
285
- cache_dir: Union[str, None] = None,
286
- ):
287
- if not cache_dir:
288
- cache_dir = os.path.expanduser("~/.cache/clip")
289
- os.makedirs(cache_dir, exist_ok=True)
290
- filename = os.path.basename(url)
291
-
292
- if 'openaipublic' in url:
293
- expected_sha256 = url.split("/")[-2]
294
- elif 'mlfoundations' in url:
295
- expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
296
- else:
297
- expected_sha256 = ''
298
-
299
- download_target = os.path.join(cache_dir, filename)
300
-
301
- if os.path.exists(download_target) and not os.path.isfile(download_target):
302
- raise RuntimeError(f"{download_target} exists and is not a regular file")
303
-
304
- if os.path.isfile(download_target):
305
- if expected_sha256:
306
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
307
- return download_target
308
- else:
309
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
310
- else:
311
- return download_target
312
-
313
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
314
- with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
315
- while True:
316
- buffer = source.read(8192)
317
- if not buffer:
318
- break
319
-
320
- output.write(buffer)
321
- loop.update(len(buffer))
322
-
323
- if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
324
- raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
325
-
326
- return download_target
327
-
328
-
329
- def has_hf_hub(necessary=False):
330
- if not _has_hf_hub and necessary:
331
- # if no HF Hub module installed, and it is necessary to continue, raise error
332
- raise RuntimeError(
333
- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
334
- return _has_hf_hub
335
-
336
-
337
- def download_pretrained_from_hf(
338
- model_id: str,
339
- filename: str = 'open_clip_pytorch_model.bin',
340
- revision=None,
341
- cache_dir: Union[str, None] = None,
342
- ):
343
- has_hf_hub(True)
344
- cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
345
- return cached_file
346
-
347
-
348
- def download_pretrained(
349
- cfg: Dict,
350
- force_hf_hub: bool = False,
351
- cache_dir: Union[str, None] = None,
352
- ):
353
- target = ''
354
- if not cfg:
355
- return target
356
-
357
- download_url = cfg.get('url', '')
358
- download_hf_hub = cfg.get('hf_hub', '')
359
- if download_hf_hub and force_hf_hub:
360
- # use HF hub even if url exists
361
- download_url = ''
362
-
363
- if download_url:
364
- target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
365
- elif download_hf_hub:
366
- has_hf_hub(True)
367
- # we assume the hf_hub entries in pretrained config combine model_id + filename in
368
- # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
369
- # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
370
- model_id, filename = os.path.split(download_hf_hub)
371
- if filename:
372
- target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
373
- else:
374
- target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
375
-
376
- return target