diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ae2b0cda96652d263cc450684d5ad701feb66bcf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +__assets__/feature_visualization.png filter=lfs diff=lfs merge=lfs -text +__assets__/pipeline.png filter=lfs diff=lfs merge=lfs -text +__assets__/teaser.gif filter=lfs diff=lfs merge=lfs -text +__assets__/teaser.mp4 filter=lfs diff=lfs merge=lfs -text +condition_images/rgb/dog_on_grass.png filter=lfs diff=lfs merge=lfs -text +generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text +generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_1.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_pan_down.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_pan_up.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_translation_1.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_translation_2.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_zoom_in.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/camera_zoom_out.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_astronaut.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_blackswan.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_cat.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_cow.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_fox.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_leaves.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_white_tiger.mp4 filter=lfs diff=lfs merge=lfs -text +reference_videos/sample_wolf.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..02205742035e11514f0fbc51380099ae6d07e91e --- /dev/null +++ b/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1,12 @@ +--- +title: MotionClone-Image-to-Video +emoji: 📷 +colorFrom: red +colorTo: pink +sdk: gradio +sdk_version: 4.44.1 +app_file: i2v_video_app.py +pinned: true +license: bsd-3-clause +short_description: Motion cloning for controllable video generation +--- \ No newline at end of file diff --git a/README.md b/README.md index 0cfa3ad4fbe94b34524fe5ea45f764c688c2fe98..02205742035e11514f0fbc51380099ae6d07e91e 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ --- -title: MotionClone Image To Video -emoji: 🦀 -colorFrom: indigo -colorTo: purple +title: MotionClone-Image-to-Video +emoji: 📷 +colorFrom: red +colorTo: pink sdk: gradio -sdk_version: 5.17.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +sdk_version: 4.44.1 +app_file: i2v_video_app.py +pinned: true +license: bsd-3-clause +short_description: Motion cloning for controllable video generation +--- \ No newline at end of file diff --git a/__assets__/feature_visualization.png b/__assets__/feature_visualization.png new file mode 100644 index 0000000000000000000000000000000000000000..29061b8073ee2ca9b5fdd1246b3636ab59d6e6b3 --- /dev/null +++ b/__assets__/feature_visualization.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c0891fbfe56b1650d6c65dac700d02faee46cff0cc56515c8a23a8be0c9a46b +size 943577 diff --git a/__assets__/pipeline.png b/__assets__/pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..7b78d569a14b8ea5d6f2e5b7ddcc15f5fa332ebc --- /dev/null +++ b/__assets__/pipeline.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc9926f5f4a746475cb1963a4e908671db82d0cc630c8a5e9cd43f78885fd82d +size 1006207 diff --git a/__assets__/teaser.gif b/__assets__/teaser.gif new file mode 100644 index 0000000000000000000000000000000000000000..bafc7bc8bd5cdd28cfc575aa126b0245da972b7a --- /dev/null +++ b/__assets__/teaser.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ee4ff21495ae52ff2c9f4ff9ad5406c3f4445633a437664f9cc20277460ea6f +size 14601625 diff --git a/__assets__/teaser.mp4 b/__assets__/teaser.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d5f80422cb2ebe01032a5412da0aa129569f14d5 --- /dev/null +++ b/__assets__/teaser.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:201747f42691e708b9efe48ea054961fd82cf54b83ac43e0d97a43f81779c00b +size 4957080 diff --git a/condition_images/rgb/dog_on_grass.png b/condition_images/rgb/dog_on_grass.png new file mode 100644 index 0000000000000000000000000000000000000000..d74de1a8e6d7c1e68b993d57c239d8eb68cc4e76 --- /dev/null +++ b/condition_images/rgb/dog_on_grass.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b3ead35573919274f59d763c5085608ca78a993bf508448ca22af31ebcab113 +size 1508291 diff --git a/condition_images/scribble/lion_forest.png b/condition_images/scribble/lion_forest.png new file mode 100644 index 0000000000000000000000000000000000000000..6f9189f6a1a70180ee55cfb684314500c8996cbc Binary files /dev/null and b/condition_images/scribble/lion_forest.png differ diff --git a/configs/i2v_rgb.jsonl b/configs/i2v_rgb.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..27a64f03011d33f68510ed2b07de71ecc65c57b2 --- /dev/null +++ b/configs/i2v_rgb.jsonl @@ -0,0 +1 @@ +{"video_path":"reference_videos/camera_zoom_out.mp4", "condition_image_paths":["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass"} \ No newline at end of file diff --git a/configs/i2v_rgb.yaml b/configs/i2v_rgb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4cf348b0353deb78c2be7c21a47606f8ed142889 --- /dev/null +++ b/configs/i2v_rgb.yaml @@ -0,0 +1,20 @@ +motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" +dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" +model_config: "configs/model_config/model_config.yaml" +controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" +controlnet_config: "configs/sparsectrl/latent_condition.yaml" +adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" + +cfg_scale: 7.5 # in default realistic classifer-free guidance +negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers" + +inference_steps: 100 # the total denosing step for inference +guidance_scale: 0.3 # which scale of time step to end guidance +guidance_steps: 40 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance +warm_up_steps: 10 +cool_up_steps: 10 + +motion_guidance_weight: 2000 +motion_guidance_blocks: ['up_blocks.1'] + +add_noise_step: 400 \ No newline at end of file diff --git a/configs/i2v_sketch.jsonl b/configs/i2v_sketch.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..a31210e9b0621044119d9abcdd1f5dd04a55c3aa --- /dev/null +++ b/configs/i2v_sketch.jsonl @@ -0,0 +1 @@ +{"video_path":"reference_videos/sample_white_tiger.mp4", "condition_image_paths":["condition_images/scribble/lion_forest.png"], "new_prompt": "Lion, walks in the forest"} \ No newline at end of file diff --git a/configs/i2v_sketch.yaml b/configs/i2v_sketch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..284b2a774f949da006d3973906b481d26f668a04 --- /dev/null +++ b/configs/i2v_sketch.yaml @@ -0,0 +1,20 @@ +motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" +dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" +model_config: "configs/model_config/model_config.yaml" +controlnet_config: "configs/sparsectrl/image_condition.yaml" +controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt" +adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" + +cfg_scale: 7.5 # in default realistic classifer-free guidance +negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers" + +inference_steps: 200 # the total denosing step for inference +guidance_scale: 0.4 # which scale of time step to end guidance +guidance_steps: 120 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance +warm_up_steps: 10 +cool_up_steps: 10 + +motion_guidance_weight: 2000 +motion_guidance_blocks: ['up_blocks.1'] + +add_noise_step: 400 \ No newline at end of file diff --git a/configs/model_config/inference-v1.yaml b/configs/model_config/inference-v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab7e4be4f3b31175a9f737050b5949f103cdd339 --- /dev/null +++ b/configs/model_config/inference-v1.yaml @@ -0,0 +1,25 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true # from config v3 + + + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true # from config v3 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: False diff --git a/configs/model_config/inference-v2.yaml b/configs/model_config/inference-v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce19a89686b343f1bfd94bd19c91661a3de5bc2a --- /dev/null +++ b/configs/model_config/inference-v2.yaml @@ -0,0 +1,24 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: true + motion_module_decoder_only: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: False diff --git a/configs/model_config/inference-v3.yaml b/configs/model_config/inference-v3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd9a34d7839374221548c5a39cfb1c744f96b3db --- /dev/null +++ b/configs/model_config/inference-v3.yaml @@ -0,0 +1,22 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: Vanilla + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: False diff --git a/configs/model_config/model_config copy.yaml b/configs/model_config/model_config copy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ddda34018c8eeeb749d0432d247eefc8b617a76 --- /dev/null +++ b/configs/model_config/model_config copy.yaml @@ -0,0 +1,22 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true # from config v3 + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true # from config v3 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: False \ No newline at end of file diff --git a/configs/model_config/model_config.yaml b/configs/model_config/model_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..50f901249ef3c6488db0b01d45d917818feafa84 --- /dev/null +++ b/configs/model_config/model_config.yaml @@ -0,0 +1,21 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_attention_dim_div: 1 + zero_initialize: true + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false \ No newline at end of file diff --git a/configs/model_config/model_config_public.yaml b/configs/model_config/model_config_public.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0aa988d0b2963eca02301d669c19a0f6fd9caae1 --- /dev/null +++ b/configs/model_config/model_config_public.yaml @@ -0,0 +1,25 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true # from config v3 + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true # from config v3 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: False diff --git a/configs/sparsectrl/image_condition.yaml b/configs/sparsectrl/image_condition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..886f7d2a6cbd805b2a33c6598a8da7cd5bc0b39b --- /dev/null +++ b/configs/sparsectrl/image_condition.yaml @@ -0,0 +1,17 @@ +controlnet_additional_kwargs: + set_noisy_sample_input_to_zero: true + use_simplified_condition_embedding: false + conditioning_channels: 3 + + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 diff --git a/configs/sparsectrl/latent_condition.yaml b/configs/sparsectrl/latent_condition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e1219de599cb4eb7361d6caeacbcc9eb00c16aa --- /dev/null +++ b/configs/sparsectrl/latent_condition.yaml @@ -0,0 +1,17 @@ +controlnet_additional_kwargs: + set_noisy_sample_input_to_zero: true + use_simplified_condition_embedding: true + conditioning_channels: 4 + + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 diff --git a/configs/t2v_camera.jsonl b/configs/t2v_camera.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3e9c0dfba95a0c8aed888953c6dd857d19a7541a --- /dev/null +++ b/configs/t2v_camera.jsonl @@ -0,0 +1,12 @@ +{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Relics on the seabed", "seed": 42} +{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "A road in the mountain", "seed": 42} +{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Caves, a path for exploration", "seed": 2026} +{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Railway for train"} +{"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Tree, in the mountain", "seed": 2026} +{"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Red car on the track", "seed": 2026} +{"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Man, standing in his garden.", "seed": 2026} +{"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A island, on the ocean, sunny day"} +{"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A tower, with fireworks"} +{"video_path":"reference_videos/camera_pan_up.mp4", "new_prompt": "Beautiful house, around with flowers", "seed": 42} +{"video_path":"reference_videos/camera_translation_2.mp4", "new_prompt": "Forest, in winter", "seed": 2028} +{"video_path":"reference_videos/camera_pan_down.mp4", "new_prompt": "Eagle, standing in the tree", "seed": 2026} \ No newline at end of file diff --git a/configs/t2v_camera.yaml b/configs/t2v_camera.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d12b26d17ba7df5d80a1429a14e0bd9a4fdca35 --- /dev/null +++ b/configs/t2v_camera.yaml @@ -0,0 +1,19 @@ + +motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" +dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" +model_config: "configs/model_config/model_config.yaml" + +cfg_scale: 7.5 # in default realistic classifer-free guidance +negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers" +postive_prompt: " 8k, high detailed, best quality, film grain, Fujifilm XT3" + +inference_steps: 100 # the total denosing step for inference +guidance_scale: 0.3 # which scale of time step to end guidance 0.2/40 +guidance_steps: 50 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance +warm_up_steps: 10 +cool_up_steps: 10 + +motion_guidance_weight: 2000 +motion_guidance_blocks: ['up_blocks.1'] + +add_noise_step: 400 \ No newline at end of file diff --git a/configs/t2v_object.jsonl b/configs/t2v_object.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..e7143d92f8ec20585402e34556846a8280ccc9b7 --- /dev/null +++ b/configs/t2v_object.jsonl @@ -0,0 +1,6 @@ +{"video_path":"reference_videos/sample_astronaut.mp4", "new_prompt": "Robot, walks in the street.","seed":59} +{"video_path":"reference_videos/sample_cat.mp4", "new_prompt": "Tiger, raises its head.", "seed": 2025} +{"video_path":"reference_videos/sample_leaves.mp4", "new_prompt": "Petals falling in the wind.","seed":3407} +{"video_path":"reference_videos/sample_fox.mp4", "new_prompt": "Cat, turns its head in the living room."} +{"video_path":"reference_videos/sample_blackswan.mp4", "new_prompt": "Duck, swims in the river.","seed":3407} +{"video_path":"reference_videos/sample_cow.mp4", "new_prompt": "Pig, drinks water on beach.","seed":3407} \ No newline at end of file diff --git a/configs/t2v_object.yaml b/configs/t2v_object.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3726f8c24e4a8812c8ce81852f503b19aa05f98d --- /dev/null +++ b/configs/t2v_object.yaml @@ -0,0 +1,19 @@ + +motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" +dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" +model_config: "configs/model_config/model_config.yaml" + +cfg_scale: 7.5 # in default realistic classifer-free guidance +negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers" +postive_prompt: "8k, high detailed, best quality, film grain, Fujifilm XT3" + +inference_steps: 300 # the total denosing step for inference +guidance_scale: 0.4 # which scale of time step to end guidance +guidance_steps: 180 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance +warm_up_steps: 10 +cool_up_steps: 10 + +motion_guidance_weight: 2000 +motion_guidance_blocks: ['up_blocks.1',] + +add_noise_step: 400 \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e50b14ef29d7243589b6aadb379b13fc20ca0c8 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,25 @@ +name: motionclone +channels: + - pytorch + - nvidia +dependencies: + - python=3.11.3 + - pytorch=2.0.1 + - torchvision=0.15.2 + - pytorch-cuda=11.8 + - pip + - pip: + - accelerate + - diffusers==0.16.0 + - transformers==4.28.1 + - xformers==0.0.20 + - imageio[ffmpeg] + - decord==0.6.0 + - gdown + - einops + - omegaconf + - safetensors + - gradio + - wandb + - triton + - opencv-python \ No newline at end of file diff --git a/generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 b/generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d1cdc9f9931f6d2a587209b58db0727787b90d18 --- /dev/null +++ b/generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63ecf6f1250b83d71b50352a020c97eb60223ee33813219b2bd8d7588f1ecfec +size 285735 diff --git a/generated_videos/inference_config.json b/generated_videos/inference_config.json new file mode 100644 index 0000000000000000000000000000000000000000..73c94ed5e2dffa79bca459515d5e1cd452ebc23a --- /dev/null +++ b/generated_videos/inference_config.json @@ -0,0 +1,21 @@ +motion_module: models/Motion_Module/v3_sd15_mm.ckpt +dreambooth_path: models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors +model_config: configs/model_config/model_config.yaml +controlnet_config: configs/sparsectrl/image_condition.yaml +controlnet_path: models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt +adapter_lora_path: models/Motion_Module/v3_sd15_adapter.ckpt +cfg_scale: 7.5 +negative_prompt: ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, + extra limbs, poorly drawn face, poorly drawn hands, missing fingers +inference_steps: 200 +guidance_scale: 0.4 +guidance_steps: 120 +warm_up_steps: 10 +cool_up_steps: 10 +motion_guidance_weight: 2000 +motion_guidance_blocks: +- up_blocks.1 +add_noise_step: 400 +width: 512 +height: 512 +video_length: 16 diff --git a/generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 b/generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a007b9b4f3f8e24e38f935ec87addb8ab0c3c95e --- /dev/null +++ b/generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ae68b549f1c6541417009d1cdd35d01286876bada07fb53a3354ad9225856cf +size 538343 diff --git a/i2v_video_app.py b/i2v_video_app.py new file mode 100644 index 0000000000000000000000000000000000000000..bdcddb71e89f242c74c47bea0a52ba8671acca62 --- /dev/null +++ b/i2v_video_app.py @@ -0,0 +1,284 @@ +import gradio as gr +from omegaconf import OmegaConf +import torch +from diffusers import AutoencoderKL, DDIMScheduler +from transformers import CLIPTextModel, CLIPTokenizer +from motionclone.models.unet import UNet3DConditionModel +from motionclone.models.sparse_controlnet import SparseControlNetModel +from motionclone.pipelines.pipeline_animation import AnimationPipeline +from motionclone.utils.util import load_weights, auto_download +from diffusers.utils.import_utils import is_xformers_available +from motionclone.utils.motionclone_functions import * +import json +from motionclone.utils.xformer_attention import * +import os +import numpy as np +import imageio +import shutil +import subprocess +from types import SimpleNamespace + +# 模型下载逻辑 +def download_weights(): + try: + # 创建模型目录 + os.makedirs("models", exist_ok=True) + os.makedirs("models/DreamBooth_LoRA", exist_ok=True) + os.makedirs("models/Motion_Module", exist_ok=True) + os.makedirs("models/SparseCtrl", exist_ok=True) + + # 下载 Stable Diffusion 模型 + if not os.path.exists("models/StableDiffusion"): + subprocess.run(["git", "clone", "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5", "models/StableDiffusion"]) + + # 下载 DreamBooth LoRA 模型 + if not os.path.exists("models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"): + subprocess.run(["wget", "https://huggingface.co/svjack/Realistic-Vision-V6.0-B1/resolve/main/realisticVisionV60B1_v51VAE.safetensors", "-O", "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"]) + + # 下载 Motion Module 模型 + if not os.path.exists("models/Motion_Module/v3_sd15_mm.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_mm.ckpt", "-O", "models/Motion_Module/v3_sd15_mm.ckpt"]) + if not os.path.exists("models/Motion_Module/v3_sd15_adapter.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt", "-O", "models/Motion_Module/v3_sd15_adapter.ckpt"]) + + # 下载 SparseCtrl 模型 + if not os.path.exists("models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_rgb.ckpt", "-O", "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"]) + if not os.path.exists("models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_scribble.ckpt", "-O", "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"]) + + print("Weights downloaded successfully.") + except Exception as e: + print(f"Error downloading weights: {e}") + +# 下载权重 +download_weights() + +# 模型初始化逻辑 +def initialize_models(pretrained_model_path, config): + # 设置设备 + adopted_dtype = torch.float16 + device = "cuda" + set_all_seed(42) + + # 加载模型组件 + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype) + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype) + + # 更新配置 + config["width"] = config.get("W", 512) + config["height"] = config.get("H", 512) + config["video_length"] = config.get("L", 16) + + # 加载模型配置 + model_config = OmegaConf.load(config.get("model_config", "")) + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs)).to(device).to(dtype=adopted_dtype) + + # 加载 controlnet 模型 + controlnet = None + if config.get("controlnet_path", "") != "": + assert config.get("controlnet_config", "") != "" + + unet.config.num_attention_heads = 8 + unet.config.projection_class_embeddings_input_dim = None + + controlnet_config = OmegaConf.load(config["controlnet_config"]) + controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype) + + auto_download(config["controlnet_path"], is_dreambooth_lora=False) + print(f"loading controlnet checkpoint from ", config["controlnet_path"]) + controlnet_state_dict = torch.load(config["controlnet_path"], map_location="cpu") + controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict + controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name} + controlnet_state_dict.pop("animatediff_config", "") + controlnet.load_state_dict(controlnet_state_dict) + del controlnet_state_dict + + # 启用 xformers + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + + # 创建 pipeline + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + controlnet=controlnet, + scheduler=DDIMScheduler(**model_config.noise_scheduler_kwargs), + ).to(device) + + # 加载权重 + pipeline = load_weights( + pipeline, + motion_module_path=config.get("motion_module", ""), + adapter_lora_path=config.get("adapter_lora_path", ""), + adapter_lora_scale=config.get("adapter_lora_scale", 1.0), + dreambooth_model_path=config.get("dreambooth_path", ""), + ).to(device) + pipeline.text_encoder.to(dtype=adopted_dtype) + + # 加载自定义函数 + pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler) + pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler) + pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet) + pipeline.sample_video = sample_video.__get__(pipeline) + pipeline.single_step_video = single_step_video.__get__(pipeline) + pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline) + pipeline.add_noise = add_noise.__get__(pipeline) + pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline) + pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline) + + # 冻结 UNet 和 ControlNet 参数 + for param in pipeline.unet.parameters(): + param.requires_grad = False + if pipeline.controlnet is not None: + for param in pipeline.controlnet.parameters(): + param.requires_grad = False + + pipeline.input_config, pipeline.unet.input_config = SimpleNamespace(**config), SimpleNamespace(**config) + pipeline.unet = prep_unet_attention(pipeline.unet, config.get("motion_guidance_blocks", [])) + pipeline.unet = prep_unet_conv(pipeline.unet) + + return pipeline + +# 硬编码的配置值 +config = { + "motion_module": "models/Motion_Module/v3_sd15_mm.ckpt", + "dreambooth_path": "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors", + "model_config": "configs/model_config/model_config.yaml", + "controlnet_path": "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt", + "controlnet_config": "configs/sparsectrl/latent_condition.yaml", + "adapter_lora_path": "models/Motion_Module/v3_sd15_adapter.ckpt", + "W": 512, + "H": 512, + "L": 16, + "motion_guidance_blocks": ['up_blocks.1'], +} + +# 初始化模型 +pretrained_model_path = "models/StableDiffusion" +pipeline = initialize_models(pretrained_model_path, config) + +# 视频生成函数 +def generate_video(uploaded_video, condition_images, new_prompt, seed, motion_representation_save_dir, generated_videos_save_dir, visible_gpu, without_xformers, cfg_scale, negative_prompt, positive_prompt, inference_steps, guidance_scale, guidance_steps, warm_up_steps, cool_up_steps, motion_guidance_weight, motion_guidance_blocks, add_noise_step): + # 更新配置 + config.update({ + "cfg_scale": cfg_scale, + "negative_prompt": negative_prompt, + "positive_prompt": positive_prompt, + "inference_steps": inference_steps, + "guidance_scale": guidance_scale, + "guidance_steps": guidance_steps, + "warm_up_steps": warm_up_steps, + "cool_up_steps": cool_up_steps, + "motion_guidance_weight": motion_guidance_weight, + #"motion_guidance_blocks": motion_guidance_blocks, + "add_noise_step": add_noise_step + }) + + # 设置环境变量 + os.environ["CUDA_VISIBLE_DEVICES"] = visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0)) + + device = pipeline.device + + # 创建保存目录 + if not os.path.exists(generated_videos_save_dir): + os.makedirs(generated_videos_save_dir) + if not os.path.exists(motion_representation_save_dir): + os.makedirs(motion_representation_save_dir) + + # 处理上传的视频 + if uploaded_video is not None: + pipeline.scheduler.customized_set_timesteps(config["inference_steps"], config["guidance_steps"], config["guidance_scale"], device=device, timestep_spacing_type="uneven") + + # 将上传的视频保存到指定路径 + video_path = os.path.join(generated_videos_save_dir, os.path.basename(uploaded_video)) + shutil.copy2(uploaded_video, video_path) + + # 更新配置 + config["video_path"] = video_path + config["condition_image_path_list"] = condition_images + config["image_index"] = [0] * len(condition_images) + config["new_prompt"] = new_prompt + config.get("positive_prompt", "") + config["controlnet_scale"] = 1.0 + + pipeline.input_config, pipeline.unet.input_config = SimpleNamespace(**config), SimpleNamespace(**config) + + # 提取运动表示 + seed_motion = seed if seed is not None else 76739 + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed_motion) + motion_representation_path = os.path.join(motion_representation_save_dir, os.path.splitext(os.path.basename(config["video_path"]))[0] + '.pt') + pipeline.obtain_motion_representation(generator=generator, motion_representation_path=motion_representation_path, use_controlnet=True) + + # 生成视频 + seed = seed_motion + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed) + pipeline.input_config.seed = seed + videos = pipeline.sample_video(generator=generator, add_controlnet=True) + + videos = rearrange(videos, "b c f h w -> b f h w c") + save_path = os.path.join(generated_videos_save_dir, os.path.splitext(os.path.basename(config["video_path"]))[0] + "_" + config["new_prompt"].strip().replace(' ', '_') + str(seed_motion) + "_" + str(seed) + '.mp4') + videos_uint8 = (videos[0] * 255).astype(np.uint8) + imageio.mimwrite(save_path, videos_uint8, fps=8) + print(save_path, "is done") + + return save_path + else: + return "No video uploaded." + +# 使用 Gradio 构建界面 +with gr.Blocks() as demo: + gr.Markdown("# MotionClone Video Generation") + with gr.Row(): + with gr.Column(): + uploaded_video = gr.Video(label="Upload Video") + condition_images = gr.Files(label="Condition Images") + new_prompt = gr.Textbox(label="New Prompt", value="A beautiful scene") + seed = gr.Number(label="Seed", value=76739) + generate_button = gr.Button("Generate Video") + with gr.Column(): + output_video = gr.Video(label="Generated Video") + + with gr.Accordion("Advanced Settings", open=False): + motion_representation_save_dir = gr.Textbox(label="Motion Representation Save Dir", value="motion_representation/") + generated_videos_save_dir = gr.Textbox(label="Generated Videos Save Dir", value="generated_videos/") + visible_gpu = gr.Textbox(label="Visible GPU", value="0") + without_xformers = gr.Checkbox(label="Without Xformers", value=False) + cfg_scale = gr.Number(label="CFG Scale", value=7.5) + negative_prompt = gr.Textbox(label="Negative Prompt", value="ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers") + positive_prompt = gr.Textbox(label="Positive Prompt", value="8k, high detailed, best quality, film grain, Fujifilm XT3") + inference_steps = gr.Number(label="Inference Steps", value=100) + guidance_scale = gr.Number(label="Guidance Scale", value=0.3) + guidance_steps = gr.Number(label="Guidance Steps", value=40) + warm_up_steps = gr.Number(label="Warm Up Steps", value=10) + cool_up_steps = gr.Number(label="Cool Up Steps", value=10) + motion_guidance_weight = gr.Number(label="Motion Guidance Weight", value=2000) + motion_guidance_blocks = gr.Textbox(label="Motion Guidance Blocks", value="['up_blocks.1']") + add_noise_step = gr.Number(label="Add Noise Step", value=400) + + # 绑定生成函数 + generate_button.click( + generate_video, + inputs=[ + uploaded_video, condition_images, new_prompt, seed, motion_representation_save_dir, generated_videos_save_dir, visible_gpu, without_xformers, cfg_scale, negative_prompt, positive_prompt, inference_steps, guidance_scale, guidance_steps, warm_up_steps, cool_up_steps, motion_guidance_weight, motion_guidance_blocks, add_noise_step + ], + outputs=output_video + ) + + # 添加示例 + examples = [ + {"video_path": "reference_videos/camera_zoom_out.mp4", "condition_image_paths": ["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass", "seed": 42} + ] + examples = list(map(lambda d: [d["video_path"], d["condition_image_paths"], d["new_prompt"], d["seed"]], examples)) + + gr.Examples( + examples=examples, + inputs=[uploaded_video, condition_images, new_prompt, seed], + outputs=output_video, + fn=generate_video, + cache_examples=False + ) + +# 启动应用 +demo.launch(share=True) diff --git a/i2v_video_sample.py b/i2v_video_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ee4b8f723d304575258e30a66e91eab027aabd --- /dev/null +++ b/i2v_video_sample.py @@ -0,0 +1,157 @@ +import argparse +from omegaconf import OmegaConf +import torch +from diffusers import AutoencoderKL, DDIMScheduler +from transformers import CLIPTextModel, CLIPTokenizer +from motionclone.models.unet import UNet3DConditionModel +from motionclone.models.sparse_controlnet import SparseControlNetModel +from motionclone.pipelines.pipeline_animation import AnimationPipeline +from motionclone.utils.util import load_weights, auto_download +from diffusers.utils.import_utils import is_xformers_available +from motionclone.utils.motionclone_functions import * +import json +from motionclone.utils.xformer_attention import * + + +def main(args): + + os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0)) + + config = OmegaConf.load(args.inference_config) + adopted_dtype = torch.float16 + device = "cuda" + set_all_seed(42) + + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype) + + config.width = config.get("W", args.W) + config.height = config.get("H", args.H) + config.video_length = config.get("L", args.L) + + if not os.path.exists(args.generated_videos_save_dir): + os.makedirs(args.generated_videos_save_dir) + OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json")) + + model_config = OmegaConf.load(config.get("model_config", "")) + unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype) + + # load controlnet model + controlnet = None + if config.get("controlnet_path", "") != "": + # assert model_config.get("controlnet_images", "") != "" + assert config.get("controlnet_config", "") != "" + + unet.config.num_attention_heads = 8 + unet.config.projection_class_embeddings_input_dim = None + + controlnet_config = OmegaConf.load(config.controlnet_config) + controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype) + + auto_download(config.controlnet_path, is_dreambooth_lora=False) + print(f"loading controlnet checkpoint from {config.controlnet_path} ...") + controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu") + controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict + controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name} + controlnet_state_dict.pop("animatediff_config", "") + controlnet.load_state_dict(controlnet_state_dict) + del controlnet_state_dict + + # set xformers + if is_xformers_available() and (not args.without_xformers): + unet.enable_xformers_memory_efficient_attention() + + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + controlnet=controlnet, + scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)), + ).to(device) + + pipeline = load_weights( + pipeline, + # motion module + motion_module_path = config.get("motion_module", ""), + # domain adapter + adapter_lora_path = config.get("adapter_lora_path", ""), + adapter_lora_scale = config.get("adapter_lora_scale", 1.0), + # image layer + dreambooth_model_path = config.get("dreambooth_path", ""), + ).to(device) + pipeline.text_encoder.to(dtype=adopted_dtype) + + # customized functions in motionclone_functions + pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler) + pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler) + pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet) + pipeline.sample_video = sample_video.__get__(pipeline) + pipeline.single_step_video = single_step_video.__get__(pipeline) + pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline) + pipeline.add_noise = add_noise.__get__(pipeline) + pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline) + pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline) + + for param in pipeline.unet.parameters(): + param.requires_grad = False + for param in pipeline.controlnet.parameters(): + param.requires_grad = False + + pipeline.input_config, pipeline.unet.input_config = config, config + pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks) + pipeline.unet = prep_unet_conv(pipeline.unet) + pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven") + + with open(args.examples, 'r') as files: + for line in files: + # prepare infor of each case + example_infor = json.loads(line) + config.video_path = example_infor["video_path"] + config.condition_image_path_list = example_infor["condition_image_paths"] + config.image_index = example_infor.get("image_index",[0]) + assert len(config.image_index) == len(config.condition_image_path_list) + config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "") + config.controlnet_scale = example_infor.get("controlnet_scale", 1.0) + pipeline.input_config, pipeline.unet.input_config = config, config # update config + + # perform motion representation extraction + seed_motion = seed_motion = example_infor.get("seed", args.default_seed) + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed_motion) + if not os.path.exists(args.motion_representation_save_dir): + os.makedirs(args.motion_representation_save_dir) + motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt') + pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,) + + # perform video generation + seed = seed_motion # can assign other seed here + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed) + pipeline.input_config.seed = seed + videos = pipeline.sample_video(generator = generator, add_controlnet=True,) + + videos = rearrange(videos, "b c f h w -> b f h w c") + save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + + "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4') + videos_uint8 = (videos[0] * 255).astype(np.uint8) + imageio.mimwrite(save_path, videos_uint8, fps=8) + print(save_path,"is done") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",) + + parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml") + parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl") + parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/") + parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/") + + parser.add_argument("--visible_gpu", type=str, default=None) + parser.add_argument("--default-seed", type=int, default=76739) + parser.add_argument("--L", type=int, default=16) + parser.add_argument("--W", type=int, default=512) + parser.add_argument("--H", type=int, default=512) + + parser.add_argument("--without-xformers", action="store_true") + + args = parser.parse_args() + main(args) diff --git a/models/Motion_Module/Put motion module checkpoints here.txt b/models/Motion_Module/Put motion module checkpoints here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/motionclone/models/__pycache__/attention.cpython-310.pyc b/motionclone/models/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea54f8caca44ea53ba0ead2c63c1752ce6f24624 Binary files /dev/null and b/motionclone/models/__pycache__/attention.cpython-310.pyc differ diff --git a/motionclone/models/__pycache__/attention.cpython-38.pyc b/motionclone/models/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6fa4ec26e7856ab053c249aa1443844c29a6bfb Binary files /dev/null and b/motionclone/models/__pycache__/attention.cpython-38.pyc differ diff --git a/motionclone/models/__pycache__/motion_module.cpython-310.pyc b/motionclone/models/__pycache__/motion_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..408746bf9bb19b24da3ae21fc247ba474b0504cf Binary files /dev/null and b/motionclone/models/__pycache__/motion_module.cpython-310.pyc differ diff --git a/motionclone/models/__pycache__/motion_module.cpython-38.pyc b/motionclone/models/__pycache__/motion_module.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b290242c4b1a8a7c0f56a7dac3d87ac4da4b83e Binary files /dev/null and b/motionclone/models/__pycache__/motion_module.cpython-38.pyc differ diff --git a/motionclone/models/__pycache__/resnet.cpython-310.pyc b/motionclone/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f6f281b38a70fd2684ff18a53cb147aab8f0f7e Binary files /dev/null and b/motionclone/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/motionclone/models/__pycache__/resnet.cpython-38.pyc b/motionclone/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5b9a84748436d7561a1df834bbad22ec3afa747 Binary files /dev/null and b/motionclone/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc b/motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d010ea139388883ee27d631bd95571c4d88a706a Binary files /dev/null and b/motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc differ diff --git a/motionclone/models/__pycache__/unet.cpython-310.pyc b/motionclone/models/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07d24c4b07fdf2c54e773fc5395e8dba92ae7a44 Binary files /dev/null and b/motionclone/models/__pycache__/unet.cpython-310.pyc differ diff --git a/motionclone/models/__pycache__/unet.cpython-38.pyc b/motionclone/models/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71f6558a171629833ee59f7119d44e82d9ea1df6 Binary files /dev/null and b/motionclone/models/__pycache__/unet.cpython-38.pyc differ diff --git a/motionclone/models/__pycache__/unet_blocks.cpython-310.pyc b/motionclone/models/__pycache__/unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2bdceffe562d44da52d7e7075fe4428c8a4c9ba Binary files /dev/null and b/motionclone/models/__pycache__/unet_blocks.cpython-310.pyc differ diff --git a/motionclone/models/__pycache__/unet_blocks.cpython-38.pyc b/motionclone/models/__pycache__/unet_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a051169a9b4d3e3a2a94cfac9130469d94778689 Binary files /dev/null and b/motionclone/models/__pycache__/unet_blocks.cpython-38.pyc differ diff --git a/motionclone/models/attention.py b/motionclone/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..aff141d29a8ad89f2f896d56bee9e9775afe3b2f --- /dev/null +++ b/motionclone/models/attention.py @@ -0,0 +1,611 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm + +from einops import rearrange, repeat +import pdb + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + + unet_use_cross_frame_attention = None, + unet_use_temporal_attention = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + # SC-Attn + assert unet_use_cross_frame_attention is not None + if unet_use_cross_frame_attention: + self.attn1 = SparseCausalAttention2D( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + else: + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + assert unet_use_temporal_attention is not None + if unet_use_temporal_attention: + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + # if self.only_cross_attention: + # hidden_states = ( + # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + # ) + # else: + # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # pdb.set_trace() + if self.unet_use_cross_frame_attention: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + #### add processer + self.processor = None + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + # query = self.reshape_heads_to_batch_dim(query) # move backwards + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + ######record###### record before reshape heads to batch dim + if self.processor is not None: + self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask) + ################## + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + ######record###### + if self.processor is not None: + self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask) + ################## + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + query = self.reshape_heads_to_batch_dim(query) # reshape query + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + ######record###### + if self.processor is not None: + self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask) + ################## + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs diff --git a/motionclone/models/motion_module.py b/motionclone/models/motion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6ac89706d27f20d6c29467410a2938460025f9 --- /dev/null +++ b/motionclone/models/motion_module.py @@ -0,0 +1,347 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import numpy as np +import torch.nn.functional as F +from torch import nn +import torchvision + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward +from .attention import CrossAttention + +from einops import rearrange, repeat +import math + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def get_motion_module( # 只能返回VanillaTemporalModule类 + in_channels, + motion_module_type: str, + motion_module_kwargs: dict +): + if motion_module_type == "Vanilla": + return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", "Temporal_Self" ), + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + temporal_attention_dim_div = 1, + zero_initialize = True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): + hidden_states = input_tensor + hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + + num_layers, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), # 两个TempAttn + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, + video_length=video_length, + ) + hidden_states + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 24 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + # self.register_buffer('pe', pe) + self.register_buffer('pe', pe, persistent=False) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class VersatileAttention(CrossAttention): # 继承CrossAttention类,不需要在额外写set_processor功能 + def __init__( + self, + attention_mode = None, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_position_encoding_max_len + ) if (temporal_position_encoding and attention_mode == "Temporal") else None + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states + else: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + # query = self.reshape_heads_to_batch_dim(query) # move backwards + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + ######record###### record before reshape heads to batch dim + if self.processor is not None: + self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask) + ################## + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + query = self.reshape_heads_to_batch_dim(query) # reshape query here + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + ######record###### + # if self.processor is not None: + # self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask) + ################## + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + diff --git a/motionclone/models/resnet.py b/motionclone/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9e77c7e89408239b0b7e495dbee3d4709b364873 --- /dev/null +++ b/motionclone/models/resnet.py @@ -0,0 +1,218 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # if self.use_conv: + # if self.name == "conv": + # hidden_states = self.conv(hidden_states) + # else: + # hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + use_inflated_groupnorm=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + self.upsample = self.downsample = None + + if groups_out is None: + groups_out = groups + + assert use_inflated_groupnorm != None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/motionclone/models/scheduler.py b/motionclone/models/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..b298e57371845a27e0072bf63b54432d80ac8a3e --- /dev/null +++ b/motionclone/models/scheduler.py @@ -0,0 +1,155 @@ +from typing import Optional, Tuple, Union + +import torch +from diffusers import DDIMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput +from diffusers.utils.torch_utils import randn_tensor + + +class CustomDDIMScheduler(DDIMScheduler): + @torch.no_grad() + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + + # Guidance parameters + score=None, + guidance_scale=0.0, + indices=None, # [0] + + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + + # Support IF models + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output + pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] + + # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf + if score is not None and guidance_scale > 0.0: # indices指定了应用guidance的位置,此处indices = [0] + if indices is not None: + # import pdb; pdb.set_trace() + assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape" + pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # 只修改了其中第一个[1, 4, 64, 64]的部分 + else: + assert pred_epsilon.shape == score.shape + pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score + # + + # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # [2, 4, 64, 64] + + # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # [2, 4, 64, 64] + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise # 最后还要再加一些随机噪声 + + prev_sample = prev_sample + variance # [2, 4, 64, 64] + self.pred_epsilon = pred_epsilon + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/motionclone/models/sparse_controlnet.py b/motionclone/models/sparse_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..18ae386dad6b10334b09b4846e23666b2410d113 --- /dev/null +++ b/motionclone/models/sparse_controlnet.py @@ -0,0 +1,593 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Changes were made to this source code by Yuwei Guo. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin + + +from .unet_blocks import ( + CrossAttnDownBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + get_down_block, +) +from einops import repeat, rearrange +from .resnet import InflatedConv3d + +from diffusers.models.unet_2d_condition import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SparseControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class SparseControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class SparseControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + + use_motion_module = True, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_type = "Vanilla", + motion_module_kwargs = { + "num_attention_heads": 8, + "num_transformer_block": 1, + "attention_block_types": ["Temporal_Self"], + "temporal_position_encoding": True, + "temporal_position_encoding_max_len": 32, + "temporal_attention_dim_div": 1, + "causal_temporal_attention": False, + }, + + concate_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = False, + + set_noisy_sample_input_to_zero: bool = False, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero + + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = InflatedConv3d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + if concate_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + self.concate_conditioning_mask = concate_conditioning_mask + + # control net conditioning embedding + if use_simplified_condition_embedding: + self.controlnet_cond_embedding = zero_module( + InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) + ).to(torch.float16) + else: + self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ).to(torch.float16) + self.use_simplified_condition_embedding = use_simplified_condition_embedding + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_inflated_groupnorm=True, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + use_inflated_groupnorm=True, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + + controlnet_additional_kwargs: dict = {}, + ): + controlnet = cls( + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + + **controlnet_additional_kwargs, + ) + + if load_weights_from_unet: + m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False) + assert len(u) == 0 + + if controlnet.class_embedding: + m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False) + assert len(u) == 0 + + return controlnet + + @staticmethod + def image_layer_filter(state_dict): + new_state_dict = {} + for name, param in state_dict.items(): + if "motion_modules." in name or "lora" in name: continue + new_state_dict[name] = param + return new_state_dict + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + + controlnet_cond: torch.FloatTensor, + conditioning_mask: Optional[torch.FloatTensor] = None, + + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[SparseControlNetOutput, Tuple]: + + # set input noise to zero + # if self.set_noisy_sample_input_to_zero: + # sample = torch.zeros_like(sample).to(sample.device) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) + encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + # equal to set input noise to zero + if self.set_noisy_sample_input_to_zero: + shape = sample.shape + sample = self.conv_in.bias.reshape(1,-1,1,1,1).expand(shape[0],-1,shape[2],shape[3],shape[4]) + else: + sample = self.conv_in(sample) + + if self.concate_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1).to(torch.float16) + # import pdb; pdb.set_trace() + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. controlnet blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return SparseControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/motionclone/models/unet.py b/motionclone/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..bd221a9652287e6868ef444566e06a49008c81eb --- /dev/null +++ b/motionclone/models/unet.py @@ -0,0 +1,515 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import json +import pdb + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d, InflatedGroupNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( # 第一个不带有CrossAttn,后面三个带有CrossAttn + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + + use_inflated_groupnorm=False, + + # Additional + use_motion_module = False, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type = None, + motion_module_kwargs = {}, + unet_use_cross_frame_attention = False, + unet_use_temporal_attention = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + + # support controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + + down_block_res_samples += res_samples + + # support controlnet + down_block_res_samples = list(down_block_res_samples) + if down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # boardcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # support controlnet + if mid_block_additional_residual is not None: + if mid_block_additional_residual.dim() == 4: # boardcast + mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### motion keys will be loaded: {len(m)}; \n### unexpected keys: {len(u)};") + + params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] + print(f"### Motion Module Parameters: {sum(params) / 1e6} M") + + return model diff --git a/motionclone/models/unet_blocks.py b/motionclone/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0059b05cfd85002a9fff1e0d7dc92354724236d4 --- /dev/null +++ b/motionclone/models/unet_blocks.py @@ -0,0 +1,760 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from .motion_module import get_motion_module + +import pdb + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + + use_inflated_groupnorm=False, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + + use_inflated_groupnorm=False, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc b/motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..266bfe9e106422bbda83fe8a4236670c962f2856 Binary files /dev/null and b/motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc differ diff --git a/motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc b/motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2340e2d05fc6f584def4100fbc042ddb934a8320 Binary files /dev/null and b/motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc differ diff --git a/motionclone/pipelines/pipeline_animation.py b/motionclone/pipelines/pipeline_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..1794e0dc0671c85303fbd5b67a6c4806d6e0f3ce --- /dev/null +++ b/motionclone/pipelines/pipeline_animation.py @@ -0,0 +1,497 @@ +# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +import inspect +from typing import Callable, List, Optional, Union, Any, Dict +from dataclasses import dataclass +from diffusers import StableDiffusionPipeline, DDIMInverseScheduler + +import os +import pickle +import numpy as np +import torch +from tqdm import tqdm +import omegaconf +from omegaconf import OmegaConf +import yaml +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput +from einops import rearrange +from ..models.unet import UNet3DConditionModel +from ..models.sparse_controlnet import SparseControlNetModel +from ..utils.xformer_attention import * +from ..utils.conv_layer import * +from ..utils.util import * + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class AnimationPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class AnimationPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + controlnet: Union[SparseControlNetModel, None] = None, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + @torch.no_grad() + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = shape + # shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + save_probs = False, + # support controlnet + controlnet_images: torch.FloatTensor = None, + controlnet_image_index: list = [0], + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + + **kwargs, + ): + + # to record temp attention probs + self.unet = prep_unet_attention(self.unet) + self.unet = prep_unet_conv(self.unet) + self.guidance_config = guidance_scale + + self.temp_rec = {} + + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + # batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 + if latents is not None: + batch_size = latents.shape[0] + if isinstance(prompt, list): + batch_size = len(prompt) + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size + if negative_prompt is not None: + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # import pdb; pdb.set_trace() + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + down_block_additional_residuals = mid_block_additional_residual = None + # import pdb; pdb.set_trace() + if (getattr(self, "controlnet", None) != None) and (controlnet_images != None): + assert controlnet_images.dim() == 5 + + controlnet_noisy_latents = latent_model_input + controlnet_prompt_embeds = text_embeddings + + controlnet_images = controlnet_images.to(latents.device) + + controlnet_cond_shape = list(controlnet_images.shape) + controlnet_cond_shape[2] = video_length + controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) + + controlnet_conditioning_mask_shape = list(controlnet_cond.shape) + controlnet_conditioning_mask_shape[1] = 1 + controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) + + assert controlnet_images.shape[2] >= len(controlnet_image_index) + controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] + controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 + + down_block_additional_residuals, mid_block_additional_residual = self.controlnet( + controlnet_noisy_latents, t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_cond, + conditioning_mask=controlnet_conditioning_mask, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=False, return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals = down_block_additional_residuals, + mid_block_additional_residual = mid_block_additional_residual, + ).sample.to(dtype=latents_dtype) + + + # get temp attn probs + if save_probs: + temp_attn_prob = self.get_temp_attn_prob() + for key in temp_attn_prob.keys(): + temp_attn_prob[key] = temp_attn_prob[key].chunk(2, dim = 0)[0].detach().clone().cpu() + self.temp_rec[i] = temp_attn_prob + # import pdb; pdb.set_trace() + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # pickle temp attn prob + if save_probs: + with open('temp_dic.pkl', 'wb') as f: + pickle.dump(self.temp_rec, f) + + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return AnimationPipelineOutput(videos=video) diff --git a/motionclone/utils/__pycache__/conv_layer.cpython-38.pyc b/motionclone/utils/__pycache__/conv_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c773518282b32c1c5522214df4d4fcb791bbed2 Binary files /dev/null and b/motionclone/utils/__pycache__/conv_layer.cpython-38.pyc differ diff --git a/motionclone/utils/__pycache__/convert_from_ckpt.cpython-38.pyc b/motionclone/utils/__pycache__/convert_from_ckpt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61596c277204bdb88dd0e96797e78b65f1611418 Binary files /dev/null and b/motionclone/utils/__pycache__/convert_from_ckpt.cpython-38.pyc differ diff --git a/motionclone/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc b/motionclone/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71b89ac40a947be5724989ae64ed147357881a6d Binary files /dev/null and b/motionclone/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc differ diff --git a/motionclone/utils/__pycache__/motionclone_functions.cpython-38.pyc b/motionclone/utils/__pycache__/motionclone_functions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cde10399714a705591c9b50ff9640e2eaedbeec Binary files /dev/null and b/motionclone/utils/__pycache__/motionclone_functions.cpython-38.pyc differ diff --git a/motionclone/utils/__pycache__/util.cpython-38.pyc b/motionclone/utils/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6bd26f6c04f624eccb6a5061832f83e40f277f5 Binary files /dev/null and b/motionclone/utils/__pycache__/util.cpython-38.pyc differ diff --git a/motionclone/utils/__pycache__/utils_freetraj.cpython-38.pyc b/motionclone/utils/__pycache__/utils_freetraj.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19689008732a2e31955e56010881d6f2603eb13e Binary files /dev/null and b/motionclone/utils/__pycache__/utils_freetraj.cpython-38.pyc differ diff --git a/motionclone/utils/__pycache__/xformer_attention.cpython-38.pyc b/motionclone/utils/__pycache__/xformer_attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f16d910a439a265030f4f0aa6b3fda4bc060fa02 Binary files /dev/null and b/motionclone/utils/__pycache__/xformer_attention.cpython-38.pyc differ diff --git a/motionclone/utils/conv_layer.py b/motionclone/utils/conv_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c09c80b60adf5f1e312457da2f70b29c76d6da --- /dev/null +++ b/motionclone/utils/conv_layer.py @@ -0,0 +1,69 @@ +import torch + +def conv_forward(self): + def forward(input_tensor, temb, scale=1.0): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + # import pdb; pdb.set_trace() + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None].repeat(1, 1, hidden_states.shape[2], 1, 1) + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + # record hidden state + self.record_hidden_state = hidden_states + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + return forward + + +def get_conv_feat(unet): + hidden_state_dict = dict() + for i in range(len(unet.up_blocks)): + for j in range(len(unet.up_blocks[i].resnets)): + module = unet.up_blocks[i].resnets[j] + module_name = f"up_blocks.{i}.resnets.{j}" + # print(module_name) + hidden_state_dict[module_name] = module.record_hidden_state + return hidden_state_dict + + +def prep_unet_conv(unet): + for i in range(len(unet.up_blocks)): + for j in range(len(unet.up_blocks[i].resnets)): + module = unet.up_blocks[i].resnets[j] + module.forward = conv_forward(module) + return unet diff --git a/motionclone/utils/convert_from_ckpt.py b/motionclone/utils/convert_from_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..a395d13e040c31f95835ae7cd224c806d7241197 --- /dev/null +++ b/motionclone/utils/convert_from_ckpt.py @@ -0,0 +1,968 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from io import BytesIO +from typing import Optional + +import requests +import torch +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers.models import ( + AutoencoderKL, + PriorTransformer, + UNet2DConditionModel, +) +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils.import_utils import BACKENDS_MAPPING + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + unet_params = original_config.model.params.unet_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + } + + if not controlnet: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint_concise(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + return text_model_dict + +def convert_ldm_clip_checkpoint(checkpoint): + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + text_model.load_state_dict(text_model_dict,strict=True) + + return text_model + + +textenc_conversion_lst = [ + ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), + ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint(checkpoint): + text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + if "cond_stage_model.model.text_projection" in checkpoint: + d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + continue + if key in textenc_conversion_map: + text_model_dict[textenc_conversion_map[key]] = checkpoint[key] + if key.startswith("cond_stage_model.model.transformer."): + new_key = key[len("cond_stage_model.model.transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + controlnet_model = ControlNetModel(**ctrlnet_config) + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + ) + + controlnet_model.load_state_dict(converted_ctrl_checkpoint) + + return controlnet_model diff --git a/motionclone/utils/convert_lora_safetensor_to_diffusers.py b/motionclone/utils/convert_lora_safetensor_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..549336a62011dfa23d74452d64f5b9eec8a6b81b --- /dev/null +++ b/motionclone/utils/convert_lora_safetensor_to_diffusers.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Changes were made to this source code by Yuwei Guo. +""" Conversion script for the LoRA's safetensors checkpoints. """ + +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import StableDiffusionPipeline + + +def load_diffusers_lora(pipeline, state_dict, alpha=1.0): + # directly update weight in diffusers model + for key in state_dict: + # only process lora down key + if "up." in key: continue + + up_key = key.replace(".down.", ".up.") + model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") + model_key = model_key.replace("to_out.", "to_out.0.") + layer_infos = model_key.split(".")[:-1] + + curr_layer = pipeline.unet + while len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + curr_layer = curr_layer.__getattr__(temp_name) + + weight_down = state_dict[key] + weight_up = state_dict[up_key] + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + return pipeline + + +def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): + # load base model + # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + + # load LoRA weight from .safetensors + # state_dict = load_file(checkpoint_path) + + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + args = parser.parse_args() + + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha + + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/motionclone/utils/motionclone_functions.py b/motionclone/utils/motionclone_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4db94942542de077c80c767a33d3df6c529128 --- /dev/null +++ b/motionclone/utils/motionclone_functions.py @@ -0,0 +1,663 @@ +from dataclasses import dataclass +import os +import numpy as np +import torch +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from typing import Callable, List, Optional, Union +from diffusers.utils import deprecate, logging, BaseOutput +from .xformer_attention import * +from .conv_layer import * +from .util import * +from diffusers.utils.torch_utils import randn_tensor +from typing import List, Optional, Tuple, Union +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from motionclone.utils.util import video_preprocess +import einops +import torchvision.transforms as transforms + +def add_noise(self, timestep, x_0, noise_pred): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + latents_noise = alpha_prod_t ** 0.5 * x_0 + beta_prod_t ** 0.5 * noise_pred + return latents_noise + +@torch.no_grad() +def obtain_motion_representation(self, generator=None, motion_representation_path: str = None, + duration=None,use_controlnet=False,): + + video_data = video_preprocess(self.input_config.video_path, self.input_config.height, + self.input_config.width, self.input_config.video_length,duration=duration) + video_latents = self.vae.encode(video_data.to(self.vae.dtype).to(self.vae.device)).latent_dist.mode() + video_latents = self.vae.config.scaling_factor * video_latents + video_latents = video_latents.unsqueeze(0) + video_latents = einops.rearrange(video_latents, "b f c h w -> b c f h w") + + uncond_input = self.tokenizer( + [""], padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt" + ) + step_t = int(self.input_config.add_noise_step) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + noise_sampled = randn_tensor(video_latents.shape, generator=generator, device=video_latents.device, dtype=video_latents.dtype) + noisy_latents = self.add_noise(step_t, video_latents, noise_sampled) + + down_block_additional_residuals = mid_block_additional_residual = None + if use_controlnet: + controlnet_image_index = self.input_config.image_index + if self.controlnet.use_simplified_condition_embedding: + controlnet_images = video_latents[:,:,controlnet_image_index,:,:] + else: + controlnet_images = (einops.rearrange(video_data.unsqueeze(0).to(self.vae.dtype).to(self.vae.device), "b f c h w -> b c f h w")+1)/2 + controlnet_images = controlnet_images[:,:,controlnet_image_index,:,:] + + controlnet_cond_shape = list(controlnet_images.shape) + controlnet_cond_shape[2] = noisy_latents.shape[2] + controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) + + controlnet_conditioning_mask_shape = list(controlnet_cond.shape) + controlnet_conditioning_mask_shape[1] = 1 + controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) + + controlnet_cond[:,:,controlnet_image_index] = controlnet_images + controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 + + down_block_additional_residuals, mid_block_additional_residual = self.controlnet( + noisy_latents, step_t, + encoder_hidden_states=uncond_embeddings, + controlnet_cond=controlnet_cond, + conditioning_mask=controlnet_conditioning_mask, + conditioning_scale=self.input_config.controlnet_scale, + guess_mode=False, return_dict=False, + ) + + _ = self.unet(noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, return_dict=False, only_motion_feature=True, + down_block_additional_residuals = down_block_additional_residuals, + mid_block_additional_residual = mid_block_additional_residual,) + temp_attn_prob_control = self.get_temp_attn_prob() + + motion_representation = { key: [max_value, max_index.to(torch.uint8)] for key, tensor in temp_attn_prob_control.items() for max_value, max_index in [torch.topk(tensor, k=1, dim=-1)]} + + torch.save(motion_representation, motion_representation_path) + self.motion_representation_path = motion_representation_path + + +def compute_temp_loss(self, temp_attn_prob_control_dict): + temp_attn_prob_loss = [] + for name in temp_attn_prob_control_dict.keys(): + current_temp_attn_prob = temp_attn_prob_control_dict[name] + reference_representation_dict = self.motion_representation_dict[name] + + max_index = reference_representation_dict[1].to(torch.int64).to(current_temp_attn_prob.device) + current_motion_representation = torch.gather(current_temp_attn_prob, index = max_index, dim=-1) + + reference_motion_representation = reference_representation_dict[0].to(dtype = current_motion_representation.dtype, device = current_motion_representation.device) + + module_attn_loss = F.mse_loss(current_motion_representation, reference_motion_representation.detach()) + temp_attn_prob_loss.append(module_attn_loss) + + loss_temp = torch.stack(temp_attn_prob_loss) + return loss_temp.sum() + +def sample_video( + self, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + noisy_latents: Optional[torch.FloatTensor] = None, + add_controlnet: bool = False, +): + # Determine if use controlnet, i.e., conditional image2video + self.add_controlnet = add_controlnet + if self.add_controlnet: + image_transforms = transforms.Compose([ + transforms.Resize((self.input_config.height, self.input_config.width)), + transforms.ToTensor(), + ]) + + controlnet_images = [image_transforms(Image.open(path).convert("RGB")) for path in self.input_config.condition_image_path_list] + controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(dtype=self.vae.dtype,device=self.vae.device) + controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") + + with torch.no_grad(): + if self.controlnet.use_simplified_condition_embedding: + num_controlnet_images = controlnet_images.shape[2] + controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") + controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * self.vae.config.scaling_factor + self.controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) + else: + self.controlnet_images = controlnet_images + + + # Define call parameters + # perform classifier_free_guidance in default + batch_size = 1 + do_classifier_free_guidance = True + device = self._execution_device + + # Encode input prompt + self.text_embeddings = self._encode_prompt(self.input_config.new_prompt, device, 1, do_classifier_free_guidance, self.input_config.negative_prompt) + # [uncond_embeddings, text_embeddings] [2, 77, 768] + + # Prepare latent variables + noisy_latents = self.prepare_latents( + batch_size, + self.unet.config.in_channels, + self.input_config.video_length, + self.input_config.height, + self.input_config.width, + self.text_embeddings.dtype, + device, + generator, + noisy_latents, + ) + + self.motion_representation_dict = torch.load(self.motion_representation_path) + self.motion_scale = self.input_config.motion_guidance_weight + + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # save GPU memory + # self.vae.to(device = "cpu") + # self.text_encoder.to(device = "cpu") + # torch.cuda.empty_cache() + + with self.progress_bar(total=self.input_config.inference_steps) as progress_bar: + for step_index, step_t in enumerate(self.scheduler.timesteps): + noisy_latents = self.single_step_video(noisy_latents, step_index, step_t, extra_step_kwargs) + progress_bar.update() + + # decode latents for videos + video = self.decode_latents(noisy_latents) + return video + +def single_step_video(self, noisy_latents, step_index, step_t, extra_step_kwargs): + + down_block_additional_residuals = mid_block_additional_residual = None + if self.add_controlnet: + with torch.no_grad(): + controlnet_cond_shape = list(self.controlnet_images.shape) + controlnet_cond_shape[2] = noisy_latents.shape[2] + controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) + + controlnet_conditioning_mask_shape = list(controlnet_cond.shape) + controlnet_conditioning_mask_shape[1] = 1 + controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) + + controlnet_image_index = self.input_config.image_index + controlnet_cond[:,:,controlnet_image_index] = self.controlnet_images + controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 + + down_block_additional_residuals, mid_block_additional_residual = self.controlnet( + noisy_latents.expand(2,-1,-1,-1,-1), step_t, + encoder_hidden_states=self.text_embeddings, + controlnet_cond=controlnet_cond, + conditioning_mask=controlnet_conditioning_mask, + conditioning_scale=self.input_config.controlnet_scale, + guess_mode=False, return_dict=False, + ) + + # Only require grad when need to compute the gradient for guidance + if step_index < self.input_config.guidance_steps: + + down_block_additional_residuals_uncond = down_block_additional_residuals_cond = None + mid_block_additional_residual_uncond = mid_block_additional_residual_cond = None + if self.add_controlnet: + down_block_additional_residuals_uncond = [tensor[[0],...].detach() for tensor in down_block_additional_residuals] + down_block_additional_residuals_cond = [tensor[[1],...].detach() for tensor in down_block_additional_residuals] + mid_block_additional_residual_uncond = mid_block_additional_residual[[0],...].detach() + mid_block_additional_residual_cond = mid_block_additional_residual[[1],...].detach() + + control_latents = noisy_latents.clone().detach() + control_latents.requires_grad = True + + control_latents = self.scheduler.scale_model_input(control_latents, step_t) + noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) + + with torch.no_grad(): + noise_pred_uncondition = self.unet(noisy_latents, step_t, encoder_hidden_states=self.text_embeddings[[0]], + down_block_additional_residuals = down_block_additional_residuals_uncond, + mid_block_additional_residual = mid_block_additional_residual_uncond,).sample.to(dtype=noisy_latents.dtype) + + noise_pred_condition = self.unet(control_latents, step_t, encoder_hidden_states=self.text_embeddings[[1]], + down_block_additional_residuals = down_block_additional_residuals_cond, + mid_block_additional_residual = mid_block_additional_residual_cond,).sample.to(dtype=noisy_latents.dtype) + temp_attn_prob_control = self.get_temp_attn_prob() + + loss_motion = self.motion_scale * self.compute_temp_loss(temp_attn_prob_control,) + + if step_index < self.input_config.warm_up_steps: + scale = (step_index+1)/self.input_config.warm_up_steps + loss_motion = scale*loss_motion + + if step_index > self.input_config.guidance_steps-self.input_config.cool_up_steps: + scale = (self.input_config.guidance_steps-step_index)/self.input_config.cool_up_steps + loss_motion = scale*loss_motion + + gradient = torch.autograd.grad(loss_motion, control_latents, allow_unused=True)[0] # [1, 4, 16, 64, 64], + assert gradient is not None, f"Step {step_index}: grad is None" + + noise_pred = noise_pred_condition + self.input_config.cfg_scale * (noise_pred_condition - noise_pred_uncondition) + + control_latents = self.scheduler.customized_step(noise_pred, step_index, control_latents, score=gradient.detach(), + **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] + return control_latents.detach() + + else: + with torch.no_grad(): + noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) + noise_pred_group = self.unet( + noisy_latents.expand(2,-1,-1,-1,-1), step_t, + encoder_hidden_states=self.text_embeddings, + down_block_additional_residuals = down_block_additional_residuals, + mid_block_additional_residual = mid_block_additional_residual, + ).sample.to(dtype=noisy_latents.dtype) + + noise_pred = noise_pred_group[[1]] + self.input_config.cfg_scale * (noise_pred_group[[1]] - noise_pred_group[[0]]) + noisy_latents = self.scheduler.customized_step(noise_pred, step_index, noisy_latents, score=None, **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] + return noisy_latents.detach() + + +def get_temp_attn_prob(self,index_select=None): + + attn_prob_dic = {} + + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if "VersatileAttention" in module_name and classify_blocks(self.input_config.motion_guidance_blocks, name): + key = module.processor.key + if index_select is not None: + get_index = torch.repeat_interleave(torch.tensor(index_select), repeats=key.shape[0]//len(index_select)) + index_all = torch.arange(key.shape[0]) + index_picked = index_all[get_index.bool()] + key = key[index_picked] + key = module.reshape_heads_to_batch_dim(key).contiguous() + + query = module.processor.query + if index_select is not None: + query = query[index_picked] + query = module.reshape_heads_to_batch_dim(query).contiguous() + attention_probs = module.get_attention_scores(query, key, None) + attention_probs = attention_probs.reshape(-1, module.heads,attention_probs.shape[1], attention_probs.shape[2]) + attn_prob_dic[name] = attention_probs + + return attn_prob_dic + +@torch.no_grad() +def schedule_customized_step( + self, + model_output: torch.FloatTensor, + step_index: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + + # Guidance parameters + score=None, + guidance_scale=1.0, + indices=None, # [0] + return_middle = False, +): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + + # Support IF models + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + timestep = self.timesteps[step_index] + # 1. get previous step value (=t-1) + # prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = self.timesteps[step_index+1] if step_index +1 = 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output + pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] + + if score is not None and return_middle: + return pred_epsilon, alpha_prod_t, alpha_prod_t_prev, pred_original_sample + + # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf + if score is not None and guidance_scale > 0.0: + if indices is not None: + # import pdb; pdb.set_trace() + assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape" + pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score + else: + assert pred_epsilon.shape == score.shape + pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score + # + + # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon + + # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return prev_sample, pred_original_sample, alpha_prod_t_prev + + + +def schedule_set_timesteps(self, num_inference_steps: int, guidance_steps: int = 0, guiduance_scale: float = 0.0, device: Union[str, torch.device] = None,timestep_spacing_type= "uneven"): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # assign more steps in early denoising stage for motion guidance + if timestep_spacing_type == "uneven": + timesteps_guidance = ( + np.linspace(int((1-guiduance_scale)*self.config.num_train_timesteps), self.config.num_train_timesteps - 1, guidance_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + timesteps_vanilla = ( + np.linspace(0, int((1-guiduance_scale)*self.config.num_train_timesteps) - 1, num_inference_steps-guidance_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + timesteps = np.concatenate((timesteps_guidance, timesteps_vanilla)) + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + elif timestep_spacing_type == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif timestep_spacing_type == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif timestep_spacing_type == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{timestep_spacing_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + +def unet_customized_forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + + # support controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + + return_dict: bool = True, + only_motion_feature: bool = False, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + + down_block_res_samples += res_samples + + # support controlnet + down_block_res_samples = list(down_block_res_samples) + if down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # boardcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # support controlnet + if mid_block_additional_residual is not None: + if mid_block_additional_residual.dim() == 4: # boardcast + mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + if i<= int(self.input_config.motion_guidance_blocks[-1].split(".")[-1]): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + else: + if only_motion_feature: + return 0 + with torch.no_grad(): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + diff --git a/motionclone/utils/util.py b/motionclone/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..87a6c8ac16656b49e0992c8b3da70315112c00c9 --- /dev/null +++ b/motionclone/utils/util.py @@ -0,0 +1,460 @@ +import hashlib +import io +import re +import os +import imageio +import numpy as np +from typing import Union + +import cv2 +import numpy as np +import requests +import random +import torch +import PIL.Image +import PIL.ImageOps +from PIL import Image +from typing import Callable, Union + +import torch +import torchvision +import torch.distributed as dist +import torch.nn.functional as F +import decord +decord.bridge.set_bridge('torch') +from PIL import Image, ImageOps + +from safetensors import safe_open +# from tqdm import tqdm +from einops import rearrange +from motionclone.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint,convert_ldm_clip_checkpoint_concise +from motionclone.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora +from huggingface_hub import snapshot_download +# from transformers import ( +# AutoFeatureExtractor, +# BertTokenizerFast, +# CLIPImageProcessor, +# CLIPTextConfig, +# CLIPTextModel, +# CLIPTextModelWithProjection, +# CLIPTokenizer, +# CLIPVisionConfig, +# CLIPVisionModelWithProjection, +# ) + +MOTION_MODULES = [ + "mm_sd_v14.ckpt", + "mm_sd_v15.ckpt", + "mm_sd_v15_v2.ckpt", + "v3_sd15_mm.ckpt", +] + +ADAPTERS = [ + # "mm_sd_v14.ckpt", + # "mm_sd_v15.ckpt", + # "mm_sd_v15_v2.ckpt", + # "mm_sdxl_v10_beta.ckpt", + "v2_lora_PanLeft.ckpt", + "v2_lora_PanRight.ckpt", + "v2_lora_RollingAnticlockwise.ckpt", + "v2_lora_RollingClockwise.ckpt", + "v2_lora_TiltDown.ckpt", + "v2_lora_TiltUp.ckpt", + "v2_lora_ZoomIn.ckpt", + "v2_lora_ZoomOut.ckpt", + "v3_sd15_adapter.ckpt", + # "v3_sd15_mm.ckpt", + "v3_sd15_sparsectrl_rgb.ckpt", + "v3_sd15_sparsectrl_scribble.ckpt", +] + +BACKUP_DREAMBOOTH_MODELS = [ + "realisticVisionV60B1_v51VAE.safetensors", + "majicmixRealistic_v4.safetensors", + "leosamsFilmgirlUltra_velvia20Lora.safetensors", + "toonyou_beta3.safetensors", + "majicmixRealistic_v5Preview.safetensors", + "rcnzCartoon3d_v10.safetensors", + "lyriel_v16.safetensors", + "leosamsHelloworldXL_filmGrain20.safetensors", + "TUSUN.safetensors", +] + +def zero_rank_print(s): + if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + +def auto_download(local_path, is_dreambooth_lora=False): + hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff" + folder, filename = os.path.split(local_path) + + if not os.path.exists(local_path): + print(f"local file {local_path} does not exist. trying to download from {hf_repo}") + + if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}" + else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}" + + folder = "." if folder == "" else folder + os.makedirs(folder, exist_ok=True) + snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename]) + +def load_weights( + animation_pipeline, + # motion module + motion_module_path = "", + motion_module_lora_configs = [], + # domain adapter + adapter_lora_path = "", + adapter_lora_scale = 1.0, + # image layers + dreambooth_model_path = "", + lora_model_path = "", + lora_alpha = 0.8, +): + # motion module + unet_state_dict = {} + if motion_module_path != "": + print(f"load motion module from {motion_module_path}") + motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") + motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict + unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) + unet_state_dict.pop("animatediff_config", "") + + missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) + # assert len(unexpected) == 0 + del unet_state_dict + + # base model + if dreambooth_model_path != "": + print(f"load dreambooth model from {dreambooth_model_path}") + if dreambooth_model_path.endswith(".safetensors"): + # import pdb; pdb.set_trace() + dreambooth_state_dict = {} + # import safetensors + # dreambooth_state_dict = safetensors.torch.load_file(dreambooth_model_path) + # import pdb; pdb.set_trace() + with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + dreambooth_state_dict[key] = f.get_tensor(key) + # import pdb; pdb.set_trace() + elif dreambooth_model_path.endswith(".ckpt"): + dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") + + # 1. vae + converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) + animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) + # 2. unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) + animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + + # 3. text_model + # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_concise(dreambooth_state_dict) + animation_pipeline.text_encoder.load_state_dict(converted_text_encoder_checkpoint, strict=True) + del dreambooth_state_dict, converted_vae_checkpoint, converted_unet_checkpoint, converted_text_encoder_checkpoint + + # clip_config_name = "models/clip-vit-large-patch14" + # clip_config = CLIPTextConfig.from_pretrained(clip_config_name, local_files_only=True) + # text_model = CLIPTextModel(clip_config) + # keys = list(dreambooth_state_dict.keys()) + # text_model_dict = {} + # for key in keys: + # if key.startswith("cond_stage_model.transformer"): + # text_model_dict[key[len("cond_stage_model.transformer.") :]] = dreambooth_state_dict[key] + # text_model.load_state_dict(text_model_dict) + # animation_pipeline.text_encoder = text_model.to(dtype=animation_pipeline.unet.dtype) + # # import pdb; pdb.set_trace() + # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) + # del dreambooth_state_dict + + # lora layers + if lora_model_path != "": + print(f"load lora model from {lora_model_path}") + assert lora_model_path.endswith(".safetensors") + lora_state_dict = {} + with safe_open(lora_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + lora_state_dict[key] = f.get_tensor(key) + + animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) + del lora_state_dict + + # domain adapter lora + if adapter_lora_path != "": + print(f"load domain lora from {adapter_lora_path}") + domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") + domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict + domain_lora_state_dict.pop("animatediff_config", "") + + animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) + + # motion module lora + for motion_module_lora_config in motion_module_lora_configs: + path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] + print(f"load motion LoRA from {path}") + motion_lora_state_dict = torch.load(path, map_location="cpu") + motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict + motion_lora_state_dict.pop("animatediff_config", "") + + animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) + + return animation_pipeline + +def video_preprocess(video_path, height, width, video_length, duration=None, sample_start_idx=0,): + #print("in video_preprocess:") + #print(video_path, height, width, video_length, duration, sample_start_idx) + + video_name = video_path.split('/')[-1].split('.')[0] + vr = decord.VideoReader(video_path) + fps = vr.get_avg_fps() + if duration is None: + # 读取整个视频 + total_frames = len(vr) + else: + # 根据给定的时长(秒)计算帧数 + total_frames = int(fps * duration) + total_frames = min(total_frames, len(vr)) # 确保不超过视频总长度 + + sample_index = np.linspace(0, total_frames - 1, video_length, dtype=int) + print(total_frames,sample_index) + video = vr.get_batch(sample_index) + + ''' + print("after betch :") + print(video) + ''' + import torch + if hasattr(video, "asnumpy"): + video = video.asnumpy() + video = torch.from_numpy(video) + + video = rearrange(video, "f h w c -> f c h w") + + video = F.interpolate(video, size=(height, width), mode="bilinear", align_corners=True) + + # video_sample = rearrange(video, "(b f) c h w -> b f h w c", f=video_length) + # imageio.mimwrite(f"processed_videos/sample_{video_name}.mp4", video_sample[0], fps=8, quality=9) + + video = video / 127.5 - 1.0 + + return video + + +def set_nested_item(dataDict, mapList, value): + """Set item in nested dictionary""" + """ + Example: the mapList contains the name of each key ['injection','self-attn'] + this method will change the content in dataDict['injection']['self-attn'] with value + + """ + for k in mapList[:-1]: + dataDict = dataDict[k] + dataDict[mapList[-1]] = value + + +def merge_sweep_config(base_config, update): + """Merge the updated parameters into the base config""" + + if base_config is None: + raise ValueError("Base config is None") + if update is None: + raise ValueError("Update config is None") + for key in update.keys(): + map_list = key.split("--") + set_nested_item(base_config, map_list, update[key]) + return base_config + + +# Adapt from https://github.com/castorini/daam +def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): + merge_idxs = [] + tokens = tokenizer.tokenize(prompt.lower()) + if word_idx is None: + word = word.lower() + search_tokens = tokenizer.tokenize(word) + start_indices = [x + offset_idx for x in range(len(tokens)) if + tokens[x:x + len(search_tokens)] == search_tokens] + for indice in start_indices: + merge_idxs += [i + indice for i in range(0, len(search_tokens))] + if not merge_idxs: + raise Exception(f'Search word {word} not found in prompt!') + else: + merge_idxs.append(word_idx) + + return [x + 1 for x in merge_idxs], word_idx # Offset by 1. + + +def extract_data(input_string: str) -> list: + print("input_string:", input_string) + """ + Extract data from a string pattern where contents in () are separated by ; + The first item in each () is considered as 'ref' and the rest as 'gen'. + + Args: + - input_string (str): The input string pattern. + + Returns: + - list: A list of dictionaries containing 'ref' and 'gen'. + """ + pattern = r'\(([^)]+)\)' + matches = re.findall(pattern, input_string) + + data = [] + for match in matches: + parts = [x.strip() for x in match.split(';')] + ref = parts[0].strip() + gen = parts[1].strip() + data.append({'ref': ref, 'gen': gen}) + + return data + + +def generate_hash_key(image, prompt=""): + """ + Generate a hash key for the given image and prompt. + """ + byte_array = io.BytesIO() + image.save(byte_array, format='JPEG') + + # Get byte data + image_byte_data = byte_array.getvalue() + + # Combine image byte data and prompt byte data + combined_data = image_byte_data + prompt.encode('utf-8') + + sha256 = hashlib.sha256() + sha256.update(combined_data) + return sha256.hexdigest() + + +def save_data(data, folder_path, key): + """ + Save data to a file, using key as the file name + """ + + if not os.path.exists(folder_path): + os.makedirs(folder_path) + file_path = os.path.join(folder_path, f"{key}.pt") + + torch.save(data, file_path) + + +def get_data(folder_path, key): + """ + Get data from a file, using key as the file name + :param folder_path: + :param key: + :return: + """ + + file_path = os.path.join(folder_path, f"{key}.pt") + if os.path.exists(file_path): + return torch.load(file_path) + else: + return None + + +def PILtoTensor(data: Image.Image) -> torch.Tensor: + return torch.tensor(np.array(data)).permute(2, 0, 1).unsqueeze(0).float() + + +def TensorToPIL(data: torch.Tensor) -> Image.Image: + return Image.fromarray(data.squeeze().permute(1, 2, 0).numpy().astype(np.uint8)) + +# Adapt from https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/utils/loading_utils.py#L9 +def load_image( + image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None +) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): + A conversion method to apply to the image after loading it. + When set to `None` the image will be converted "RGB". + + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + +# Take from huggingface/diffusers +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +def _in_step(config, step): + in_step = False + try: + start_step = config.start_step + end_step = config.end_step + if start_step <= step < end_step: + in_step = True + except: + in_step = False + return in_step + +def classify_blocks(block_list, name): + is_correct_block = False + for block in block_list: + if block in name: + is_correct_block = True + break + return is_correct_block + +def set_all_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True \ No newline at end of file diff --git a/motionclone/utils/xformer_attention.py b/motionclone/utils/xformer_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..468a3d8d13e0711c190a89fc8c7c6d02d7330bf5 --- /dev/null +++ b/motionclone/utils/xformer_attention.py @@ -0,0 +1,98 @@ +import math +from typing import Optional, Callable +import xformers +from omegaconf import OmegaConf +import yaml +from .util import classify_blocks + +def identify_blocks(block_list, name): + block_name = None + for block in block_list: + if block in name: + block_name = block + break + return block_name + + +class MySelfAttnProcessor: + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + + def __call__(self, attn, hidden_states, query, key, value, attention_mask): + # self.attn = attn + self.key = key + self.query = query + # self.value = value + # self.attention_mask = attention_mask + # self.hidden_state = hidden_states.detach() + # return hidden_states + + def record_qkv(self, attn, hidden_states, query, key, value, attention_mask): + # self.attn = attn + self.key = key + self.query = query + # self.value = value + # # self.attention_mask = attention_mask + # self.hidden_state = hidden_states.detach() + # # import pdb; pdb.set_trace() + + def record_attn_mask(self, attn, hidden_states, query, key, value, attention_mask): + self.attn = attn + self.attention_mask = attention_mask + + +def prep_unet_attention(unet,motion_gudiance_blocks): + # replace the fwd function + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if "VersatileAttention" in module_name and classify_blocks(motion_gudiance_blocks, name): # the temporary attention in guidance blocks + module.set_processor(MySelfAttnProcessor()) + # print(module_name) + return unet + + +def get_self_attn_feat(unet, injection_config, config): + hidden_state_dict = dict() + query_dict = dict() + key_dict = dict() + value_dict = dict() + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if "CrossAttention" in module_name and 'attn1' in name and classify_blocks(injection_config.blocks, name=name): + res = int(math.sqrt(module.processor.hidden_state.shape[1])) + # import pdb; pdb.set_trace() + bs = module.processor.hidden_state.shape[0] # 20 * 16 = 320 + # block_name = identify_blocks(injection_config.blocks, name=name) + # block_id = int(block_name.split('.')[-1]) + # h = config.H // (32 * block_id) + # w = config.W // (32 * block_id) + hidden_state_dict[name] = module.processor.hidden_state.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) + res = int(math.sqrt(module.processor.query.shape[1])) + query_dict[name] = module.processor.query.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) + key_dict[name] = module.processor.key.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) + value_dict[name] = module.processor.value.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) + # import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + return hidden_state_dict, query_dict, key_dict, value_dict + + +def clean_attn_buffer(unet): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "Attention" and 'attn' in name: + if 'injection_config' in module.processor.__dict__.keys(): + module.processor.injection_config = None + if 'injection_mask' in module.processor.__dict__.keys(): + module.processor.injection_mask = None + if 'obj_index' in module.processor.__dict__.keys(): + module.processor.obj_index = None + if 'pca_weight' in module.processor.__dict__.keys(): + module.processor.pca_weight = None + if 'pca_weight_changed' in module.processor.__dict__.keys(): + module.processor.pca_weight_changed = None + if 'pca_info' in module.processor.__dict__.keys(): + module.processor.pca_info = None + if 'step' in module.processor.__dict__.keys(): + module.processor.step = None diff --git a/reference_videos/camera_1.mp4 b/reference_videos/camera_1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..06a8c2008dce9fb6db2550cbfaf96bc401bd1623 --- /dev/null +++ b/reference_videos/camera_1.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62321a2a8bc671b03f912e363d66010a0bbcf6d98d31d9f30ac1cb7d06099a8d +size 400650 diff --git a/reference_videos/camera_pan_down.mp4 b/reference_videos/camera_pan_down.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..42ea580bc58386ae7ba7bc7313658d217fdca262 --- /dev/null +++ b/reference_videos/camera_pan_down.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af5b5fd84728f844fe7ce766e7920e08eaef0944da3462f2e17e954c067e2a57 +size 535648 diff --git a/reference_videos/camera_pan_up.mp4 b/reference_videos/camera_pan_up.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7d57faf25d99312d9c3c859fd7f976ee20e0c5c2 --- /dev/null +++ b/reference_videos/camera_pan_up.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f66af10dd2d110bba3c6047acf477d22b40904949a9098fb753133de7d80dbc4 +size 879335 diff --git a/reference_videos/camera_translation_1.mp4 b/reference_videos/camera_translation_1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7800188b46b8faa6c9ff4149ed0abb6b6e23e581 --- /dev/null +++ b/reference_videos/camera_translation_1.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1151181afa8d0560cb6a2be99c17fd657f6f9364fa10837713f48ab7cada1c3 +size 514540 diff --git a/reference_videos/camera_translation_2.mp4 b/reference_videos/camera_translation_2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..835ad7c38bbe0c2564d773e3ab56c11aac752e3f --- /dev/null +++ b/reference_videos/camera_translation_2.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0af642d5cc70ab35876e8fe7c5da112c4782163e12b9fc7b5ec33222863ee0e +size 684655 diff --git a/reference_videos/camera_zoom_in.mp4 b/reference_videos/camera_zoom_in.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..776081d8d45e77e39d3954447cc24c92d932f7d7 --- /dev/null +++ b/reference_videos/camera_zoom_in.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9bdb57576a801c4c049a59a20f9fc8b14eaa57beac50eddfa53c56076b64925 +size 1197273 diff --git a/reference_videos/camera_zoom_out.mp4 b/reference_videos/camera_zoom_out.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7582032f2dd40d97ef944f41e8d072ff488d42b3 --- /dev/null +++ b/reference_videos/camera_zoom_out.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b73e3471737528167cecad1d31a323da1a32708a115cb3fe0938bd42581f9070 +size 984866 diff --git a/reference_videos/sample_astronaut.mp4 b/reference_videos/sample_astronaut.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6d0f49e7030fdecf2e6b6d719bf646a257497e50 --- /dev/null +++ b/reference_videos/sample_astronaut.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70c22ce81987b9c7825af37a22588771ed03c6839bdf7446eec7417202d01916 +size 1456676 diff --git a/reference_videos/sample_blackswan.mp4 b/reference_videos/sample_blackswan.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..22b8d1fb1344d302a12c80989f474f3c7bb776a2 --- /dev/null +++ b/reference_videos/sample_blackswan.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13f05772a2cbed4e800820d22b0afffdfc053446116a3b700a686e7f63bf3662 +size 1351812 diff --git a/reference_videos/sample_cat.mp4 b/reference_videos/sample_cat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..41070c99a32ef3c0f8fe40e38d4f0947f6ddfcb0 --- /dev/null +++ b/reference_videos/sample_cat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af2ee3a217c85bbffbf2a8e8bc8e93aad93fc69bbad7d7bb98928f6f3641d546 +size 1681425 diff --git a/reference_videos/sample_cow.mp4 b/reference_videos/sample_cow.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a1801e31064d4fe46dbc261ca81cf3ca3296e839 --- /dev/null +++ b/reference_videos/sample_cow.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06db236f3fd3e07c82a49a3f1750558c0ee036c221d41af3baa164604e8723d2 +size 2255955 diff --git a/reference_videos/sample_fox.mp4 b/reference_videos/sample_fox.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..023999513583282722f5b5b8724096edb9ce2ed9 --- /dev/null +++ b/reference_videos/sample_fox.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ac422cdc91ef385361890e0f1267b917b76760c1f650e13586119e4092fa57d +size 1276068 diff --git a/reference_videos/sample_leaves.mp4 b/reference_videos/sample_leaves.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..82b2343b2f3cfdebacc8b2cc364dd3c063d4d386 --- /dev/null +++ b/reference_videos/sample_leaves.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dac5de1d0b77eabf6c3b7cb5a858b59a6577f563065f5daf5fffdaa7ab12ca50 +size 665431 diff --git a/reference_videos/sample_white_tiger.mp4 b/reference_videos/sample_white_tiger.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cae8a746b7176182677cc0307e064bcb6c0f52e0 --- /dev/null +++ b/reference_videos/sample_white_tiger.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1844a46e5d460b56074a20797908fe7ca71e2c51401c299449661cdc4530b698 +size 1969611 diff --git a/reference_videos/sample_wolf.mp4 b/reference_videos/sample_wolf.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8d38b428f9dd2355eb304b4044250bf0d14ad555 --- /dev/null +++ b/reference_videos/sample_wolf.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be57ba1048e9fab45a6aced93564445ae4a74c83a71888a96ca3507a7f919b9a +size 1894025 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3616de23867c42b623437a8f73d4f80bc1cbd7b4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +# requirements.txt + +# Python 版本建议在安装环境中设置,而不是在 requirements.txt 中指定 +# python=3.11.3 + +# PyTorch 及相关依赖 +#torch==2.0.1+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html +#torchvision==0.15.2+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html +torch==2.0.1 +torchvision==0.15.2 + +# 其他依赖 +#pip +accelerate +diffusers==0.16.0 +transformers==4.28.1 +xformers==0.0.20 +imageio[ffmpeg] +decord==0.6.0 +gdown +einops +omegaconf +safetensors +gradio +wandb +triton +opencv-python + +gradio +huggingface_hub==0.25.0 +httpx[socks] +matplotlib +numpy<2 diff --git a/t2v_video_app.py b/t2v_video_app.py new file mode 100644 index 0000000000000000000000000000000000000000..c115d911cd142f2d8e870316242431c506a3396c --- /dev/null +++ b/t2v_video_app.py @@ -0,0 +1,296 @@ +import gradio as gr +from omegaconf import OmegaConf +import torch +from diffusers import AutoencoderKL, DDIMScheduler +from transformers import CLIPTextModel, CLIPTokenizer +from motionclone.models.unet import UNet3DConditionModel +from motionclone.pipelines.pipeline_animation import AnimationPipeline +from motionclone.utils.util import load_weights +from diffusers.utils.import_utils import is_xformers_available +from motionclone.utils.motionclone_functions import * +import json +from motionclone.utils.xformer_attention import * +import os +import numpy as np +import imageio +import shutil +import subprocess + +# 权重下载函数 +def download_weights(): + try: + # 创建模型目录 + os.makedirs("models", exist_ok=True) + os.makedirs("models/DreamBooth_LoRA", exist_ok=True) + os.makedirs("models/Motion_Module", exist_ok=True) + os.makedirs("models/SparseCtrl", exist_ok=True) + + # 下载 Stable Diffusion 模型 + if not os.path.exists("models/StableDiffusion"): + subprocess.run(["git", "clone", "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5", "models/StableDiffusion"]) + + # 下载 DreamBooth LoRA 模型 + if not os.path.exists("models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"): + subprocess.run(["wget", "https://huggingface.co/svjack/Realistic-Vision-V6.0-B1/resolve/main/realisticVisionV60B1_v51VAE.safetensors", "-O", "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"]) + + # 下载 Motion Module 模型 + if not os.path.exists("models/Motion_Module/v3_sd15_mm.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_mm.ckpt", "-O", "models/Motion_Module/v3_sd15_mm.ckpt"]) + if not os.path.exists("models/Motion_Module/v3_sd15_adapter.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt", "-O", "models/Motion_Module/v3_sd15_adapter.ckpt"]) + + # 下载 SparseCtrl 模型 + if not os.path.exists("models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_rgb.ckpt", "-O", "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"]) + if not os.path.exists("models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"): + subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_scribble.ckpt", "-O", "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"]) + + print("Weights downloaded successfully.") + except Exception as e: + print(f"Error downloading weights: {e}") + +# 下载权重 +download_weights() + +# 加载 model_config +model_config_path = "configs/model_config/model_config.yaml" +model_config = OmegaConf.load(model_config_path) + +# 硬编码的配置值 +config = { + "motion_module": "models/Motion_Module/v3_sd15_mm.ckpt", + "dreambooth_path": "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors", + "model_config": model_config, + "W": 512, + "H": 512, + "L": 16, + "motion_guidance_blocks": ['up_blocks.1',] +} + +# 写死 pretrained_model_path +pretrained_model_path = "models/StableDiffusion" + +# 模型初始化逻辑 +def initialize_models(): + # 设置设备 + adopted_dtype = torch.float16 + device = "cuda" + set_all_seed(42) + + # 加载模型组件 + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype) + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype) + + # 更新配置 + config["width"] = config.get("W", 512) + config["height"] = config.get("H", 512) + config["video_length"] = config.get("L", 16) + + # 加载模型配置 + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=config["model_config"]["unet_additional_kwargs"]).to(device).to(dtype=adopted_dtype) + + # 启用 xformers + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + + # 创建 pipeline + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + controlnet=None, + scheduler=DDIMScheduler(**config["model_config"]["noise_scheduler_kwargs"]), + ).to(device) + + # 加载权重 + pipeline = load_weights( + pipeline, + motion_module_path=config["motion_module"], + dreambooth_model_path=config["dreambooth_path"], + ).to(device) + pipeline.text_encoder.to(dtype=adopted_dtype) + + # 加载自定义函数 + pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler) + pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler) + pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet) + pipeline.sample_video = sample_video.__get__(pipeline) + pipeline.single_step_video = single_step_video.__get__(pipeline) + pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline) + pipeline.add_noise = add_noise.__get__(pipeline) + pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline) + pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline) + + # 冻结 UNet 参数 + for param in pipeline.unet.parameters(): + param.requires_grad = False + pipeline.input_config, pipeline.unet.input_config = config, config + + # 准备 UNet 的 attention 和 conv + pipeline.unet = prep_unet_attention(pipeline.unet, config["motion_guidance_blocks"]) + pipeline.unet = prep_unet_conv(pipeline.unet) + + return pipeline + +# 初始化模型 +pipeline = initialize_models() + +def generate_video(uploaded_video, motion_representation_save_dir, generated_videos_save_dir, visible_gpu, default_seed, without_xformers, cfg_scale, negative_prompt, positive_prompt, inference_steps, guidance_scale, guidance_steps, warm_up_steps, cool_up_steps, motion_guidance_weight, motion_guidance_blocks, add_noise_step, new_prompt, seed): + # 更新配置 + config.update({ + "cfg_scale": cfg_scale, + "negative_prompt": negative_prompt, + "positive_prompt": positive_prompt, + "inference_steps": inference_steps, + "guidance_scale": guidance_scale, + "guidance_steps": guidance_steps, + "warm_up_steps": warm_up_steps, + "cool_up_steps": cool_up_steps, + "motion_guidance_weight": motion_guidance_weight, + #"motion_guidance_blocks": motion_guidance_blocks, + "add_noise_step": add_noise_step + }) + + # 设置环境变量 + os.environ["CUDA_VISIBLE_DEVICES"] = visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0)) + + device = pipeline.device + + # 创建保存目录 + if not os.path.exists(generated_videos_save_dir): + os.makedirs(generated_videos_save_dir) + + # 处理上传的视频 + if uploaded_video is not None: + pipeline.scheduler.customized_set_timesteps(config["inference_steps"], config["guidance_steps"], config["guidance_scale"], device=device, timestep_spacing_type="uneven") + + # 将上传的视频保存到指定路径 + video_path = os.path.join(generated_videos_save_dir, os.path.basename(uploaded_video)) + #shutil.move(uploaded_video, video_path) + shutil.copy2(uploaded_video, video_path) + + print("video_path :") + print(video_path) + + # 更新配置 + config["video_path"] = video_path + config["new_prompt"] = new_prompt + config.get("positive_prompt", "") + + from types import SimpleNamespace + + pipeline.input_config, pipeline.unet.input_config = SimpleNamespace(**config), SimpleNamespace(**config) + + print("pipeline.input_config.video_path :") + print(pipeline.input_config.video_path) + + # 提取运动表示 + seed_motion = seed if seed is not None else default_seed + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed_motion) + if not os.path.exists(motion_representation_save_dir): + os.makedirs(motion_representation_save_dir) + motion_representation_path = os.path.join(motion_representation_save_dir, os.path.splitext(os.path.basename(config["video_path"]))[0] + '.pt') + pipeline.obtain_motion_representation(generator=generator, motion_representation_path=motion_representation_path) + + # 生成视频 + seed = seed_motion + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed) + pipeline.input_config.seed = seed + + videos = pipeline.sample_video(generator=generator) + + #print("videos :") + #print(videos) + + videos = rearrange(videos, "b c f h w -> b f h w c") + save_path = os.path.join(generated_videos_save_dir, os.path.splitext(os.path.basename(config["video_path"]))[0] + "_" + config["new_prompt"].strip().replace(' ', '_') + str(seed_motion) + "_" + str(seed) + '.mp4') + videos_uint8 = (videos[0] * 255).astype(np.uint8) + imageio.mimwrite(save_path, videos_uint8, fps=8) + print(save_path, "is done") + + return save_path + else: + return "No video uploaded." + +# 使用 Gradio Blocks 构建界面 +with gr.Blocks() as demo: + # 页面标题和描述 + gr.Markdown("# MotionClone-Text-to-Video Generation") + gr.Markdown("This tool allows you to generate videos from text prompts using a pre-trained model. Upload a motion reference video, provide a new prompt, and adjust the settings to create your custom video.") + + # 主要输入区域 + with gr.Row(): + with gr.Column(): + # 视频上传 + uploaded_video = gr.Video(label="Upload Video") + # 新提示词 + new_prompt = gr.Textbox(label="New Prompt", value="A beautiful scene", lines=2) + # 种子 + seed = gr.Number(label="Seed", value=42) + # 生成按钮 + generate_button = gr.Button("Generate Video") + + with gr.Column(): + # 输出视频 + output_video = gr.Video(label="Generated Video") + + # 高级设置区域 + with gr.Accordion("Advanced Settings", open=False): + with gr.Row(): + with gr.Column(): + motion_representation_save_dir = gr.Textbox(label="Motion Representation Save Dir", value="motion_representation/") + generated_videos_save_dir = gr.Textbox(label="Generated Videos Save Dir", value="generated_videos") + visible_gpu = gr.Textbox(label="Visible GPU", value="0") + default_seed = gr.Number(label="Default Seed", value=2025) + without_xformers = gr.Checkbox(label="Without Xformers", value=False) + with gr.Column(): + cfg_scale = gr.Number(label="CFG Scale", value=7.5) + negative_prompt = gr.Textbox(label="Negative Prompt", value="bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers") + positive_prompt = gr.Textbox(label="Positive Prompt", value="8k, high detailed, best quality, film grain, Fujifilm XT3") + inference_steps = gr.Number(label="Inference Steps", value=100) + guidance_scale = gr.Number(label="Guidance Scale", value=0.3) + guidance_steps = gr.Number(label="Guidance Steps", value=50) + warm_up_steps = gr.Number(label="Warm Up Steps", value=10) + cool_up_steps = gr.Number(label="Cool Up Steps", value=10) + motion_guidance_weight = gr.Number(label="Motion Guidance Weight", value=2000) + motion_guidance_blocks = gr.Textbox(label="Motion Guidance Blocks", value="['up_blocks.1']") + add_noise_step = gr.Number(label="Add Noise Step", value=400) + + # 绑定生成函数 + generate_button.click( + generate_video, + inputs=[ + uploaded_video, motion_representation_save_dir, generated_videos_save_dir, visible_gpu, default_seed, without_xformers, cfg_scale, negative_prompt, positive_prompt, inference_steps, guidance_scale, guidance_steps, warm_up_steps, cool_up_steps, motion_guidance_weight, motion_guidance_blocks, add_noise_step, new_prompt, seed + ], + outputs=output_video + ) + + # 定义示例数据 + examples = [ + {"video_path": "reference_videos/camera_zoom_in.mp4", "new_prompt": "Relics on the seabed", "seed": 42}, + {"video_path": "reference_videos/camera_zoom_in.mp4", "new_prompt": "A road in the mountain", "seed": 42}, + {"video_path": "reference_videos/camera_zoom_in.mp4", "new_prompt": "Caves, a path for exploration", "seed": 2026}, + {"video_path": "reference_videos/camera_zoom_in.mp4", "new_prompt": "Railway for train", "seed": 2026}, + {"video_path": "reference_videos/camera_zoom_out.mp4", "new_prompt": "Tree, in the mountain", "seed": 2026}, + {"video_path": "reference_videos/camera_zoom_out.mp4", "new_prompt": "Red car on the track", "seed": 2026}, + {"video_path": "reference_videos/camera_zoom_out.mp4", "new_prompt": "Man, standing in his garden.", "seed": 2026}, + {"video_path": "reference_videos/camera_1.mp4", "new_prompt": "A island, on the ocean, sunny day", "seed": 42}, + {"video_path": "reference_videos/camera_1.mp4", "new_prompt": "A tower, with fireworks", "seed": 42}, + {"video_path": "reference_videos/camera_pan_up.mp4", "new_prompt": "Beautiful house, around with flowers", "seed": 42}, + {"video_path": "reference_videos/camera_translation_2.mp4", "new_prompt": "Forest, in winter", "seed": 2028}, + {"video_path": "reference_videos/camera_pan_down.mp4", "new_prompt": "Eagle, standing in the tree", "seed": 2026} + ] + examples = list(map(lambda d: [d["video_path"], d["new_prompt"], d["seed"]], examples)) + + # 添加示例 + gr.Examples( + examples=examples, + inputs=[uploaded_video, new_prompt, seed], + outputs=output_video, + fn=generate_video, + cache_examples=False + ) + + +# 启动应用 +demo.launch(share = True) diff --git a/t2v_video_sample.py b/t2v_video_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..eab68a1c1362552ad2bc76c96c6df5d2bdfbba5e --- /dev/null +++ b/t2v_video_sample.py @@ -0,0 +1,125 @@ +import argparse +from omegaconf import OmegaConf +import torch +from diffusers import AutoencoderKL, DDIMScheduler +from transformers import CLIPTextModel, CLIPTokenizer +from motionclone.models.unet import UNet3DConditionModel +from motionclone.pipelines.pipeline_animation import AnimationPipeline +from motionclone.utils.util import load_weights +from diffusers.utils.import_utils import is_xformers_available +from motionclone.utils.motionclone_functions import * +import json +from motionclone.utils.xformer_attention import * + +def main(args): + + os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0)) + + config = OmegaConf.load(args.inference_config) + adopted_dtype = torch.float16 + device = "cuda" + set_all_seed(42) + + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype) + + config.width = config.get("W", args.W) + config.height = config.get("H", args.H) + config.video_length = config.get("L", args.L) + + if not os.path.exists(args.generated_videos_save_dir): + os.makedirs(args.generated_videos_save_dir) + OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json")) + + model_config = OmegaConf.load(config.get("model_config", "")) + unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype) + + # set xformers + if is_xformers_available() and (not args.without_xformers): + unet.enable_xformers_memory_efficient_attention() + + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + controlnet=None, + scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)), + ).to(device) + + pipeline = load_weights( + pipeline, + # motion module + motion_module_path = config.get("motion_module", ""), + dreambooth_model_path = config.get("dreambooth_path", ""), + ).to(device) + pipeline.text_encoder.to(dtype=adopted_dtype) + + # load customized functions from motionclone_functions + pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler) + pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler) + pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet) + pipeline.sample_video = sample_video.__get__(pipeline) + pipeline.single_step_video = single_step_video.__get__(pipeline) + pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline) + pipeline.add_noise = add_noise.__get__(pipeline) + pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline) + pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline) + + for param in pipeline.unet.parameters(): + param.requires_grad = False + pipeline.input_config, pipeline.unet.input_config = config, config + + pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks) + pipeline.unet = prep_unet_conv(pipeline.unet) + pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven") + # pipeline.scheduler.customized_set_timesteps(config.inference_steps,device=device,timestep_spacing_type = "linspace") + with open(args.examples, 'r') as files: + for line in files: + # prepare infor of each case + example_infor = json.loads(line) + config.video_path = example_infor["video_path"] + config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "") + pipeline.input_config, pipeline.unet.input_config = config, config # update config + + # perform motion representation extraction + seed_motion = example_infor.get("seed", args.default_seed) + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed_motion) + if not os.path.exists(args.motion_representation_save_dir): + os.makedirs(args.motion_representation_save_dir) + motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt') + pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path) + + # perform video generation + seed = seed_motion # can assign other seed here + generator = torch.Generator(device=pipeline.device) + generator.manual_seed(seed) + pipeline.input_config.seed = seed + + videos = pipeline.sample_video(generator = generator,) + videos = rearrange(videos, "b c f h w -> b f h w c") + save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + + "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4') + videos_uint8 = (videos[0] * 255).astype(np.uint8) + + imageio.mimwrite(save_path, videos_uint8, fps=8) + print(save_path,"is done") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",) + + parser.add_argument("--inference_config", type=str, default="configs/t2v_camera.yaml") + parser.add_argument("--examples", type=str, default="configs/t2v_camera.jsonl") + parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/") + parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos") + + parser.add_argument("--visible_gpu", type=str, default=None) + parser.add_argument("--default-seed", type=int, default=2025) + parser.add_argument("--L", type=int, default=16) + parser.add_argument("--W", type=int, default=512) + parser.add_argument("--H", type=int, default=512) + + parser.add_argument("--without-xformers", action="store_true") + + args = parser.parse_args() + main(args)