diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..34d71beb0d9b724d62c11fe723b9beb95e5ec7bb --- /dev/null +++ b/.gitignore @@ -0,0 +1,49 @@ +.ipynb_checkpoints +.idea +__pycache__ + +datasets/ +tmp_imgs +runs/ +runs_last/ +saved_models/ +pre_trained/ +save_log/ +diffusers/ +weights/ +checkpoints/ +validation_videos* +pretrained/* +.gradio/* + +*.pyc +*.sh +*.pth +*.png +*.jpg +*.mp4 +*.txt +*.json +*.jsonl +*.zip +*.mp4 +*.csv +*.webp +*.bin +*.pkl +*.safetensors +*.pt +*.log +events.* +*.yml +*.gif +*.npy +*.out + +!requirements.txt +!saved_models/*.md +!LICENSE.txt +!config/* +!__assets__/* +!__assets__/Bridge_example/* +!pretrained/PUT_YOUR_WEIGHT_HERE.md \ No newline at end of file diff --git a/__assets__/0.jpg b/__assets__/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1c136bc0b5b4fe4ac16c7f7fbd68f96fd414620e Binary files /dev/null and b/__assets__/0.jpg differ diff --git a/__assets__/156.jpg b/__assets__/156.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68bf46dcdb902e1de47fdae417f5162dd8522b6e Binary files /dev/null and b/__assets__/156.jpg differ diff --git a/__assets__/274.jpg b/__assets__/274.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de860085102d9e04712f7ab50ed13d55477bd60e Binary files /dev/null and b/__assets__/274.jpg differ diff --git a/__assets__/375.jpg b/__assets__/375.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51a72d65b350bc002ad8123afb834b3034613604 Binary files /dev/null and b/__assets__/375.jpg differ diff --git a/__assets__/551.jpg b/__assets__/551.jpg new file mode 100644 index 0000000000000000000000000000000000000000..69dfc6128c208960432866649757509af14a3394 Binary files /dev/null and b/__assets__/551.jpg differ diff --git a/__assets__/91.jpg b/__assets__/91.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f25fa54d0d65396cec51c816223c27d4735717f5 Binary files /dev/null and b/__assets__/91.jpg differ diff --git a/__assets__/ThisThat_logo.png b/__assets__/ThisThat_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..23d05502214972bb18c549a803d508596923609f Binary files /dev/null and b/__assets__/ThisThat_logo.png differ diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..b5584963c9c11e95d04ede327d8dc8d41142177d 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,478 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import os, shutil, sys +import urllib.request +import argparse +import imageio +import math +import cv2 +import collections +import numpy as np import gradio as gr +from PIL import Image + +import torch +from pathlib import Path +from omegaconf import OmegaConf +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from diffusers import ( + AutoencoderKLTemporalDecoder, + DDPMScheduler, +) +from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, PretrainedConfig + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from train_code.train_svd import import_pretrained_text_encoder +from data_loader.video_dataset import tokenize_captions +from data_loader.video_this_that_dataset import get_thisthat_sam +from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel +from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline +from svd.temporal_controlnet import ControlNetModel +from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline +from utils.optical_flow_utils import bivariate_Gaussian + + +# For the 2D dilation +blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True) + + +# Import +# LENGTH=480 # length of the square area displaying/editing images +HEIGHT = 256 +WIDTH = 384 + + +MARKDOWN = \ + """ + ##

This&That

+ + [GitHub](https://github.com/Kiteretsu77/This_and_That_VDM) | [Paper](http://arxiv.org/abs/2407.05530) | [Webpage](https://cfeng16.github.io/this-and-that/) + This&That is a Robotics scenario (Bridge-dataset-based for this repo) Language-Gesture-Image-conditioned Video Generation Model for Robot Planning. + + This Demo is on the Video Diffusion Model part. + Only GestureNet is provided in this Gradio Demo, you can check the full test code for all pretrained weight available. + + ### Note: The index we put the gesture point by default here is [4, 10] for two gesture points or [4] for one gesture point. + ### Note: The result now only support is 256x384 + ### Note: Click "Clear All" to restart everything; Click "Undo Point" to cancel the point you put + + If This&That is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/This_and_That_VDM). Thanks! + """ + + +def store_img(img): + + # when new image is uploaded, `selected_points` should be empty + return img, [] + + + +def clear_all(): + return None, \ + gr.Image(value=None, height=HEIGHT, width=WIDTH, interactive=False), \ + None, [] # selected points + + +def undo_points(original_image): + img = original_image.copy() + return img, [] + + +# User click the image to get points, and show the points on the image [From https://github.com/Yujun-Shi/DragDiffusion] +def get_points(img, original_image, sel_pix, evt: gr.SelectData): + + # collect the selected point + sel_pix.append(evt.index) + + if len(sel_pix) > 2: + raise gr.Error("We only at most support two points") + + if original_image is None: + original_image = img.copy() + + # draw points + points = [] + for idx, point in enumerate(sel_pix): + if idx % 2 == 0: + # draw a red circle at the handle point + cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) + else: + # draw a blue circle at the handle point + cv2.circle(img, tuple(point), 10, (0, 255, 0), -1) + points.append(tuple(point)) + # draw an arrow from handle point to target point + # if len(points) == 2: + # cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) + # points = [] + + return [img if isinstance(img, np.ndarray) else np.array(img), original_image] + + +def gesturenet_inference(ref_image, prompt, selected_points): + + # Check some paramter, must have prompt and selected points + if prompt == "" or prompt is None: + raise gr.Error("Please input text prompt") + if selected_points == []: + raise gr.Error("Please click one/two points in the Image") + + # Prepare the setting + frame_idxs = [4, 10] + use_ambiguous_prompt = False + model_type = "GestureNet" + huggingface_pretrained_path = "HikariDawn/This-and-That-1.1" + + print("Text prompt is ", prompt) + + # Prepare tmp folder + store_folder_name = "tmp" + if os.path.exists(store_folder_name): + shutil.rmtree(store_folder_name) + os.makedirs(store_folder_name) + + + # Read the yaml setting files (Very important for loading hyperparamters needed) + if not os.path.exists(huggingface_pretrained_path): + yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="unet", filename="train_image2video.yaml") + if model_type == "GestureNet": + yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="gesturenet", filename="train_image2video_gesturenet.yaml") + else: # If the path is a local path we can concatenate it here + yaml_download_path = os.path.join(huggingface_pretrained_path, "unet", "train_image2video.yaml") + if model_type == "GestureNet": + yaml_download_path = os.path.join(huggingface_pretrained_path, "gesturenet", "train_image2video_gesturenet.yaml") + + # Load the config + assert(os.path.exists(yaml_download_path)) + config = OmegaConf.load(yaml_download_path) + + + ################################################ Prepare vae, unet, image_encoder Same as before ################################################################# + print("Prepare the pretrained model") + accelerator = Accelerator( + gradient_accumulation_steps = config["gradient_accumulation_steps"], + mixed_precision = config["mixed_precision"], + log_with = config["report_to"], + project_config = ProjectConfiguration(project_dir=config["output_dir"], logging_dir=Path(config["output_dir"], config["logging_name"])), + ) + feature_extractor = CLIPImageProcessor.from_pretrained( + config["pretrained_model_name_or_path"], subfolder="feature_extractor", revision=None + ) # This instance has now weight, they are just seeting file + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + config["pretrained_model_name_or_path"], subfolder="image_encoder", revision=None, variant="fp16" + ) + vae = AutoencoderKLTemporalDecoder.from_pretrained( + config["pretrained_model_name_or_path"], subfolder="vae", revision=None, variant="fp16" + ) + unet = UNetSpatioTemporalConditionModel.from_pretrained( + huggingface_pretrained_path, + subfolder = "unet", + low_cpu_mem_usage = True, + # variant = "fp16", + ) + + + # For text .............................................. + tokenizer = AutoTokenizer.from_pretrained( + config["pretrained_tokenizer_name_or_path"], + subfolder = "tokenizer", + revision = None, + use_fast = False, + ) + # Clip Text Encoder + text_encoder_cls = import_pretrained_text_encoder(config["pretrained_tokenizer_name_or_path"], revision=None) + text_encoder = text_encoder_cls.from_pretrained(config["pretrained_tokenizer_name_or_path"], subfolder = "text_encoder", revision = None, variant = None) + + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae + image_encoder to gpu and cast to weight_dtype + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) # Will switch back at the end + text_encoder.requires_grad_(False) + + # Move to accelerator + vae.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # For GestureNet + if model_type == "GestureNet": + unet.to(accelerator.device, dtype=weight_dtype) # There is no need to cast unet in unet training, only needed in controlnet one + + # Handle the Controlnet first from UNet + gesturenet = ControlNetModel.from_pretrained( + huggingface_pretrained_path, + subfolder = "gesturenet", + low_cpu_mem_usage = True, + variant = None, + ) + + gesturenet.requires_grad_(False) + gesturenet.to(accelerator.device) + ############################################################################################################################################################## + + + + + # Init the pipeline + pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained( + config["pretrained_model_name_or_path"], # Still based on regular SVD config + vae = vae, + image_encoder = image_encoder, + unet = unet, + revision = None, # Set None directly now + torch_dtype = weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + + + ############################## Prepare and Process the condition here ############################## + org_height, org_width, _ = ref_image.shape + ref_image_pil = Image.fromarray(ref_image) + ref_image_pil = ref_image_pil.resize((config["width"], config["height"])) + + + # Initial the optical flow format we want + gesture_condition_img = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty + + # Handle the selected points to the condition we want + for point_idx, point in enumerate(selected_points): + + frame_idx = frame_idxs[point_idx] + horizontal, vertical = point + + # Init the base image + base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size + base_img.fill(255) + + # Draw square around the target position + dot_range = 10 # Diameter + for i in range(-1*dot_range, dot_range+1): + for j in range(-1*dot_range, dot_range+1): + dil_vertical, dil_horizontal = vertical + i, horizontal + j + if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): + if point_idx == 0: + base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red + else: + base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point + + # Dilate + if config["dilate"]: + base_img = cv2.filter2D(base_img, -1, blur_kernel) + + + ############################################################################################################################## + ### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store + + # Resize frames Don't use negative and don't resize in [0,1] + base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) + + # Channel Transform and Range Shift + if config["conditioning_channels"] == 3: + # Map to [0, 1] range + base_img = base_img / 255.0 + + else: + raise NotImplementedError() + + # ReOrganize shape + base_img = base_img.transpose(2, 0, 1) # hwc -> chw + + # Write base img based on frame_idx + gesture_condition_img[frame_idx] = base_img # Only the first frame, the rest is 0 initialized + + + #################################################################################################### + + # Use the same tokenize process as the dataset preparation stage + tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim + + + + # Call the pipeline + with torch.autocast("cuda"): + frames = pipeline( + image = ref_image_pil, + condition_img = gesture_condition_img, # numpy [0,1] range + controlnet = accelerator.unwrap_model(gesturenet), + prompt = tokenized_prompt, + use_text = config["use_text"], + text_encoder = text_encoder, + height = config["height"], + width = config["width"], + num_frames = config["video_seq_length"], + decode_chunk_size = 8, + motion_bucket_id = 200, + # controlnet_image_index = controlnet_image_index, + # coordinate_values = coordinate_values, + num_inference_steps = config["num_inference_steps"], + max_guidance_scale = config["inference_max_guidance_scale"], + fps = 7, + use_instructpix2pix = config["use_instructpix2pix"], + noise_aug_strength = config["inference_noise_aug_strength"], + controlnet_conditioning_scale = config["outer_conditioning_scale"], + inner_conditioning_scale = config["inner_conditioning_scale"], + guess_mode = config["inference_guess_mode"], # False in inference + image_guidance_scale = config["image_guidance_scale"], + ).frames[0] + + # Save frames + video_file_path = os.path.join(store_folder_name, "tmp.mp4") + writer = imageio.get_writer(video_file_path, fps=4) + for idx, frame in enumerate(frames): + frame.save(os.path.join(store_folder_name, str(idx)+".png")) + writer.append_data(cv2.cvtColor(cv2.imread(os.path.join(store_folder_name, str(idx)+".png")), cv2.COLOR_BGR2RGB)) + writer.close() + + + + # Cleaning process + del pipeline + torch.cuda.empty_cache() + + return gr.update(value=video_file_path, width=config["width"], height=config["height"]) # Return resuly based on the need + + + +if __name__ == '__main__': + + + # Gradio demo part + with gr.Blocks() as demo: + # layout definition + with gr.Row(): + gr.Markdown(MARKDOWN) + + # UI components for editing real images + with gr.Row(elem_classes=["container"]): + selected_points = gr.State([]) # store points + original_image = gr.State(value=None) # store original input image + with gr.Row(): + with gr.Column(): + gr.Markdown("""

Click two Points

""") + input_image = gr.Image(label="Input Image", height=HEIGHT, width=WIDTH, interactive=False, elem_id="input_img") + # gr.Image(type="numpy", label="Click Points", height=HEIGHT, width=WIDTH, interactive=False) # for points clicking + undo_button = gr.Button("Undo point") + + # Text prompt + with gr.Row(): + prompt = gr.Textbox(label="Text Prompt") + + + with gr.Column(): + gr.Markdown("""

Results

""") + frames = gr.Video(value=None, label="Generate Video", show_label=True, height=HEIGHT, width=WIDTH) + with gr.Row(): + run_button = gr.Button("Run") + clear_all_button = gr.Button("Clear All") + + + + + # with gr.Tab("Base Model Config"): + # with gr.Row(): + # local_models_dir = 'local_pretrained_models' + # local_models_choice = \ + # [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] + # model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", + # label="Diffusion Model Path", + # choices=[ + # "runwayml/stable-diffusion-v1-5", + # "gsdf/Counterfeit-V2.5", + # "stablediffusionapi/anything-v5", + # "SG161222/Realistic_Vision_V2.0", + # ] + local_models_choice + # ) + # vae_path = gr.Dropdown(value="default", + # label="VAE choice", + # choices=["default", + # "stabilityai/sd-vae-ft-mse"] + local_models_choice + # ) + + # Examples + with gr.Row(elem_classes=["container"]): + gr.Examples( + [ + ["__assets__/Bridge_example/Task1_v1_511/im_0.jpg", "take this to there"], + ["__assets__/Bridge_example/Task2_v2_164/im_0.jpg", "put this to there"], + ["__assets__/Bridge_example/Task3_v2_490/im_0.jpg", "fold this"], + ["__assets__/Bridge_example/Task4_v2_119/im_0.jpg", "open this"], + + # ["__assets__/0.jpg", "take this to there"], + ["__assets__/91.jpg", "take this to there"], + ["__assets__/156.jpg", "take this to there"], + # ["__assets__/274.jpg", "take this to there"], + ["__assets__/375.jpg", "take this to there"], + # ["__assets__/551.jpg", "take this to there"], + ], + [input_image, prompt, selected_points], + ) + + + + + ####################################### Event Definition ####################################### + + # Draw the points + input_image.select( + get_points, + [input_image, original_image, selected_points], + [input_image, original_image], + ) + + # Clean the points + undo_button.click( + undo_points, + [original_image], + [input_image, selected_points], + ) + + run_button.click( + gesturenet_inference, + inputs = [ + # vae, unet, gesturenet, image_encoder, text_encoder, tokenizer, + original_image, prompt, selected_points, + # frame_idxs, + # config, accelerator, weight_dtype + ], + outputs = [frames] + ) + + clear_all_button.click( + clear_all, + [], + outputs = [original_image, input_image, prompt, selected_points], + ) -def greet(name): - return "Hello " + name + "!!" -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() + demo.queue().launch(share=True, debug=True) diff --git a/config/accelerate_config.json b/config/accelerate_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d4ab380c9917211331b88d59a4deca035f63aa00 --- /dev/null +++ b/config/accelerate_config.json @@ -0,0 +1,18 @@ +{ + "compute_environment": "LOCAL_MACHINE", + "debug": false, + "distributed_type": "MULTI_GPU", + "downcast_bf16": "no", + "gpu_ids": "all", + "machine_rank": 0, + "main_training_function": "main", + "mixed_precision": "fp16", + "num_machines": 1, + "num_processes": 8, + "rdzv_backend": "static", + "same_network": true, + "tpu_env": [], + "tpu_use_cluster": false, + "tpu_use_sudo": false, + "use_cpu": false +} \ No newline at end of file diff --git a/config/flowformer_config.py b/config/flowformer_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1b40b051212a9b70ce79f070dab38bb7f3dae6 --- /dev/null +++ b/config/flowformer_config.py @@ -0,0 +1,78 @@ +from yacs.config import CfgNode as CN +_CN = CN() + +_CN.name = 'default' +_CN.suffix ='sintel' +_CN.gamma = 0.75 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 100000000 +_CN.image_size = [432, 960] +_CN.add_noise = False +_CN.use_smoothl1 = False +_CN.critical_params = [] + +_CN.transformer = 'percostformer3' + +### change the path here +_CN.model = "pretrained/sintel.pth" + +_CN.percostformer3 = CN() +_CN.percostformer3.pe = 'linear' +_CN.percostformer3.dropout = 0.0 +_CN.percostformer3.droppath = 0.0 +_CN.percostformer3.encoder_latent_dim = 256 # in twins, this is 256 +_CN.percostformer3.query_latent_dim = 64 +_CN.percostformer3.cost_latent_input_dim = 64 +_CN.percostformer3.cost_latent_token_num = 8 +_CN.percostformer3.cost_latent_dim = 128 +_CN.percostformer3.cost_heads_num = 1 +# encoder +_CN.percostformer3.pretrain = True +_CN.percostformer3.use_convertor = False +_CN.percostformer3.del_layers = True +_CN.percostformer3.encoder_depth = 3 +_CN.percostformer3.expand_factor = 4 +_CN.percostformer3.vertical_encoder_attn = "twins" +_CN.percostformer3.attn_dim = 128 +_CN.percostformer3.patch_size = 8 +_CN.percostformer3.patch_embed = 'single' +_CN.percostformer3.cross_attn = "all" +_CN.percostformer3.gma = "GMA" +_CN.percostformer3.vert_c_dim = 64 +_CN.percostformer3.cost_encoder_res = True +_CN.percostformer3.cnet = 'twins' +_CN.percostformer3.fnet = 'twins' +_CN.percostformer3.flow_or_pe = "and" +_CN.percostformer3.use_patch = False # use cost patch rather than local cost as query +_CN.percostformer3.use_rpe = False +_CN.percostformer3.detach_local = False +_CN.percostformer3.no_sc = False +_CN.percostformer3.r_16 =-1 +_CN.percostformer3.quater_refine = False +# pretrain config +_CN.percostformer3.pretrain_mode = False +_CN.percostformer3.pic_size = [368, 496, 368, 496] +_CN.percostformer3.mask_ratio = 0.5 +_CN.percostformer3.query_num = 30 +_CN.percostformer3.no_border = True +_CN.percostformer3.gt_r = 15 +_CN.percostformer3.fix_pe = False +# decoder +_CN.percostformer3.decoder_depth = 12 +_CN.percostformer3.critical_params = ['vert_c_dim', 'encoder_depth', 'vertical_encoder_attn', "use_patch", "flow_or_pe", "use_rpe", "dropout", "detach_local", "expand_factor"] + + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = 'OneCycleLR' +_CN.trainer.optimizer = 'adamw' +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-5 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = 'linear' +def get_cfg(): + return _CN.clone() \ No newline at end of file diff --git a/config/train_image2video.yaml b/config/train_image2video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79b0367614fc716fe9ac7be3caa9e776fac7b1bb --- /dev/null +++ b/config/train_image2video.yaml @@ -0,0 +1,78 @@ + +# Model Setting +pretrained_model_name_or_path: stabilityai/stable-video-diffusion-img2vid # -xt is for 25 frames version +load_unet_path: # This is usally used to load pretrained UNet; e.g., you may want to start one of your checkpoints trained before +video_seq_length: 14 # Standardized to 14 +process_fps: 7 +train_noise_aug_strength: 0.1 +scheduler: EDM +conditioning_dropout_prob: 0.1 + + +# Dataset Setting +dataset_name: Bridge # WebVid / Bridge +dataset_path: [../sanity_check/bridge_v1_raw, ../sanity_check/bridge_v2_raw] +output_dir: checkpoints/img2video +height: 256 # Ratio that is functional: 256:384 576:1024 320:512 320:576 +width: 384 # It is said that the height and width should be a scale of 64 +dataloader_num_workers: 4 # Don't set this too large; usually, Video diffusion are slow processing, so don't need that many workers to do early loading +flip_aug_prob: 0.45 # Whether we flip the GT and cond vertically +acceleration_tolerance: 4 # Recommened setting + + +# Text setting +use_text: True # If this is True, we will use text value +pretrained_tokenizer_name_or_path: stabilityai/stable-diffusion-2-1-base # Use SD 2.1 +empty_prompts_proportion: 0.0 # Useless now, we already have CFG in training +mix_ambiguous: False # Whether we mix ambiguous prompt for "this" and "that" + + +# Motion setting Useless right now... +motion_bucket_id: 200 # Set it for exact value; If this is none, we will use below setting +dataset_motion_mean: 35.3 # For 14 fps, it is N(35.3, 18.5) +dataset_motion_std: 18.5 # For 25 fps, it is N(?, ?) +svd_motion_mean: 165 +svd_motion_std: 22.5 + + +# Training setting +resume_from_checkpoint: False # latest/False +num_train_iters: 100000 # Will automatically choose the checkpoints at 99K +partial_finetune: False # Whether we just tune some params to speed up +train_batch_size: 1 # This is the batch size per GPU +checkpointing_steps: 3000 +validation_step: 300 +logging_name: logging +seed: 42 +validation_img_folder: # Prepare your own validation dataset +validation_store_folder: validation_results +checkpoints_total_limit: 15 + +# Noise Strength +noise_mean: 0.5 # Regular Img2Video: (0.7, 1.6); Text2Video: (0.5, 1.4) +noise_std: 1.4 + + +# Inference +num_inference_steps: 25 +inference_noise_aug_strength: 0.1 +inference_max_guidance_scale: 3.0 # Take training and testing at different scenario + + +# Learning Rate and Optimizer +learning_rate: 1e-5 # Usually this is ok +scale_lr: False # TODO: Is it needed to scale the learning rate? +adam_beta1: 0.9 +adam_beta2: 0.999 +use_8bit_adam: True # Need this to save more memory +adam_weight_decay: 1e-2 +adam_epsilon: 1e-08 +lr_warmup_steps: 500 +lr_decay_scale: 0.5 + + +# Other Setting +mixed_precision: fp16 +gradient_accumulation_steps: 1 +gradient_checkpointing: 1 +report_to: tensorboard \ No newline at end of file diff --git a/config/train_image2video_controlnet.yaml b/config/train_image2video_controlnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5ed3e35a513c680350e8deca55e7fc27c00efff --- /dev/null +++ b/config/train_image2video_controlnet.yaml @@ -0,0 +1,101 @@ + +# Model Setting +pretrained_model_name_or_path: stabilityai/stable-video-diffusion-img2vid # stabilityai/pretrained +load_unet_path: ../saved_weights/v4_VL_paper/checkpoint-99000 # None/specific path This is for pretrained-UNet path +load_controlnet_path: # None/specific path For checkpoint loaded from pretrained-Controlnet Path +video_seq_length: 14 +process_fps: 7 +train_noise_aug_strength: 0.1 +scheduler: EDM +conditioning_dropout_prob: 0.1 + + +# Dataset Setting +data_loader_type: thisthat # thisthat +dataset_name: Bridge # Bridge +dataset_path: [../sanity_check/bridge_v1_TT14, ../sanity_check/bridge_v2_TT14] # ../Bridge_filter_flow, ../Bridge_v2_filter_flow/] +output_dir: checkpoints/img2video +height: 256 # Ratio that is functional: 256:384 576:1024 320:448 320:576 512:640 448:640 +width: 384 # It is said that the height and width should be a scale of 64 +dataloader_num_workers: 4 # For Debug, it only needs 1 +flip_aug_prob: 0.45 # Whether we flip the GT and cond vertically +# No acceleration_tolerance, since TT dataset already filter those out + + +# Text setting +use_text: True # If this is True, we will use text value +pretrained_tokenizer_name_or_path: stabilityai/stable-diffusion-2-1-base # Use SD 2.1 +empty_prompts_proportion: 0.0 +mix_ambiguous: False # Whether we mix ambiguous prompt for "this" and "that" + + +# Mask setting +mask_unet_vae: False # Whether we use mask to map latents to be zero padding +mask_controlnet_vae: False +mask_proportion: 0.0 + + +# Condition Setting +conditioning_channels: 3 # Usually it is 3 +num_points_left: # 1 # For flow: You can only choose one between flow_select_rate and num_points_left; num_points_left should be higher priority +flow_select_rate: 0.99 # For flow +threshold_factor: 0.2 # For flow +dilate: True # Traj must be True for dilate +inner_conditioning_scale: 1.0 # Conditioning scale for the internal value, defauly is starting from 1.0 +outer_conditioning_scale: 1.0 # Outer Conditioning Scale for whole conditioning trainable copy 这里有点意思,直接不小心设定成2.0了 + + +# Motion setting +motion_bucket_id: 200 +dataset_motion_mean: 25 # For 14 fps, it is N(25, 10) +dataset_motion_std: 10 # For 25 fps, it is N(18, 7) +svd_motion_mean: 180 +svd_motion_std: 30 + + + +# Training setting +resume_from_checkpoint: False # latest/False +num_train_iters: 30100 # Will automatically choose the checkpoints +partial_finetune: False # Whether we just tune some params to speed up +train_batch_size: 1 # This is the batch size per GPU +checkpointing_steps: 3000 +validation_step: 300 +logging_name: logging +seed: 42 +validation_img_folder: datasets/validation_TT14 +validation_store_folder: validation_videos +checkpoints_total_limit: 15 + + +# Noise Strength +noise_mean: 0.5 # Regular Img2Video: (0.7, 1.6); Text2Video: (0.5, 1.4) +noise_std: 1.4 + + +# Inference +num_inference_steps: 25 +use_instructpix2pix: False # Whether we will use the instructPix2Pix mode, which involves 3 inputs; it may needs tuning to have better result at the end. +inference_noise_aug_strength: 0.1 +inference_max_guidance_scale: 3.0 # Take training and testing at different scenario +inference_guess_mode: False # Whether we use guess mode in the contorlnet +image_guidance_scale: 2.5 # Empirically, 2.5 is the best value Seems not using this now + + +# Learning Rate and Optimizer +learning_rate: 5e-6 # 5e-6 is the LR we test that is just right +scale_lr: False # TODO: Is it needed to scale the learning rate? +adam_beta1: 0.9 +adam_beta2: 0.999 +use_8bit_adam: True # Need this to save more memory +adam_weight_decay: 1e-2 +adam_epsilon: 1e-08 +lr_warmup_steps: 500 +lr_decay_scale: 0.5 + + +# Other Setting +mixed_precision: fp16 +gradient_accumulation_steps: 1 # ???? +gradient_checkpointing: 1 # ???? +report_to: tensorboard \ No newline at end of file diff --git a/curation_pipeline/add_lang_info.py b/curation_pipeline/add_lang_info.py new file mode 100644 index 0000000000000000000000000000000000000000..1a33ec4d75f395e3f97a903fd6cdb03a0ca07cf2 --- /dev/null +++ b/curation_pipeline/add_lang_info.py @@ -0,0 +1,38 @@ +''' + Add the processed lang information +''' +import os, sys, shutil +import json + + +if __name__ == "__main__": + + # Main config file path information + processed_json_file_path = "updated_bridge_v2.json" + + + # Read the json file + file = open(processed_json_file_path) + data = json.load(file) + + + # Iterate all the folders inside + start_idx = 0 + for seq_instance in data: + target_path = seq_instance["images0"] + print("We are processing ", target_path) + + processed_lang_txt_path = os.path.join(target_path, "processed_lang.txt") + if os.path.exists(processed_lang_txt_path): + os.remove(processed_lang_txt_path) + + # Write the action + This + That into the sequence. + processed_lang_txt = open(processed_lang_txt_path, "a") + processed_lang_txt.write(str(seq_instance["action"])+"\n") + processed_lang_txt.write(str(seq_instance["this"])+"\n") + processed_lang_txt.write(str(seq_instance["that"])+"\n") + + + start_idx += 1 + + print("We have ", start_idx) \ No newline at end of file diff --git a/curation_pipeline/match_dataset_v1.py b/curation_pipeline/match_dataset_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd9f66fd3f3669ce518df6d1f8f601a4c7e784f --- /dev/null +++ b/curation_pipeline/match_dataset_v1.py @@ -0,0 +1,117 @@ +''' + This file is to match the selected frames with the bridge dataset + We need to use some tricks to select the item +''' +import os, sys, shutil +import cv2 +import numpy as np + + + + +def compare_img(imageA, imageB): + # the 'Mean Squared Error' between the two images is the + # sum of the squared difference between the two images; + # NOTE: the two images must have the same dimension + err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2) + err /= float(imageA.shape[0] * imageA.shape[1]) + + # return the MSE, the lower the error, the more "similar" + # the two images are + return err + + + +def search_path(dataset_path, target_path, store_txt_path): + + # We only needs to care about Bridge v1 dataset area + target_img_path = os.path.join(target_path, "im_0.jpg") + target_img = cv2.imread(target_img_path) + + # Iterate all the folders inside + for scene_name in sorted(os.listdir(dataset_path)): + # print("We are reading scene", scene_name) + scene_dir = os.path.join(dataset_path, scene_name) + + for task_name in os.listdir(scene_dir): + task_dir = os.path.join(scene_dir, task_name) + + for time_clock in os.listdir(task_dir): + if time_clock == "lmdb": + continue # Skip lmdb folder + + time_dir = os.path.join(task_dir, time_clock, "raw", "traj_group0") + if not os.path.exists(time_dir): + continue + + for traj_name in os.listdir(time_dir): + traj_path = os.path.join(time_dir, traj_name) + if not os.path.isdir(traj_path): + continue + + # Directly move policy_out_file_path; just in case there is also valuable information there + policy_out_file_path = os.path.join(traj_path, "policy_out.pkl") + if not os.path.exists(policy_out_file_path): + continue + + # Check the lang txt file + lang_txt_file_path = os.path.join(traj_path, "lang.txt") + if not os.path.exists(lang_txt_file_path): + continue + + + # Last thing to locate to the right path + for img_name in os.listdir(traj_path): + if img_name != "images0": # Only consider one camera angle + continue + + img_folder_path = os.path.join(traj_path, img_name) + if not os.path.isdir(img_folder_path): + continue + + + # Compare two image + img_path = os.path.join(img_folder_path, "im_0.jpg") + # print("img_folder_path is ", img_path) + compare_sample_img = cv2.imread(img_path) + error = compare_img(target_img, compare_sample_img) + + if error == 0: + # Continue to all the rest for at least 5 images + status = True + for idx in range (10): + idx_img_path = os.path.join(img_folder_path, "im_"+str(idx)+".jpg") + idx_target_img_path = os.path.join(target_path, "im_"+str(idx)+".jpg") + idx_compare_sample_img = cv2.imread(idx_img_path) + idx_target_img = cv2.imread(idx_target_img_path) + error = compare_img(idx_target_img, idx_compare_sample_img) + + if error != 0: + status = False + break + + if status: + print("We found one at ", img_path) + f = open(store_txt_path, "a") + f.write(target_path + " " + img_folder_path + "\n") + return True + + return False + + +if __name__ == "__main__": + input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/datasets_rob/Bridge_v1_test_raw" + dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v1/berkeley" # 直接从本地新unzip的获取,怕之前的被xuweiyi改动过 + store_txt_path = "match_info.txt" + + if os.path.exists(store_txt_path): + os.remove(store_txt_path) + + for img_name in sorted(os.listdir(input_path)): + target_path = os.path.join(input_path, img_name) + print("We are finding for ", target_path) + + status = search_path(dataset_path, target_path, store_txt_path) + + if not status: + print("we cannot find one") \ No newline at end of file diff --git a/curation_pipeline/match_dataset_v2.py b/curation_pipeline/match_dataset_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..26318a34ee735e9afdbc26667aa6ddba071a86ef --- /dev/null +++ b/curation_pipeline/match_dataset_v2.py @@ -0,0 +1,137 @@ +''' + This file is to match the selected frames with the bridge dataset + We need to use some tricks to select the item +''' +import os, sys, shutil +import cv2 +import numpy as np + + + + +def compare_img(imageA, imageB): + # the 'Mean Squared Error' between the two images is the + # sum of the squared difference between the two images; + # NOTE: the two images must have the same dimension + err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2) + err /= float(imageA.shape[0] * imageA.shape[1]) + + # return the MSE, the lower the error, the more "similar" + # the two images are + return err + + + +def search_path(dataset_path, target_path, store_txt_path): + + # We only needs to care about Bridge v1 dataset area + target_img_path = os.path.join(target_path, "im_0.jpg") + if not os.path.exists(target_img_path): + print("The image we read is False") + return False + target_img = cv2.imread(target_img_path) + + # Iterate all the folders inside + for scene_name in sorted(os.listdir(dataset_path)): + scene_dir = os.path.join(dataset_path, scene_name) + + for task_name in sorted(os.listdir(scene_dir)): + task_dir = os.path.join(scene_dir, task_name) + + for order_name in sorted(os.listdir(task_dir)): + order_dir = os.path.join(task_dir, order_name) + + for time_clock in sorted(os.listdir(order_dir)): + if time_clock == "lmdb": + continue # Skip lmdb folder + + time_dir = os.path.join(order_dir, time_clock, "raw", "traj_group0") + if not os.path.exists(time_dir): + continue + + for traj_name in sorted(os.listdir(time_dir)): + traj_path = os.path.join(time_dir, traj_name) + if not os.path.isdir(traj_path): + continue + + # Directly move policy_out_file_path; just in case there is also valuable information there + policy_out_file_path = os.path.join(traj_path, "policy_out.pkl") + if not os.path.exists(policy_out_file_path): + continue + + # Check the lang txt file + lang_txt_file_path = os.path.join(traj_path, "lang.txt") + if not os.path.exists(lang_txt_file_path): + continue + + + for img_name in sorted(os.listdir(traj_path)): + if img_name != "images0": # Only consider one camera angle + continue + + img_folder_path = os.path.join(traj_path, img_name) + if not os.path.isdir(img_folder_path): + continue + + + # Compare two image + img_path = os.path.join(img_folder_path, "im_0.jpg") + if not os.path.exists(img_path): + print(img_folder_path + " doesn't even have im_0.jpg") + continue + # print("img_folder_path is ", img_path) + compare_sample_img = cv2.imread(img_path) + # try: + # compare_sample_img.shape + # except Exception: + # print("The compare_sample_img cannot be red") + # continue + error = compare_img(target_img, compare_sample_img) + + if error == 0: + # Continue to all the rest for at least 5 images + status = True + for idx in range (10): + idx_img_path = os.path.join(img_folder_path, "im_"+str(idx)+".jpg") + idx_target_img_path = os.path.join(target_path, "im_"+str(idx)+".jpg") + if not os.path.exists(idx_img_path): + print("The idx_img_path long idx we see only at ", idx) + continue + if not os.path.exists(idx_target_img_path): + print("The idx_target_img_path long idx we see only at ", idx) + continue + idx_compare_sample_img = cv2.imread(idx_img_path) + idx_target_img = cv2.imread(idx_target_img_path) + error = compare_img(idx_target_img, idx_compare_sample_img) + + if error != 0: + status = False + break + + if status: + print("We found one at ", img_path) + f = open(store_txt_path, "a") + f.write(target_path + " " + img_folder_path + "\n") + return True + + return False + + +if __name__ == "__main__": + input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/datasets_rob/Bridge_v2_test_raw" + dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2" # 直接从本地新unzip的获取,怕之前的被xuweiyi改动过 + store_txt_path = "match_info_v2_p1.txt" + start_idx = 0 + end_idx = 500 + + if os.path.exists(store_txt_path): + os.remove(store_txt_path) + + for img_name in sorted(os.listdir(input_path))[start_idx:end_idx]: + target_path = os.path.join(input_path, img_name) + print("We are finding for ", target_path) + + status = search_path(dataset_path, target_path, store_txt_path) + + if not status: + print("we cannot find one") \ No newline at end of file diff --git a/curation_pipeline/prepare_bridge_csv.py b/curation_pipeline/prepare_bridge_csv.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c99a9292561a212b1ea4827203099067a72d5a --- /dev/null +++ b/curation_pipeline/prepare_bridge_csv.py @@ -0,0 +1,69 @@ +''' + This file is to prepare the dataset in csv file following the format required by Opne-SORA +''' + +import os, sys, shutil +import json +import csv + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +# from curation_pipeline.prepare_bridge_v1 import read_bridge_v1 +# from curation_pipeline.prepare_bridge_v2 import read_bridge_v2 + + + +def iter_dataset(dataset_path): + lists = [] + for sub_folder_name in os.listdir(dataset_path): + sub_folder_path = os.path.join(dataset_path, sub_folder_name) + + # Check number of frames + max_length = len(os.listdir(sub_folder_path)) + for check_idx in range(max_length): + if not os.path.exists(os.path.join(sub_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists + break + num_frames = check_idx + + # Read the text + txt_path = os.path.join(sub_folder_path, "lang.txt") + f = open(txt_path, "r") + lang_prompt = f.readline() + + lists.append([sub_folder_path, lang_prompt, num_frames, 480, 640]) + # break + return lists + + + +if __name__ == "__main__": + v1_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/sanity_check/bridge_v1_raw" + v2_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/sanity_check/bridge_v2_raw" + store_name = "Bridge_raw.csv" + + if os.path.exists(store_name): + os.remove(store_name) + + + # Execute + full_lists = [["path", "text", "num_frames", "height", "width"]] + + v1_lists = iter_dataset(v1_dataset_path) + full_lists.extend(v1_lists) + v2_lists = iter_dataset(v2_dataset_path) + full_lists.extend(v2_lists) + print("Full length is ", len(full_lists)) + + + # Store as csv file + with open(store_name, 'w') as outfile: + write = csv.writer(outfile) + write.writerows(full_lists) + + + + # with open('output.jsonl', 'w') as outfile: + # for entry in JSON_file: + # json.dump(entry, outfile) + # outfile.write('\n') \ No newline at end of file diff --git a/curation_pipeline/prepare_bridge_jsonl.py b/curation_pipeline/prepare_bridge_jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..930aba37e8f763f01f95563b92680f5eee29edaf --- /dev/null +++ b/curation_pipeline/prepare_bridge_jsonl.py @@ -0,0 +1,47 @@ +''' + This file is to prepare the dataset in jsonl file +''' + +import os, sys, shutil +import json + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from curation_pipeline.prepare_bridge_v1 import read_bridge_v1 +from curation_pipeline.prepare_bridge_v2 import read_bridge_v2 + + +if __name__ == "__main__": + v1_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v1/berkeley" + v2_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2" + store_name = "store.jsonl" + + if os.path.exists(store_name): + os.remove(store_name) + + + # Execute + full_lists = [] + + v1_lists = read_bridge_v1(v1_dataset_path, "", copyfile=False) + full_lists.extend(v1_lists) + v2_lists = read_bridge_v2(v2_dataset_path, "", copyfile=False) + full_lists.extend(v2_lists) + print("Full length is ", len(full_lists)) + + + with open(store_name, 'w') as outfile: + for list_name in full_lists: + instance = dict() + instance["file_path"] = list_name + + json.dump(instance, outfile) + outfile.write('\n') + + + + # with open('output.jsonl', 'w') as outfile: + # for entry in JSON_file: + # json.dump(entry, outfile) + # outfile.write('\n') \ No newline at end of file diff --git a/curation_pipeline/prepare_bridge_v1.py b/curation_pipeline/prepare_bridge_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..4de8746f1aa1c9b726db6a266dc63bfa124522e2 --- /dev/null +++ b/curation_pipeline/prepare_bridge_v1.py @@ -0,0 +1,132 @@ +''' + This repository is used to prepare Bridge dataset +''' +import os, sys, shutil + + +def read_bridge_v1(dataset_path, train_store_path, test_store_path, test_dataset_lists, copyfile=True): + # copyfile is True when we need to copy the file to the target destination + + start_idx = 0 + target_lists = [] + prefix_len = len(dataset_path) + 1 + + # Iterate all the folders inside + for scene_name in sorted(os.listdir(dataset_path)): + print("We are reading scene ", scene_name) + scene_dir = os.path.join(dataset_path, scene_name) + for task_name in sorted(os.listdir(scene_dir)): + task_dir = os.path.join(scene_dir, task_name) + + for time_clock in sorted(os.listdir(task_dir)): + if time_clock == "lmdb": + continue # Skip lmdb folder + + time_dir = os.path.join(task_dir, time_clock, "raw", "traj_group0") + if not os.path.exists(time_dir): + continue + + for traj_name in sorted(os.listdir(time_dir)): + traj_path = os.path.join(time_dir, traj_name) + if not os.path.isdir(traj_path): + continue + + # Directly move policy_out_file_path; just in case there is also valuable information there + policy_out_file_path = os.path.join(traj_path, "policy_out.pkl") + if not os.path.exists(policy_out_file_path): + continue + + # Check the lang txt file + lang_txt_file_path = os.path.join(traj_path, "lang.txt") + if not os.path.exists(lang_txt_file_path): + continue + + + for img_name in sorted(os.listdir(traj_path)): + if img_name != "images0": # Only consider one camera angle + continue + + img_folder_path = os.path.join(traj_path, img_name) + if not os.path.isdir(img_folder_path): + continue + + ############################################ Main Process #################################################### + + # # First Sanity check (Make sure the input source is jpg good) + # length = len(os.listdir(img_folder_path)) + # status = True + # for check_idx in range(length): + # if not os.path.exists(os.path.join(img_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists + # status = False + # break + + # Now we can copy the folder to our destination + target_lists.append(img_folder_path) + if copyfile: + print("img_folder_path[prefix_len:] is ", img_folder_path[prefix_len:]) + if img_folder_path[prefix_len:] in test_dataset_lists: + # Store to test set + target_dir = os.path.join(test_store_path, str(start_idx)) + else: + # This is training set + target_dir = os.path.join(train_store_path, str(start_idx)) + + print("Copy " + str(img_folder_path) + " to " + str(target_dir)) + shutil.copytree(img_folder_path, target_dir) + + + # Sanity check + length = len(os.listdir(target_dir)) + status = True + for check_idx in range(length): + if not os.path.exists(os.path.join(target_dir, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists + status = False + break + + if not status: + # If they didn't have sequential files we need, we will remove and begin again without updating start_idx + print("This file cannot pass the sanity check. We will remove it!") + shutil.rmtree(target_dir) + continue + + # Move other auxiliary files + shutil.copy(policy_out_file_path, os.path.join(target_dir, "policy_out.pkl")) + shutil.copy(lang_txt_file_path, os.path.join(target_dir, "lang.txt")) + + ################################################################################################################ + + # Update the idx + start_idx += 1 + + print("We have ", start_idx, " number of cases") + + # Return a list of file path + return target_lists + + + +if __name__ == "__main__": + dataset_path = "/Path/to/Bridge/raw/bridge_data_v1/berkeley" # Until Bridge v1 - berkeley section + train_store_path = "/Path/to/Bridge/train/bridge_v1_raw" + test_store_path = "/Path/to/Bridge/train/bridge_v1_test_raw" + test_dataset_predefined_path = "test_path.txt" # This will be providede by us + + + # Make dir if needed + if os.path.exists(train_store_path): + shutil.rmtree(train_store_path) + os.makedirs(train_store_path) + if os.path.exists(test_store_path): + shutil.rmtree(test_store_path) + os.makedirs(test_store_path) + + + # Read Test dataset path + test_dataset_lists = [] + read_file = open(test_dataset_predefined_path, "r") + for line in read_file.readlines(): + test_dataset_lists.append(line[:-1]) + print("test_dataset_lists is ", test_dataset_lists) + + + read_bridge_v1(dataset_path, train_store_path, test_store_path, test_dataset_lists) diff --git a/curation_pipeline/prepare_bridge_v2.py b/curation_pipeline/prepare_bridge_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5c56ea2eaa20ab9edb8d196c4c6eedb2d0a06699 --- /dev/null +++ b/curation_pipeline/prepare_bridge_v2.py @@ -0,0 +1,139 @@ +''' + This repository is used to prepare Bridge dataset +''' +import os, sys, shutil + + + +def read_bridge_v2(dataset_path, train_store_path, test_store_path, test_dataset_lists, copyfile=True): + # copyfile is True most of the time + + start_idx = 0 + target_lists = [] + prefix_len = len(dataset_path) + 1 + + # Iterate all the folders inside + for scene_name in sorted(os.listdir(dataset_path)): + print("We are reading scene ", scene_name) + scene_dir = os.path.join(dataset_path, scene_name) + + for task_name in sorted(os.listdir(scene_dir)): + task_dir = os.path.join(scene_dir, task_name) + + for order_name in sorted(os.listdir(task_dir)): + order_dir = os.path.join(task_dir, order_name) + + for time_clock in sorted(os.listdir(order_dir)): + if time_clock == "lmdb": + continue # Skip lmdb folder + + time_dir = os.path.join(order_dir, time_clock, "raw", "traj_group0") + if not os.path.exists(time_dir): + print("time_dir does not exist for ", time_dir) + continue + + for traj_name in sorted(os.listdir(time_dir)): + traj_path = os.path.join(time_dir, traj_name) + if not os.path.isdir(traj_path): + print("traj_path does not exist for ", traj_path) + continue + + # Directly move policy_out_file_path; just in case there is also valuable information there + policy_out_file_path = os.path.join(traj_path, "policy_out.pkl") + if not os.path.exists(policy_out_file_path): + continue + + # Check the lang txt file + lang_txt_file_path = os.path.join(traj_path, "lang.txt") + if not os.path.exists(lang_txt_file_path): + continue + + + for img_name in sorted(os.listdir(traj_path)): + if img_name != "images0": # Only consider one camera angle + continue + + img_folder_path = os.path.join(traj_path, img_name) + if not os.path.isdir(img_folder_path): + print("img_folder_path does not exist for ", img_folder_path) + continue + + ############################################ Main Process #################################################### + + # # First Sanity check (Make sure the input source is jpg good) + # length = len(os.listdir(img_folder_path)) + # status = True + # for check_idx in range(length): + # if not os.path.exists(os.path.join(img_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists + # status = False + # break + + # Now we can copy the folder to our destination + target_lists.append(img_folder_path) + if copyfile: + print("img_folder_path[prefix_len:] is ", img_folder_path[prefix_len:]) + if img_folder_path[prefix_len:] in test_dataset_lists: + # Store to test set + target_dir = os.path.join(test_store_path, str(start_idx)) + else: + # This is training set + target_dir = os.path.join(train_store_path, str(start_idx)) + + # Now we can copy the folder to our destination + print("Copy " + str(img_folder_path) + " to " + str(os.path.join(train_store_path, str(start_idx)))) + shutil.copytree(img_folder_path, target_dir) + + # Sanity check + length = len(os.listdir(target_dir)) + status = True + for check_idx in range(length): + if not os.path.exists(os.path.join(target_dir, 'im_' + str(check_idx) + '.jpg' )): # Should be sequentially exists + status = False + break + + if not status: + # If they didn't have sequential files we need, we will remove and begin again without updating start_idx + print("This file cannot pass the sanity check. We will remove it!") + shutil.rmtree(target_dir) + continue + + # Move other auxilary files + shutil.copy(policy_out_file_path, os.path.join(target_dir, "policy_out.pkl")) + shutil.copy(lang_txt_file_path, os.path.join(target_dir, "lang.txt")) + + # Update the idx + start_idx += 1 + + print("We have ", start_idx) + + # Return a list of file path + return target_lists + + + +if __name__ == "__main__": + dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2" + train_store_path = "../sanity_check/bridge_v2_raw" + test_store_path = "../sanity_check/bridge_v2_test_raw" + test_dataset_predefined_path = "test_path_v2.txt" + + + # Make dir if needed + if os.path.exists(train_store_path): + shutil.rmtree(train_store_path) + os.makedirs(train_store_path) + if os.path.exists(test_store_path): + shutil.rmtree(test_store_path) + os.makedirs(test_store_path) + + # Read Test dataset path + test_dataset_lists = [] + read_file = open(test_dataset_predefined_path, "r") + for line in read_file.readlines(): + test_dataset_lists.append(line[:-1]) + print("test_dataset_lists is ", test_dataset_lists) + + + read_bridge_v2(dataset_path, train_store_path, test_store_path, test_dataset_lists) + + \ No newline at end of file diff --git a/curation_pipeline/select_frame_with_this_that.py b/curation_pipeline/select_frame_with_this_that.py new file mode 100644 index 0000000000000000000000000000000000000000..92f0ec6545fc3eda6fd5111f5635c24cbb14640b --- /dev/null +++ b/curation_pipeline/select_frame_with_this_that.py @@ -0,0 +1,421 @@ +''' + This repository is used to prepare Bridge dataset with this that conditioning +''' +import os, sys, shutil +import pickle +from ultralytics import YOLO +from PIL import Image, ImageDraw +import numpy as np +import cv2 +import math +import collections +from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry + + +def show_mask(mask, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + + return mask_image * 255 + + +def read_center_point(model, img_path, do_visualization, store_path): + + action_img = Image.open(img_path) + prediction = model.predict(source=action_img, save=False)[0] # Only 1 frame + + if not hasattr(prediction, "boxes"): + print("Detection Fail: We cannot have boxes attribute") + return None, None # -1 means NAN and pass this case + + # save at the temp_places for visualizaiton + if do_visualization: + prediction.save(filename=store_path) + + + bounding_boxes = prediction.boxes.xywh + num, dim = bounding_boxes.shape + assert(dim == 4) + + # Catch up all center point of all bounding boxes + edge_point_cord = [] + center_points = [] + for idx in range(num): + x, y, w, h = bounding_boxes[idx].detach().cpu().numpy() + center_point = [x, y] # TODO: y+(h/4) 根据经验,往下飘逸25%的高度,一般来说比较有帮助 + + edge_point_cord.extend([ (x+w//2, y+h//2), (x-w//2, y+h//2), (x-w//2, y-h//2), (x+w//2, y-h//2) ]) + + + if w <= 15 or h <= 15: # If a bounding box is too small, we will disregard this case + return None, None + + # Calculate the distance between current one and previous points for sanity check + for point in center_points: # Check all previous points + give_up_threshold = 90 + if center_point[0] - point[0] >= give_up_threshold: + print("Two points are too far away and neglect the case") + return None, None + if center_point[1] - point[1] >= give_up_threshold: + print("Two points are too far away and neglect the case") + return None, None + + # Append to the list + center_points.append(center_point) + + + if len(center_points) == 0 or len(center_points) > 2: + print("Detection Fail: We cannot detect bounding boxes") + return None, None + + # Calculating the average distance among center_points + if len(center_points) == 2: + first_box, second_box = center_points + + center_x = (first_box[0] + second_box[0]) / 2 + center_y = (first_box[1] + second_box[1]) / 2 + + distance = math.sqrt(abs(first_box[0] - second_box[0])**2 + abs(first_box[1] - second_box[1])**2) + + return [center_x, center_y, distance], edge_point_cord + + return [*center_points[0], 100], edge_point_cord # if len(center_points) == 1, distance is 0; however, to avoid 2-1-2 box detection in sequential, we set it as a higher value + + + +def detect_gripper(gripper_detection_model, input_dir, action_start, action_end, do_visualization, store_dir, sample_failure_collect_folder=None): + + # 先处理第一个point的(这个比较重要,所以要重复3次);然后再快速处理最后一个point + + # Process the first action frame by iterating next three frames and choose the closest one + first_center_points = [] + edge_point_cords = [] + for idx in range(3): # Repeat 3 times + action_start_path = os.path.join(input_dir, "im_"+str(action_start + idx)+".jpg") + first_center_point, edge_point_cord = read_center_point(gripper_detection_model, action_start_path, do_visualization, os.path.join(store_dir, "contact_first"+str(idx)+".jpg")) # The first frame + + if idx == 0 and first_center_point is None: + message = "Cannot find the first contact point!" + + print("The contact point we cannot detect is at ", action_start_path) + if sample_failure_collect_folder != "": + shutil.copyfile(action_start_path, os.path.join(sample_failure_collect_folder, str(len(os.listdir(sample_failure_collect_folder)))+".jpg") ) + + return (None, None, message) + + if first_center_point is not None: + first_center_points.append([action_start + idx, first_center_point]) + + # Add edge points + print(edge_point_cord) + edge_point_cords.extend(edge_point_cord) # 我有点担心所有point就这么extend会对一些的edge case不是那么robust + + + # Select the closest point between two + first_center_points.sort(key=lambda x: x[1][2]) + first_center_point = first_center_points[0][1][:2] + start_idx = first_center_points[0][0] + print("first_center_point is " + str(first_center_point) + " with idx " + str(start_idx)) + order_idx = [start_idx, action_end] + + + # Find the xmin, ymin, xmax, ymax for based all three points as the bounding box for the SAM + edge_point_cords.sort(key=lambda x: x[0]) + xmin = int(edge_point_cords[0][0]) + xmax = int(edge_point_cords[-1][0]) + + edge_point_cords.sort(key=lambda x: x[1]) + ymin = int(edge_point_cords[0][1]) + ymax = int(edge_point_cords[-1][1]) + + bbox_info = (xmin, xmax, ymin, ymax) + + + # Process the last action frame + action_end_path = os.path.join(input_dir, "im_"+str(action_end)+".jpg") + last_center_point, edge_point_cord = read_center_point(gripper_detection_model, action_end_path, do_visualization, os.path.join(store_dir, "contact_last.jpg")) # The last frame + if last_center_point is None: + message = "Cannot find the last contact point!" + + print("The contact point we cannot detect is at ", action_start_path) + if sample_failure_collect_folder != "": + store_name = str(len(os.listdir(sample_failure_collect_folder))) + ".jpg" + shutil.copyfile(action_start_path, os.path.join(sample_failure_collect_folder, store_name) ) + + return (None, bbox_info, message) + last_center_point = last_center_point[:2] + + + # Check if two center points is too close, if they are too close, we will merge to one point + merge_threshold = 30 + if math.sqrt((first_center_point[0] - last_center_point[0])**2 + (first_center_point[1] - last_center_point[1])**2) <= merge_threshold: + print("Merge two points to one!") + message = "Success!" + return ([[first_center_point], order_idx], bbox_info, message) + + + # Return needed information + message = "Success!" + return ([[first_center_point, last_center_point], order_idx], bbox_info, message) + + + + +def visualize_this_that(base_img, bbox_info, this_that_points): + + # Draw a green dot only for the start point + for point in this_that_points: + print("point is ", point) + target_horizontal, target_vertical = point + target_horizontal, target_vertical = int(target_horizontal), int(target_vertical) + + dot_range = 3 + for i in range(-1*dot_range, dot_range+1): + for j in range(-1*dot_range, dot_range+1): + dil_vertical, dil_horizontal = target_vertical + i, target_horizontal + j + if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): + base_img[dil_vertical, dil_horizontal, :] = [0, 128, 0] + # else: + # # print("The traj is out of boundary!!!!!!!!!!!!!!!!!!!!! and we won't consider it") # 现在 + # return (False, base_img) + + # Draw the bounding box + xmin, xmax, ymin, ymax = bbox_info + base_img = cv2.rectangle(base_img, (xmin, ymin), (xmax, ymax), color=(0,0,255), thickness=2) + + return (True, base_img) + + + +def manage_seq_range(input_dir, store_dir, sample_failure_collect_folder, total_frames_needed, + max_original_input_tolerate, gripper_detection_model, sam_predictor, do_visualization): + + # Find valid image lists + num_frames_input = 0 + for file_name in os.listdir(input_dir): + if file_name.startswith("im_"): + num_frames_input += 1 + for idx in range(num_frames_input): + target_path = os.path.join(input_dir, "im_"+str(idx)+".jpg") + if not os.path.exists(target_path): + print("We don't have ", target_path) + message = "Invalid error" # Make sure that every file in this order is existed, this is quite important + return (False, message) + + + if num_frames_input > max_original_input_tolerate: + message = "The number of frames is too long for constructing the sequence length needed" + return (False, message) + + if num_frames_input < total_frames_needed: + message = "The number of frames is too short for constructing the sequence length needed" + return (False, message) + + + + # Prepare this and that based on policy_out.pkl + policy_out_file_path = os.path.join(input_dir, "policy_out.pkl") + with open(policy_out_file_path, "rb") as f: + policy = pickle.load(f) + + actions_codes = [] + action_start, action_end = None, None + for idx, item in enumerate(policy): + action_value = item["actions"][-1] + if action_start is None and action_value == 0.0: + action_start = idx + + if (action_start is not None) and (action_end is None) and (action_value == 1.0): + action_end = idx # Until record the first 1.0 exists after the first 0.0 appears + actions_codes.append(action_value) + + if action_start is None or action_end is None: + message = "We cannot read an action_start or action_end code!" + return (False, message) # Requires to have both start and end actions (Usually, they are a pair) + + print("actions_codes is ", actions_codes) + print("the start end idx we read is ", action_start, action_end) + + + # Detect the gripper (should return a list with exactly two x,y coordinate points) + detection_retrun_info, bbox_info, detect_message = detect_gripper( + gripper_detection_model, + input_dir, + action_start, + action_end, + do_visualization = do_visualization, + store_dir = store_dir, + sample_failure_collect_folder = sample_failure_collect_folder, + ) + if detection_retrun_info is None: + return (False, detect_message) + + detected_point, old_seq_idx = detection_retrun_info + print("detected_point is ", detected_point) + + + # Visualize if needed + base_img = cv2.imread(os.path.join(input_dir, "im_0.jpg")) + if do_visualization: + status, visual_img = visualize_this_that(base_img, bbox_info, detected_point) + if status: + cv2.imwrite(os.path.join(store_dir, "visualization.png"), visual_img) + + + + # SAM process based on bbox_info + xmin, xmax, ymin, ymax = bbox_info + sam_predictor.set_image(np.uint8(base_img)) + positive_point_cords = np.array([[ int(detected_point[0][0]), int(detected_point[0][1]) ]]) + positive_point_cords = np.array(positive_point_cords) + positive_point_labels = np.ones(len(positive_point_cords)) + + # Predict the mask based on the point and bounding box designed + masks, scores, logits = sam_predictor.predict( + point_coords = positive_point_cords, + point_labels = positive_point_labels, + box = np.array([xmin, ymin, xmax, ymax])[None, :], + multimask_output = False, + ) + print(scores) + for mask_idx, mask in enumerate(masks): + mask_img = show_mask(mask) + cv2.imwrite(os.path.join(store_dir, "mask_" + str(mask_idx) + ".png"), mask_img) + + + + ################################ Move the img ###################################### + # Calculate needed parameters + division_factor = num_frames_input // total_frames_needed + remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation + + # Define the gap + gaps = [division_factor for _ in range(total_frames_needed-1)] + for idx in range(remain_frames): + if idx % 2 == 0: + gaps[idx//2] += 1 # Start to end order + else: + gaps[-1*(1+(idx//2))] += 1 # End to start order + + # Map the gap to the specific orders + idx_orders = [1] # 从1还是shift一下问题应该不大 + for global_idx, gap in enumerate(gaps): + idx_orders.append(idx_orders[-1] + gap) + if idx_orders[-1] >= num_frames_input: + message = "Invalid error" + return (False, message) + # assert(idx_orders[-1] < num_frames_input) + assert(len(idx_orders) == total_frames_needed) + + + # Copy the essential files first + for global_idx, cur_idx in enumerate(idx_orders): + source_path = os.path.join(input_dir, "im_"+str(cur_idx)+".jpg") + destination_path = os.path.join(store_dir, "im_"+str(global_idx)+".jpg") + + if not os.path.exists(source_path): # Theoretically, source_path must exists + message = "We couldn't find the source path. Theoretically, source_path must exists!" # 有一种可能就是我们丢失了一些地方,在cp或者本来就没有,记得统计数量 + return (False, message) + + shutil.copyfile(source_path, destination_path) + + # Map order_idx to the cropped version + mapped_seq_idx = [] + for old_idx in old_seq_idx: + tmp = [] + for tmp_idx, new_idx in enumerate(range(len(idx_orders))): + tmp.append((tmp_idx, abs(old_idx - idx_orders[new_idx]))) + # Sort the smallest fistance + tmp.sort(key=lambda x: x[1]) + mapped_seq_idx.append(tmp[0][0]) + + print("Before the idx is ", old_seq_idx) + print("mapped idx is ", mapped_seq_idx) + + + # Write the information to new destination + f = open(os.path.join(store_dir, "data.txt"), "a") + f.write(str(mapped_seq_idx[0]) + " " + str(detected_point[0][0]) + " " + str(detected_point[0][1]) + "\n") + if len(detected_point) == 2: # Two points excluding the last idx + f.write(str(mapped_seq_idx[1]) + " " + str(detected_point[1][0]) + " " + str(detected_point[1][1]) + "\n") + f.close() + + + # Move lang.txt file + shutil.copyfile(os.path.join(input_dir, 'lang.txt'), os.path.join(store_dir, 'lang.txt')) + + + message = "Success!" + return (True, message) + + + + +if __name__ == "__main__": + + # General storage setting + dataset_path = "../datasets_rob/Bridge_v2_raw" + destination_path = "../sanity_check/bridge_v2_TT14_longer_tolerance" + sample_failure_collect_folder = "" # This is to collect cases that fail for active learning + + total_frames_needed = 14 + max_original_input_tolerate = 56 # 40 for 14 fps; 60 for 25fps; + do_visualization = True + + + # YOLO model init + yolo_pretarined_path = "pretrained/yolov8n_best.pt" + gripper_detection_model = YOLO("yolov8n.yaml") # build a new model from scratch + gripper_detection_model = YOLO(yolo_pretarined_path) # load a pretrained model (recommended for training) + + # SAM model init + model_type = "vit_h" + sam_pretrained_path = "pretrained/sam_vit_h_4b8939.pth" + sam = sam_model_registry[model_type](checkpoint=sam_pretrained_path).to(device="cuda") + sam_predictor = SamPredictor(sam) # There is a lot of setting here + + + # Make dir if needed + if os.path.exists(destination_path): + shutil.rmtree(destination_path) + os.makedirs(destination_path) + + # Prepare the folder to collect failure cases + if sample_failure_collect_folder != "": + if os.path.exists(sample_failure_collect_folder): + shutil.rmtree(sample_failure_collect_folder) + os.makedirs(sample_failure_collect_folder) + + + + # Collect the message + message_dict = collections.defaultdict(int) + + + store_idx = 0 + for folder_name in sorted(os.listdir(dataset_path)): + input_folder_path = os.path.join(dataset_path, folder_name) + store_folder_path = os.path.join(destination_path, "0"*(6-len(str(store_idx)))+str(store_idx)) + print("We are processing ", input_folder_path) + + # Prepare store_folder_path folder + os.makedirs(store_folder_path) + + status, message = manage_seq_range(input_folder_path, store_folder_path, sample_failure_collect_folder, total_frames_needed, max_original_input_tolerate, gripper_detection_model, sam_predictor, do_visualization) + if status: # We will only update the store_idx only when this file is successfully written + store_idx += 1 + else: + print("This status failed! Message: " + message) + shutil.rmtree(store_folder_path) + # break # For debug + + # Collect the infor to dict + message_dict[message] += 1 + + print("We have " + str(store_idx) + " valid dataset") + print("message_dict info is ", message_dict) + diff --git a/curation_pipeline/tracking_by_keypoint.py b/curation_pipeline/tracking_by_keypoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ac036ade8c7d13e09012ad649703f14eecf76b9b --- /dev/null +++ b/curation_pipeline/tracking_by_keypoint.py @@ -0,0 +1,136 @@ +import os, shutil, sys +import argparse +import gdown +import cv2 +import numpy as np +import os +import sys +import requests +import json +import torchvision +import torch +import psutil +import time +try: + from mmcv.cnn import ConvModule +except: + os.system("mim install mmcv") + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from track_anything_code.model import TrackingAnything +from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video +from scripts.compress_videos import compress_video + + + + +if __name__ == "__main__": + dataset_path = "Bridge_v1_TT14" + video_name = "combined.mp4" + verbose = True # If this is verbose, you will continue to write the code + + + ################################################## Model setup #################################################### + # check and download checkpoints if needed + sam_checkpoint = "sam_vit_h_4b8939.pth" + sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + xmem_checkpoint = "XMem-s012.pth" + xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth" + + + folder ="./pretrained" + SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) + xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) + + # argument + args = parse_augment() + args.device = "cuda" # Any GPU is ok + + # Initialize the Track model + track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) + ################################################################################################################### + + + # Iterate all files under the folder + for sub_folder_name in sorted(os.listdir(dataset_path)): + + ################################################## Setting #################################################### + sub_folder_path = os.path.join(dataset_path, sub_folder_name) + + click_state = [[],[]] + interactive_state = { + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + "resize_ratio": 1 + } + ################################################################################################################### + + + video_path = os.path.join(sub_folder_path, video_name) + if not os.path.exists(video_path): + print("We cannot find the path of the ", video_path, " and we will compress one") + status = compress_video(sub_folder_path, video_name) + if not status: + print("We still cannot generate a video") + continue + + # Read video state + video_state = { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30 + } + video_state, template_frame = get_frames_from_video(video_path, video_state, track_model) + + + + ########################################################## Get the sam point based on the data.txt ########################################################### + data_txt_path = os.path.join(sub_folder_path, "data.txt") + if not os.path.exists(data_txt_path): + print("We cannot find data.txt in this folder") + continue + + data_file = open(data_txt_path, 'r') + lines = data_file.readlines() + frame_idx, horizontal, vertical = lines[0][:-2].split(' ') # Only read the first point + point_cord = [int(float(horizontal)), int(float(vertical))] + + # Process by SAM + track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history + painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord) + ################################################################################################################################################################ + + + + ######################################################### Get the tracking output ######################################################################## + + # Track the video for processing + segment_output_path = os.path.join(sub_folder_path, "segment_output.gif") + video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0] # mask_dropdown is empty now + + # Extract the mask needed by us for further point calculating + masks = video_state["masks"] # In the range [0, 1] + + if verbose: + for idx, mask in enumerate(masks): + cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255) + + ############################################################################################################################################################## + + diff --git a/data_loader/video_dataset.py b/data_loader/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc2a55fa801303ad28ef0d5616dc6be3725e408 --- /dev/null +++ b/data_loader/video_dataset.py @@ -0,0 +1,323 @@ +import os, sys +import json +import cv2 +import math +import shutil +import numpy as np +import random +import collections +from PIL import Image +import torch +from torch.utils.data import Dataset + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from utils.img_utils import resize_with_antialiasing, numpy_to_pt + + + +def get_video_frames(config, video_frame_path, flip = False): + + video_seq_length = config["video_seq_length"] + + # Calculate needed parameters + num_frames_input = 0 + for file_name in os.listdir(video_frame_path): + if file_name.startswith("im_"): + num_frames_input += 1 + total_frames_needed = video_seq_length + division_factor = num_frames_input // total_frames_needed + remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation + + + # Define the gap + gaps = [division_factor for _ in range(total_frames_needed-1)] + for idx in range(remain_frames): + if idx % 2 == 0: + gaps[idx//2] += 1 # Start to end order + else: + gaps[-1*(1+(idx//2))] += 1 # End to start order + + + # Find needed file + needed_img_path = [] + cur_idx = 0 + for gap in gaps: + img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg") + needed_img_path.append(img_path) + + # Update the idx + cur_idx += gap + # Append the last one + img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg") + needed_img_path.append(img_path) + + + # Read all img_path based on the order + video_frames = [] + for img_path in needed_img_path: + if not os.path.exists(img_path): + print("We don't have ", img_path) + frame = cv2.imread(img_path) + + try: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + except Exception: + print("The exception places is ", img_path) + + # Resize frames + frame = cv2.resize(frame, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) + + # Flip aug + if flip: + frame = np.fliplr(frame) + + # Collect frames + video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here. + + + # Concatenate + video_frames = np.concatenate(video_frames, axis=0) + assert(len(video_frames) == video_seq_length) + + return video_frames + + + +def tokenize_captions(prompt, tokenizer, config, is_train=True): + ''' + Tokenize text prompt be prepared tokenizer from SD2.1 + ''' + + captions = [] + if random.random() < config["empty_prompts_proportion"]: + captions.append("") + elif isinstance(prompt, str): + captions.append(prompt) + elif isinstance(prompt, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(prompt) if is_train else prompt[0]) + else: + raise ValueError( + f"Caption column should contain either strings or lists of strings." + ) + + inputs = tokenizer( + captions, max_length = tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids[0] + + + +class Video_Dataset(Dataset): + ''' + Video Dataset to load sequential frames for training with needed pre-processing + ''' + + def __init__(self, config, device, normalize=True, tokenizer=None): + + # Attribute variables + self.config = config + self.device = device + self.normalize = normalize + self.tokenizer = tokenizer + + # Obtain values + self.video_seq_length = config["video_seq_length"] + self.height = config["height"] + self.width = config["width"] + + # Process data + self.video_lists = [] + stats_analysis = collections.defaultdict(int) + print("Process all files to check valid datasets....") + for dataset_path in config["dataset_path"]: + for video_name in sorted(os.listdir(dataset_path)): + video_path = os.path.join(dataset_path, video_name) + all_files = os.listdir(video_path) + + + valid = True + # Valid check 1: the number of files should be in sequential order + num_frames_input = 0 + for file_name in os.listdir(video_path): + if file_name.startswith("im_"): + num_frames_input += 1 + for idx in range(num_frames_input): + img_path = 'im_' + str(idx) + '.jpg' + if img_path not in all_files: # Should be sequential existing + valid = False + stats_analysis["incomplete_img"] += 1 + break + + + # Valid check 1.5: the number of files must be longer than video_seq_length and less than self.config["acceleration_tolerance"]*self.config["video_seq_length"] + if num_frames_input < self.config["video_seq_length"]: + stats_analysis["too_little_frames"] += 1 + valid = False + if num_frames_input > self.config["acceleration_tolerance"] * self.config["video_seq_length"]: + stats_analysis["too_many_frames"] += 1 + valid = False + + if not valid: # SpeedUp so set in the middle here + continue + + + # Valid check 2: language if needed + if config["use_text"] and not os.path.exists(os.path.join(dataset_path, video_name, "lang.txt")): + stats_analysis["no_lang_txt"] += 1 + valid = False + + + # Valid check 3: motion if needed + if config["motion_bucket_id"] is None: + flow_path = os.path.join(dataset_path, video_name, "flow.txt") + if "flow.txt" not in all_files: + stats_analysis["no_flow_txt"] += 1 + valid = False + else: + file = open(flow_path, 'r') + info = file.readlines() + if len(info) == 0: + stats_analysis["no_flow_txt"] += 1 + valid = False + + + if valid: + self.video_lists.append(video_path) + print("stats_analysis is ", stats_analysis) + print("Valid dataset length is ", len(self.video_lists)) + + + def __len__(self): + return len(self.video_lists) + + + + def _get_motion_value(self, sub_folder_path): + ''' Read the motion value from the flow.txt file prepared; preprocess the flow to accelerate + ''' + + # Read the flow.txt + flow_path = os.path.join(sub_folder_path, 'flow.txt') + file = open(flow_path, 'r') + info = file.readlines() + per_video_movement = float(info[0][:-2]) + + # Map the raw reflected_motion_bucket_id to target range based on the number of images have + num_frames_input = 0 + for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_ + if file_name.startswith("im_"): + num_frames_input += 1 + + # Correct the value based on the number of frames relative to video_seq_length + per_video_movement_correct = per_video_movement * (num_frames_input/self.config["video_seq_length"]) + + # Map from one Normal Distribution to another Normal Distribution + z = (per_video_movement_correct - self.config["dataset_motion_mean"]) / (self.config["dataset_motion_std"] + 0.001) + reflected_motion_bucket_id = int((z * self.config["svd_motion_std"]) + self.config["svd_motion_mean"]) + + + print("We map " + str(per_video_movement) + " to " + str(per_video_movement_correct) + " by length " + str(num_frames_input) + " to bucket_id of " + str(reflected_motion_bucket_id)) + return reflected_motion_bucket_id + + + + def __getitem__(self, idx): + ''' Get item by idx and pre-process by Resize and Normalize to [0, 1] + Args: + idx (int): The index to the file in the directory + Returns: + video_frames (torch.float32): The Pytorch tensor format of obtained frames (max: 1.0; min: 0.0) + reflected_motion_bucket_id (tensor): Motion value is there is optical flow provided, else they are fixed value from config + prompt (tensor): Tokenized text + ''' + + # Prepare the text if needed: + if self.config["use_text"]: + # Read the file + file_path = os.path.join(self.video_lists[idx], "lang.txt") + file = open(file_path, 'r') + prompt = file.readlines()[0] # Only read the first line + + if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")): + # If we don't have this txt file, we skip + + ######################################################## Mix up prompt ######################################################## + + # Read the file + file_path = os.path.join(self.video_lists[idx], "processed_text.txt") + file = open(file_path, 'r') + prompts = [line for line in file.readlines()] # Only read the first line + + # Get the componenet + action = prompts[0][:-1] + this = prompts[1][:-1] + there = prompts[2][:-1] + + + random_value = random.random() + # If less than 0.4, we don't care, just use the most concrete one + if random_value >= 0.4 and random_value < 0.6: + # Mask pick object to "This" + prompt = action + " this to " + there + elif random_value >= 0.6 and random_value < 0.8: + # Mask place position to "There" + prompt = action + " " + this + " to there" + elif random_value >= 0.8 and random_value < 1.0: + # Just be like "this to there" + prompt = action + " this to there" + + # print("New prompt is ", prompt) + ################################################################################################################################################### + + # else: + # print("We don't have llama processed prompt at ", self.video_lists[idx]) + + else: + prompt = "" + + # Tokenize text prompt + tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config) + + + # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) + flip = False + if random.random() < self.config["flip_aug_prob"]: + if self.config["use_text"]: + if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) + flip = True + else: + flip = True + + + # Read frames for different datasets; Currently, we have WebVid / Bridge + if self.config["dataset_name"] == "Bridge": + video_frames = get_video_frames(self.config, self.video_lists[idx], flip=flip) + else: + raise NotImplementedError("We don't support this dataset loader") + + + # Scale [0, 255] -> [-1, 1] + if self.normalize: + video_frames = video_frames.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32 + + # Transform to Pytorch Tensor in the range [-1, 1] + video_frames = numpy_to_pt(video_frames) + # print("length of input frames has ", len(video_frames)) + + + # Get the motion value based on the optical flow + if self.config["motion_bucket_id"] is None: + reflected_motion_bucket_id = self._get_motion_value(self.video_lists[idx]) + else: + reflected_motion_bucket_id = self.config["motion_bucket_id"] + + + # The tensor we returned is torch float32. We won't cast here for mixed precision training! + return { + "video_frames" : video_frames, + "reflected_motion_bucket_id" : reflected_motion_bucket_id, + "prompt": tokenized_prompt, + } \ No newline at end of file diff --git a/data_loader/video_this_that_dataset.py b/data_loader/video_this_that_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..59013c2bacbdc667e4165737e8e5643ccfb35c5a --- /dev/null +++ b/data_loader/video_this_that_dataset.py @@ -0,0 +1,326 @@ +import os, sys +import json +import cv2 +import math +import shutil +import numpy as np +import random +from PIL import Image +import torch.nn.functional as F +import torch +import os.path as osp +import time +from moviepy.editor import VideoFileClip +from torch.utils.data import Dataset + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from utils.img_utils import resize_with_antialiasing, numpy_to_pt +from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian +from data_loader.video_dataset import tokenize_captions + + +# For the 2D dilation +blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True) + + +def get_thisthat_sam(config, intput_dir, store_dir = None, flip = False, verbose=False): + ''' + Args: + idx (int): The index to the folder we need to process + ''' + + # Read file + file_path = os.path.join(intput_dir, "data.txt") + file1 = open(file_path, 'r') + Lines = file1.readlines() + + + # Initial the optical flow format we want + thisthat_condition = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty + + + # Init the image + sample_img = cv2.imread(os.path.join(intput_dir, "im_0.jpg")) + org_height, org_width, _ = sample_img.shape + + # Prepare masking + controlnet_image_index = [] + coordinate_values = [] + + # Iterate all points in the txt file + for idx in range(len(Lines)): + + # Read points + frame_idx, horizontal, vertical = Lines[idx].split(' ') + frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal)) + + # Read the mask frame idx + controlnet_image_index.append(frame_idx) + coordinate_values.append((vertical, horizontal)) + + + # Init the base image + base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size + base_img.fill(255) + + # Draw square around the target position + dot_range = 10 # Diameter + for i in range(-1*dot_range, dot_range+1): + for j in range(-1*dot_range, dot_range+1): + dil_vertical, dil_horizontal = vertical + i, horizontal + j + if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): + if idx == 0: + base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red + else: + base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point + + # Dilate + if config["dilate"]: + base_img = cv2.filter2D(base_img, -1, blur_kernel) + + + ############################################################################################################################## + ### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store + + # Resize frames Don't use negative and don't resize in [0,1] + base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) + + + # Flip the image for aug if needed + if flip: + base_img = np.fliplr(base_img) + + + # Channel Transform and Range Shift + if config["conditioning_channels"] == 3: + # Map to [0, 1] range + if store_dir is not None and verbose: # For the first frame condition visualization + cv2.imwrite(os.path.join(store_dir, "condition_TT"+str(idx)+".png"), base_img) + base_img = base_img / 255.0 + + else: + raise NotImplementedError() + + + # ReOrganize shape + base_img = base_img.transpose(2, 0, 1) # hwc -> chw + + + # Check the min max value range + # if verbose: + # print("{} min, max range value is {} - {}".format(intput_dir, np.min(base_img), np.max(base_img))) + + + # Write base img based on frame_idx + thisthat_condition[frame_idx] = base_img # Only the first frame, the rest is 0 initialized + + ############################################################################################################################## + + + if config["motion_bucket_id"] is None: + # take the motion to stats collected before + reflected_motion_bucket_id = 200 + else: + reflected_motion_bucket_id = config["motion_bucket_id"] + + + # print("Motion Bucket ID is ", reflected_motion_bucket_id) + return (thisthat_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values) + + + +class Video_ThisThat_Dataset(Dataset): + ''' + Video Dataset to load sequential frames for training with needed pre-processing and process with optical flow + ''' + + def __init__(self, config, device, normalize=True, tokenizer=None): + # Attribute variables + self.config = config + self.device = device + self.normalize = normalize + self.tokenizer = tokenizer + + # Obtain values + self.video_seq_length = config["video_seq_length"] + self.height = config["height"] + self.width = config["width"] + + # Process data + self.video_lists = [] + for dataset_path in config["dataset_path"]: + for video_name in sorted(os.listdir(dataset_path)): + if not os.path.exists(os.path.join(dataset_path, video_name, "data.txt")): + continue + + self.video_lists.append(os.path.join(dataset_path, video_name)) + print("length of the dataset is ", len(self.video_lists)) + + + + + def __len__(self): + return len(self.video_lists) + + + def _extract_frame_bridge(self, idx, flip=False): + ''' Extract the frame in video based on the needed fps from already extracted frame + Args: + idx (int): The index to the file in the directory + flip (bool): Bool for whether we will flip + Returns: + video_frames (numpy): Extracted video frames in numpy format + ''' + + # Init the the Video Reader + # The naming of the Bridge dataset follow a pattern: im_x.jpg, so we need to + video_frame_path = self.video_lists[idx] + + + # Find needed file + needed_img_path = [] + for idx in range(self.video_seq_length): + img_path = os.path.join(video_frame_path, "im_" + str(idx) + ".jpg") + needed_img_path.append(img_path) + + + + # Read all img_path based on the order + video_frames = [] + for img_path in needed_img_path: + if not os.path.exists(img_path): + print("We don't have ", img_path) + frame = cv2.imread(img_path) + + try: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + except Exception: + print("The exception place is ", img_path) + # Resize frames + frame = cv2.resize(frame, (self.width, self.height), interpolation = cv2.INTER_CUBIC) + + # Flip aug + if flip: + frame = np.fliplr(frame) + + # Collect frames + video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here. + + + # Concatenate + video_frames = np.concatenate(video_frames, axis=0) + assert(len(video_frames) == self.video_seq_length) + + # Returns + return video_frames + + + + + def __getitem__(self, idx): + ''' Get item by idx and pre-process by Resize and Normalize to [0, 1] + Args: + idx (int): The index to the file in the directory + Returns: + return_dict (dict): video_frames (torch.float32) [-1, 1] and controlnet_condition (torch.float32) [0, 1] + ''' + + # Prepare the text if needed: + if self.config["use_text"]: + # Read the file + file_path = os.path.join(self.video_lists[idx], "lang.txt") + file = open(file_path, 'r') + prompt = file.readlines()[0] # Only read the first line + + if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")): + # If we don't have this txt file, we skip + + ######################################################## Mix up prompt ######################################################## + + # Read the file + file_path = os.path.join(self.video_lists[idx], "processed_text.txt") + file = open(file_path, 'r') + prompts = [line for line in file.readlines()] # Only read the first line + + # Get the componenet + action = prompts[0][:-1] + this = prompts[1][:-1] + there = prompts[2][:-1] + + + random_value = random.random() + # If less than 0.4, we don't care, just use the most concrete one + if random_value >= 0.4 and random_value < 0.6: + # Mask pick object to "This" + prompt = action + " this to " + there + elif random_value >= 0.6 and random_value < 0.8: + # Mask place position to "There" + prompt = action + " " + this + " to there" + elif random_value >= 0.8 and random_value < 1.0: + # Just be like "this to there" + prompt = action + " this to there" + + # print("New prompt is ", prompt) + ################################################################################################################################################### + + # else: + # print("We don't have llama processed prompt at ", self.video_lists[idx]) + + else: + prompt = "" + + # Tokenize text prompt + tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config) + + + + # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) + flip = False + if random.random() < self.config["flip_aug_prob"]: + if self.config["use_text"]: + if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) + flip = True + else: + flip = True + + + + # Read frames for different dataset; Currently, we have WebVid / Bridge + if self.config["dataset_name"] == "Bridge": + video_frames_raw = self._extract_frame_bridge(idx, flip=flip) + else: + raise NotImplementedError("We don't support this dataset loader") + + + # Scale [0, 255] -> [-1, 1] if needed + if self.normalize: + video_frames = video_frames_raw.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32 + + # Transform to Pytorch Tensor in the range [-1, 1] + video_frames = numpy_to_pt(video_frames) + + + # Generate the pairs we need + intput_dir = self.video_lists[idx] + + # Get the This That point information + controlnet_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(self.config, intput_dir, flip=flip) + controlnet_condition = torch.from_numpy(controlnet_condition) + + # Cast other value to tensor + reflected_motion_bucket_id = torch.tensor(reflected_motion_bucket_id, dtype=torch.float32) + controlnet_image_index = torch.tensor(controlnet_image_index, dtype=torch.int32) + coordinate_values = torch.tensor(coordinate_values, dtype=torch.int32) + + + # The tensor we returned is torch float32. We won't cast here for mixed precision training! + return {"video_frames" : video_frames, + "controlnet_condition" : controlnet_condition, + "reflected_motion_bucket_id" : reflected_motion_bucket_id, + "controlnet_image_index": controlnet_image_index, + "prompt": tokenized_prompt, + "coordinate_values": coordinate_values, # Useless now, but I still passed back + } + diff --git a/pretrained/PUT_YOUR_WEIGHT_HERE.md b/pretrained/PUT_YOUR_WEIGHT_HERE.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0b8d4fb982d228438e1d412b8abe4a87bcb68fe8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +# Non-strict version lib +opencv-python +transformers +accelerate +requests +moviepy +omegaconf +# xformers +tensorboard +einops +yacs +loguru +imageio +pyparsing +ultralytics +lpips +matplotlib +gradio +torch==2.0.1 +torchvision + +# Strict version lib +bitsandbytes==0.43.0 +diffusers==0.25.1 +timm==0.4.12 +scipy==1.9.3 +pyiqa==0.1.7 \ No newline at end of file diff --git a/scripts/active_learning_select.py b/scripts/active_learning_select.py new file mode 100644 index 0000000000000000000000000000000000000000..70142ad2729100fdfaba520f716fed21ee288358 --- /dev/null +++ b/scripts/active_learning_select.py @@ -0,0 +1,27 @@ +import os, shutil +import random + + +if __name__ == "__main__": + start_idx = 950 + end_idx = 1020 + select_num = 70 + + label_start_idx = 632 + input_parent_dir = "../Bridge" + store_dir = "../bridge_select3" + + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + for idx in range(start_idx, end_idx): + folder_path = os.path.join(input_parent_dir, str(idx)) + select_idx = random.randint(0, len(os.listdir(folder_path))) + for idx, img_name in enumerate(os.listdir(folder_path)): + if idx == select_idx and img_name != "policy_out.pkl": + img_path = os.path.join(folder_path, img_name) + target_path = os.path.join(store_dir, str(label_start_idx) + ".jpg") + label_start_idx += 1 + shutil.copy(img_path, target_path) + diff --git a/scripts/add_point2img.py b/scripts/add_point2img.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0e4895ad46258b8f80f63f55c1fd39c898ebd2 --- /dev/null +++ b/scripts/add_point2img.py @@ -0,0 +1,51 @@ +''' + This file is to add point to the first image +''' + +import os, shutil, sys + +if __name__ == "__main__": + input_folder_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/Human_Study/Input_Bridge_human_evaluation" + store_path = "point_highlighted" + + if os.path.exists(input_folder_path): + shutil.rmtree(input_folder_path) + os.makedirs(input_folder_path) + + + for instance_name in os.listdir(input_folder_path): + + sub_folder_dir = os.path.join(input_folder_path, instance_name) + + # Read file + file_path = os.path.join(sub_folder_dir, "data.txt") + file1 = open(file_path, 'r') + Lines = file1.readlines() + + # Read the first img + first_img_path = os.path.join(sub_folder_dir, "im_0.jpg") + + + # Init the image + base_img = cv2.imread(first_img_path).astype(np.float32) # Use the original image size + + # Draw the point + for idx in range(len(Lines)): + # Read points + frame_idx, horizontal, vertical = Lines[idx].split(' ') + frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal)) + + # Draw square around the target position + dot_range = 15 # Diameter + for i in range(-1*dot_range, dot_range+1): + for j in range(-1*dot_range, dot_range+1): + dil_vertical, dil_horizontal = vertical + i, horizontal + j + if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): + if idx == 0: + base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red + else: + base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point + + + + diff --git a/scripts/check_video.py b/scripts/check_video.py new file mode 100644 index 0000000000000000000000000000000000000000..53beca2e2885c8c5c7cc66676de1f8b6b91e6bd7 --- /dev/null +++ b/scripts/check_video.py @@ -0,0 +1,19 @@ +''' + This file is to make sure that the video files is readeable by moviepy, such that the data loader can read these files. +''' +import os +from moviepy.editor import VideoFileClip + +if __name__ == "__main__": + video_dir = "../webvid_sample" + delete_abnormal_video = True # Whether you want to delete these abnormal video directly + + for video_name in sorted(os.listdir(video_dir)): + video_path = os.path.join(video_dir, video_name) + try: + objVideoreader = VideoFileClip(filename=video_path) + except Exception: + print("There is an exception of reading: ", video_path) + if delete_abnormal_video: + print("We will remove this abnormal video source") + os.remove(video_path) \ No newline at end of file diff --git a/scripts/clean_bridge_dataset.py b/scripts/clean_bridge_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6ae157f609da9e66c9f7cdd117a67709f73325 --- /dev/null +++ b/scripts/clean_bridge_dataset.py @@ -0,0 +1,22 @@ +''' + Sometimes, Bridge dataset will contain strange downloads, we need to clean them +''' +import os, shutil + +# TODO: 后面把这个直接merge 到prepare_bridge_dataset中 +if __name__ == "__main__": + dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/Bridge" + + for sub_folder in sorted(os.listdir(dataset_path)): + sub_folder_path = os.path.join(dataset_path, sub_folder) + + img_lists = os.listdir(sub_folder_path) + if len(img_lists) < 14: + print("The folder is too short, we will remove them all") + shutil.rmtree(sub_folder_path) + continue + for img_name in img_lists: + img_path = os.path.join(sub_folder_path, img_name) + if not img_name.startswith("im_"): + print("We remove ", img_path) + os.remove(img_path) \ No newline at end of file diff --git a/scripts/collect_lang.py b/scripts/collect_lang.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c6a263b7da1787f7549bf356ae60eec798b4d0 --- /dev/null +++ b/scripts/collect_lang.py @@ -0,0 +1,31 @@ +''' + THis file is to collect all lang.txt and move to a new directory, this is for the convenience to compress and scp the lang for post-processing +''' +import os, sys, shutil + +if __name__ == "__main__": + parent_dir = "../datasets_rob" + dataset_paths = ["Bridge_v1_TT14", "Bridge_v2_TT14"] + store_folder = "../full_text_tmp" + + # Manage the store folder + if os.path.exists(store_folder): + shutil.rmtree(store_folder) + os.makedirs(store_folder) + + + for dataset_name in dataset_paths: + store_path = os.path.join(store_folder, dataset_name) + if os.path.exists(store_path): + shutil.rmtree(store_path) + os.makedirs(store_path) + + # Iterate all the files + for sub_folder_name in os.listdir(os.path.join(parent_dir, dataset_name)): + print("We are processing ", sub_folder_name) + lang_txt_path = os.path.join(parent_dir, dataset_name, sub_folder_name, "lang.txt") + + # Store on the new address + store_file_path = os.path.join(store_path, sub_folder_name) + os.makedirs(store_file_path) + shutil.copyfile(lang_txt_path, os.path.join(store_file_path, "lang.txt")) diff --git a/scripts/combine_results.py b/scripts/combine_results.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb83b8f9d48b9fcd1055c3e32a75a5666347e76 --- /dev/null +++ b/scripts/combine_results.py @@ -0,0 +1,85 @@ +''' + This repo is to combine multiple generated images with same index together +''' + +import os, shutil, sys +import imageio +import math +import cv2 +from PIL import Image +import collections +import numpy as np + + +if __name__ == "__main__": + + # Basic setting + data_paths = [ + "human_evaluation_v3_V_raw_prompt", + "human_evaluation_v3_VG_raw_prompt_no_sam", + "human_evaluation_v3_VL_ambiguous_prompt", + + "../datasets_rob/Bridge_human_evaluation", + + "human_evaluation_v3_VL_raw_prompt", + "human_evaluation_v3_VGL_raw_prompt_no_sam", + "human_evaluation_v3_VGL_ambiguous_prompt_no_sam", + ] + store_path = "combined_results_human_evaluation" + sample_data_path = data_paths[0] + gif_per_row = 4 # Number of GIF files per row + + + # Create folder + if os.path.exists(store_path): + shutil.rmtree(store_path) + os.makedirs(store_path) + + + # Iterate the sample + for instance_idx, sub_folder_name in enumerate(sorted(os.listdir(sample_data_path))): + print("we are processing ", sub_folder_name) + + collected_gif_paths = [] + for data_path in data_paths: + collected_gif_paths.append(os.path.join(data_path, sub_folder_name, 'combined.gif')) + + # Merge frames together + rows = math.ceil(len(collected_gif_paths) / gif_per_row) + cols = gif_per_row + + # Read all input GIFs and find maximum dimensions + gifs = [] + max_width, max_height = 0, 0 + for path in collected_gif_paths: + gif = imageio.mimread(path) + max_width = max(max_width, gif[0].shape[1]) + max_height = max(max_height, gif[0].shape[0]) + gifs.append(gif) + + # Create blank canvas for concatenated GIF + frames_length = len(gifs[0]) + canvas_width = max_width * cols + canvas_height = max_height * rows + canvas = np.zeros((frames_length, canvas_height, canvas_width, 3), dtype=np.uint8) + + + # push each frame into the canvas placeholder + gif_index = 0 + for row in range(rows): + for col in range(cols): + gif = gifs[gif_index] + gif_height, gif_width, _ = gif[0].shape + start_y = row * max_height + start_x = col * max_width + for i in range(frames_length): + canvas[i, start_y:start_y+gif_height, start_x:start_x+gif_width, :] = gif[i] + + # Update index + gif_index += 1 + if gif_index == len(collected_gif_paths): + break + + + # Write the concatenated GIF + imageio.mimsave(os.path.join(store_path, sub_folder_name + ".gif"), canvas, duration=0.05, quality=100) \ No newline at end of file diff --git a/scripts/compress_gif.py b/scripts/compress_gif.py new file mode 100644 index 0000000000000000000000000000000000000000..a659ac1c487dad0c62e33ad928a68d136040b2b6 --- /dev/null +++ b/scripts/compress_gif.py @@ -0,0 +1,52 @@ +import os, shutil, sys +import cv2 +import imageio +import numpy as np + + +def compress_gif(sub_folder_path): + + # Check valid length + all_files = os.listdir(sub_folder_path) + num_frames_input = 0 + valid = True + for file_name in os.listdir(sub_folder_path): + if file_name.startswith("im_"): + num_frames_input += 1 + for idx in range(num_frames_input): + img_path = 'im_' + str(idx) + '.jpg' + if img_path not in all_files: # Should be sequential existing + valid = False + break + if not valid: + print("We cannot generate a video because the video is not sequential") + return False + + + if num_frames_input == 0: + print("We cannot generate a video because the input length is 0") + return False + + img_lists = [] + for idx in range(num_frames_input): + img_path = os.path.join(sub_folder_path, "im_" + str(idx) + ".jpg") + img_lists.append(cv2.resize(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB), (384, 256))) + + imageio.mimsave(os.path.join(sub_folder_path, 'combined.gif'), np.array(img_lists), duration=0.05, quality=100) + + return True + + +if __name__ == "__main__": + dataset_path = "../datasets_rob/Bridge_human_evaluation" # ../datasets_rob/Bridge_v1_raw + + for sub_folder_name in sorted(os.listdir(dataset_path)): + print("We are processing ", sub_folder_name) + sub_folder_path = os.path.join(dataset_path, sub_folder_name) + + status = compress_gif(sub_folder_path) + + + + + \ No newline at end of file diff --git a/scripts/compress_videos.py b/scripts/compress_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..671c40531c6b336fc8cf44734b9790b67e090723 --- /dev/null +++ b/scripts/compress_videos.py @@ -0,0 +1,55 @@ +import os, shutil, sys +from moviepy.editor import ImageSequenceClip + + +def compress_video(sub_folder_path, video_name): + store_path = os.path.join(sub_folder_path, video_name) + + if os.path.exists(store_path): + os.remove(store_path) + + + # Check valid length + all_files = os.listdir(sub_folder_path) + num_frames_input = 0 + valid = True + for file_name in os.listdir(sub_folder_path): + if file_name.startswith("im_"): + num_frames_input += 1 + for idx in range(num_frames_input): + img_path = 'im_' + str(idx) + '.jpg' + if img_path not in all_files: # Should be sequential existing + valid = False + break + if not valid: + print("We cannot generate a video because the video is not sequential") + return False + + + if num_frames_input == 0: + print("We cannot generate a video because the input length is 0") + return False + + img_lists = [] + for idx in range(num_frames_input): + img_path = os.path.join(sub_folder_path, "im_" + str(idx) + ".jpg") + img_lists.append(img_path) + + clip = ImageSequenceClip(img_lists, fps=4) + clip.write_videofile(store_path) + + return True + + +if __name__ == "__main__": + dataset_path = "../datasets_rob/Bridge_v2_raw" # ../datasets_rob/Bridge_v1_raw + + for sub_folder_name in sorted(os.listdir(dataset_path)): + sub_folder_path = os.path.join(dataset_path, sub_folder_name) + + status = compress_video(sub_folder_path) + + + + + \ No newline at end of file diff --git a/scripts/crop_video_frames.py b/scripts/crop_video_frames.py new file mode 100644 index 0000000000000000000000000000000000000000..78093bb8dfb5ecfd6f38033809a3d5d8fca8dcda --- /dev/null +++ b/scripts/crop_video_frames.py @@ -0,0 +1,22 @@ +''' + This file is to split the video sources in a folder to folder with images, for the mass evaluation +''' +import os, shutil, sys +import cv2 + + +if __name__ == "__main__": + input_folder = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/StreamingT2V_results" + needed_frame_length = 14 + + idx = 0 + for file_name in sorted(os.listdir(input_folder)): + print("We are processing ", file_name) + sub_folder_path = os.path.join(input_folder, file_name) + + for idx in range(len(os.listdir(sub_folder_path))): + if idx >= needed_frame_length: + target_path = os.path.join(sub_folder_path, str(idx)+".png") + os.remove(target_path) + + diff --git a/scripts/extract_test_dataset.py b/scripts/extract_test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ad850c93c935d1714fcd09290e5993702b651d --- /dev/null +++ b/scripts/extract_test_dataset.py @@ -0,0 +1,18 @@ +''' + Extract the test dataset from the txt file +''' + +if __name__ == "__main__": + txt_path = "match_info_v2.txt" + store_path = "test_path_v2.txt" + start_idx = len("/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2/") + + read_file = open(txt_path, "r") + write_file = open(store_path, "w") + for line in read_file.readlines(): + test_dataset_path = line.split(' ')[1] + test_instance = test_dataset_path[start_idx:] + + write_file.write(test_instance) + + \ No newline at end of file diff --git a/scripts/generate_noise.py b/scripts/generate_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..e5080bd1f34be2ec4ef8e3be9df7a2468c9af6f5 --- /dev/null +++ b/scripts/generate_noise.py @@ -0,0 +1,14 @@ +import cv2 +import numpy as np +import matplotlib.pyplot as plt + +# Set the dimensions of the image +height = 256 +width = 256 + +# Generate random pixel values +noise = np.random.rand(height, width, 3) * 255 # Scale to 255 for grayscale image + + +for idx in range (4): + cv2.imwrite("noise"+str(idx)+".png", noise) \ No newline at end of file diff --git a/scripts/generate_sam.py b/scripts/generate_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..76e9ad692662ad2a83c829cea63ae749c5ad250f --- /dev/null +++ b/scripts/generate_sam.py @@ -0,0 +1,56 @@ +import os, sys, shutil +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry + + +def show_anns(anns): + if len(anns) == 0: + return + + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(True) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3)) + # img[:,:,3] = 0 + for ann in sorted_anns: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3)]) + img[m] = color_mask + + return img*255 + + + + +if __name__ == "__main__": + input_parent_folder = "../Bridge_filter_flow" + + + # Init SAM for segmentation task + model_type = "vit_h" + weight_path = "pretrained/sam_vit_h_4b8939.pth" + + + + sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda") + mask_generator = SamAutomaticMaskGenerator(sam) # There is a lot of setting here + + + for sub_dir_name in sorted(os.listdir(input_parent_folder)): + print("We are processing ", sub_dir_name) + ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg') + store_path = os.path.join(input_parent_folder, sub_dir_name, 'sam.png') + + image = cv2.imread(ref_img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + mask = mask_generator.generate(image) + mask_img = show_anns(mask) + + cv2.imwrite(store_path, mask_img) + + + diff --git a/scripts/generate_sam_this_that.py b/scripts/generate_sam_this_that.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd61024edb9c111e9f2b8a4bb1b7e68618ae223 --- /dev/null +++ b/scripts/generate_sam_this_that.py @@ -0,0 +1,108 @@ +import os, sys, shutil +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry + + +def show_anns(anns): + if len(anns) == 0: + return + + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(True) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3)) + # img[:,:,3] = 0 + for ann in sorted_anns: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3)]) + img[m] = color_mask + + return img*255 + + +def show_mask(mask, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + + return mask_image * 255 + + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1) + + +if __name__ == "__main__": + input_parent_folder = "validation_tmp" + + + # Init SAM for segmentation task + model_type = "vit_h" + weight_path = "pretrained/sam_vit_h_4b8939.pth" + + + + sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda") + sam_predictor = SamPredictor(sam) + mask_generator = SamAutomaticMaskGenerator(sam) + + + # Iterate the folder + for sub_dir_name in sorted(os.listdir(input_parent_folder)): + print("We are processing ", sub_dir_name) + ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg') + data_txt_path = os.path.join(input_parent_folder, sub_dir_name, 'data.txt') + + + # Read the image and process + image = cv2.imread(ref_img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + + # Read the positive point + data_file = open(data_txt_path, 'r') + lines = data_file.readlines() + for idx in range(len(lines)): + frame_idx, horizontal, vertical = lines[idx].split(' ') + vertical, horizontal = int(float(vertical)), int(float(horizontal)) + positive_point_cords = [[horizontal, vertical]] + + positive_point_cords = np.array(positive_point_cords) + positive_point_labels = np.ones(len(positive_point_cords)) + print(positive_point_cords) + + + + # Set the SAM predictor + sam_predictor.set_image(np.uint8(image)) + masks, scores, logits = sam_predictor.predict( + point_coords = positive_point_cords, # Only positive points here + point_labels = positive_point_labels, + multimask_output = False, + ) + # print("Detected mask length is ", len(masks)) + + # Visualize + mask_img = show_mask(masks[0]) + cv2.imwrite(os.path.join(input_parent_folder, sub_dir_name, "first_contact0.png"), mask_img) + + break + + + # SAM all + sam_all = mask_generator.generate(image) + all_sam_imgs = show_anns(sam_all) + cv2.imwrite("sam_all.png", all_sam_imgs) + + + + diff --git a/scripts/generate_traj.py b/scripts/generate_traj.py new file mode 100644 index 0000000000000000000000000000000000000000..15a1b4646dab39272b6d7461bf9ca4dfd876198b --- /dev/null +++ b/scripts/generate_traj.py @@ -0,0 +1,601 @@ +import sys +import argparse +import copy +import os, shutil +import imageio +import cv2 +from PIL import Image, ImageDraw +import os.path as osp +import random +import numpy as np +import torch.multiprocessing as mp +from multiprocessing import set_start_method +import math, time, gc +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry + + +# Import files from the local path +root_path = os.path.abspath('.') +sys.path.append(root_path) +from config.flowformer_config import get_cfg +from flowformer_code.utils import flow_viz, frame_utils +from flowformer_code.utils.utils import InputPadder +from flowformer_code.FlowFormer import build_flowformer + + + + +TRAIN_SIZE = [432, 960] + +def show_anns(anns): + if len(anns) == 0: + return + + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(True) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) + img[:,:,3] = 0 + for ann in sorted_anns: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3), [0.35]]) + img[m] = color_mask + + return img*255 + + +def show_mask(mask, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + + return mask_image * 255 + + +def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20): + if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]: + raise ValueError( + f"Overlap should be less than size of patch (got {min_overlap}" + f"for patch size {patch_size}).") + if image_shape[0] == TRAIN_SIZE[0]: + hs = list(range(0, image_shape[0], TRAIN_SIZE[0])) + else: + hs = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap)) + if image_shape[1] == TRAIN_SIZE[1]: + ws = list(range(0, image_shape[1], TRAIN_SIZE[1])) + else: + ws = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap)) + + # Make sure the final patch is flush with the image boundary + hs[-1] = image_shape[0] - patch_size[0] + ws[-1] = image_shape[1] - patch_size[1] + return [(h, w) for h in hs for w in ws] + + + +def compute_flow(model, image1, image2, weights=None): + print(f"computing flow...") + + image_size = image1.shape[1:] + + image1, image2 = image1[None].cuda(), image2[None].cuda() + + hws = compute_grid_indices(image_size) + if weights is None: # no tile + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_pre, _ = model(image1, image2) + + flow_pre = padder.unpad(flow_pre) + flow = flow_pre[0].permute(1, 2, 0).cpu().numpy() + else: # tile + flows = 0 + flow_count = 0 + + for idx, (h, w) in enumerate(hws): + image1_tile = image1[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]] + image2_tile = image2[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]] + flow_pre, _ = model(image1_tile, image2_tile) + padding = (w, image_size[1]-w-TRAIN_SIZE[1], h, image_size[0]-h-TRAIN_SIZE[0], 0, 0) + flows += F.pad(flow_pre * weights[idx], padding) + flow_count += F.pad(weights[idx], padding) + + flow_pre = flows / flow_count + flow = flow_pre[0].permute(1, 2, 0).cpu().numpy() + + return flow + + +def compute_adaptive_image_size(image_size): + target_size = TRAIN_SIZE + scale0 = target_size[0] / image_size[0] + scale1 = target_size[1] / image_size[1] + + if scale0 > scale1: + scale = scale0 + else: + scale = scale1 + + image_size = (int(image_size[1] * scale), int(image_size[0] * scale)) + + return image_size + + +def prepare_image(viz_root_dir, fn1, fn2, keep_size): + print(f"preparing image...") + + image1 = frame_utils.read_gen(fn1) + image2 = frame_utils.read_gen(fn2) + image1 = np.array(image1).astype(np.uint8)[..., :3] + image2 = np.array(image2).astype(np.uint8)[..., :3] + if not keep_size: + dsize = compute_adaptive_image_size(image1.shape[0:2]) + image1 = cv2.resize(image1, dsize=dsize, interpolation=cv2.INTER_CUBIC) + image2 = cv2.resize(image2, dsize=dsize, interpolation=cv2.INTER_CUBIC) + image1 = torch.from_numpy(image1).permute(2, 0, 1).float() + image2 = torch.from_numpy(image2).permute(2, 0, 1).float() + + + dirname = osp.dirname(fn1) + filename = osp.splitext(osp.basename(fn1))[0] + + viz_dir = osp.join(viz_root_dir, dirname) + # if not osp.exists(viz_dir): + # os.makedirs(viz_dir) + + viz_fn = osp.join(viz_dir, filename + '.png') + + return image1, image2, viz_fn + + +def build_model(): + print(f"building model...") + cfg = get_cfg() + model = torch.nn.DataParallel(build_flowformer(cfg)) + model.load_state_dict(torch.load(cfg.model)) + + model.cuda() + model.eval() + + return model + + +def filter_uv(flow, threshold_factor = 0.2): + u = flow[:,:,0] + v = flow[:,:,1] + + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + threshold = threshold_factor * rad_max + flow[:,:,0][rad < threshold] = 0 + flow[:,:,1][rad < threshold] = 0 + + return flow + + +def visualize_traj(base_img, traj_path, connect_points = True): + target_vertical, target_horizontal = traj_path[-1] + + if connect_points and len(traj_path) > 1: + # Draw a line to connect two point to show motion direction + start_coordinate = (traj_path[-2][1], traj_path[-2][0]) + end_coordinate = (traj_path[-1][1], traj_path[-1][0]) + pil_img = Image.fromarray(base_img) + + # Draw the line + color = 'red' + draw = ImageDraw.Draw(pil_img) + draw.line([start_coordinate, end_coordinate], fill = color, width = 3) + + base_img = np.array(pil_img) + + + # Draw a green dot only for the start point + if len(traj_path) == 1: + dot_range = 3 + for i in range(-1*dot_range, dot_range+1): + for j in range(-1*dot_range, dot_range+1): + dil_vertical, dil_horizontal = target_vertical + i, target_horizontal + j + if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): + base_img[dil_vertical][dil_horizontal] = [0, 128, 0] + else: + print("The traj is out of boundary!!!!!!!!!!!!!!!!!!!!! and we won't consider it") # 现在 + return (False, base_img) + + return (True, base_img) + + + +def calculate_flow(viz_root_dir, store_dir, img_pairs, optical_flow_model, sam_predictor, SAM_positive_sample_num, SAM_negative_sample_num, mask_generator, traj_visualization, keep_size, verbose=False): + + # Trajectory prepare + traj_path = [] # It collects all points traversed in a temporal order + is_hard_to_track = False # If this is True, it means that, we have a time in tracking hard to find dx and dy movement. Under this circumstance, we are not very recommended to use it + hard_track_idxs = set() + traj_image_lists = [] + + + # Iterate all image pairs + for idx, img_pair in enumerate(img_pairs): + + fn1, fn2 = img_pair + print(f"processing {fn1}, {fn2}...") + + image1, image2, viz_fn = prepare_image(viz_root_dir, fn1, fn2, keep_size) # Be very careful, image1 and image2 may be different resolution shape if keep_size is False + # Generate the optical flow and filter those that is small motion + flow_uv = filter_uv(compute_flow(optical_flow_model, image1, image2, None)) + + # if verbose: + # Store the visualization of flow_uv + # flow_img = flow_viz.flow_to_image(flow_uv) + # cv2.imwrite("optical_flow_" + str(idx+1) + ".png", flow_img[:, :, [2,1,0]]) + + if idx == 0: + # We will store the first image to memory for further visualization purpose + + # Base img + # base_img = np.uint8(np.transpose(image1.numpy(), (1,2,0))) + + # SAM figure + # sam_all = mask_generator.generate(image1) + # base_img = show_anns(sam_all) + # base_img = np.transpose(base_img, (1,2,0)) + + # Plain white image + base_img = np.zeros(np.transpose(image1.numpy(), (1,2,0)).shape, dtype=np.uint8) + base_img.fill(255) + + + + + # Extract moving points (positive point) + positive_point_cords = [] + nonzeros = np.nonzero(flow_uv) # [(vertical), (horizontal)] + if len(nonzeros[0]) < SAM_positive_sample_num: + # We require the number of points to be more than SAM_positive_sample_num + return False + positive_orders = np.random.choice(len(nonzeros[0]), SAM_positive_sample_num, replace=False) # we have randomly select instead of use all in the sam_predictor prediction + for i in range(len(nonzeros[0])): + if i in positive_orders: + positive_point_cords.append([nonzeros[1][i], nonzeros[0][i]]) # 根据document来看,这个就应该是先horizontal再vertical,也就是这个顺序 + positive_point_cords = np.array(positive_point_cords) + positive_point_labels = np.ones(len(positive_point_cords)) + + + # Define negative sample (outside the optical flow choice) + if SAM_negative_sample_num != 0: + skip_prob = 2 * SAM_negative_sample_num / (flow_uv.shape[0]*flow_uv.shape[1] - len(nonzeros[0])) + negative_point_cords = [] + for i in range(flow_uv.shape[0]): + for j in range(flow_uv.shape[1]): + if flow_uv[i][j][0] == 0 and flow_uv[i][j][1] == 0: # 0 means the no motion zone and we have already filter low motion as zero before + if random.random() < skip_prob: + negative_point_cords.append([j, i]) # 根据document来看,这个就应该是先horizontal再vertical,也就是这个顺序 + negative_point_cords = np.array(negative_point_cords) # [:SAM_negative_sample_num] + negative_point_labels = np.zeros(len(negative_point_cords)) # Make sure that it is less than / equals to SAM_negative_sample_num quantity + + + + ################## Use SAM to filter out what we need (& use negative points) ################## + if idx == 0: # Only consider the first frame now. + # With sample coordinate + sam_predictor.set_image(np.uint8(np.transpose(image1.numpy(), (1,2,0)))) + if SAM_negative_sample_num != 0 and len(negative_point_cords) != 0: + all_point_cords = np.concatenate((positive_point_cords, negative_point_cords), axis=0) + all_point_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0) + else: + all_point_cords = positive_point_cords + all_point_labels = positive_point_labels + + masks, scores, logits = sam_predictor.predict( + point_coords=all_point_cords, + point_labels=all_point_labels, + multimask_output=False, + ) + mask = masks[0] # TODO: 一定要确定我们这里选择了最大的mask,而没有考虑的第二大和其他的, 这里可能有bug,我们默认了第一个就是最大的mask + # if verbose: + # cv2.imwrite("mask_"+str(idx+1)+".png", (np.uint8(mask)*255)) + # annotated_img = show_mask(mask) + # cv2.imwrite("annotated.png", annotated_img) + + + ################## Choose the one we need as the reference for the future tracking ################## + # Choose a random point in the mask + target_zone = np.nonzero(mask) # [(vertical), (horizontal)] + target_zone = [(target_zone[0][i], target_zone[1][i]) for i in range(len(target_zone[0]))] # Now, the sturcture is [(vertical, horizontal), ...] + + repeat_time = 0 + loop2find = True + while loop2find: + loop2find = False + start_point = target_zone[np.random.choice(len(target_zone), 1, replace=False)[0]] + start_vertical, start_horizontal = start_point + + repeat_time += 1 + if repeat_time == 100: + # In some minor case, it may have infinite loop, so we need to manually break if it is looping + print("We are still hard to find a optimal first point, but we cannot let it loop") + break + + # Try to choose a start_point that is more centralized (Not close to the border) + fast_break = False + for i in range(-15, 15): + for j in range(-15, 15): + dil_vertical, dil_horizontal = start_vertical + i, start_horizontal + j + if (0 <= dil_vertical and dil_vertical < mask.shape[0]) and (0 <= dil_horizontal and dil_horizontal < mask.shape[1]): + if mask[dil_vertical][dil_horizontal] == 0: + print("We need to change to a new position for the start p Since this one is close to the border of the object...........") + loop2find = True + fast_break = True + break + else: + # We won't want to consider those that is close to the boundary + print("We need to change to a new position Since this one is close to the border of the image...........") + loop2find = True + fast_break = True + break + if fast_break: + break + traj_path.append(start_point) + + status, base_img = visualize_traj(base_img, traj_path) + if status == False: # If the traj is False, we won't consider it anymore. + file = open("log.txt", "a") + file.write("Invalid start point\n") + return False + + # Read from the last one in traj + ref_vertical, ref_horizontal = traj_path[-1][0], traj_path[-1][1] + + + # Get the average motion vector for point surrounding (8+1 directions) the ref_point; This is because this is the most accurate statistics + horizon_lists, vertical_lists = [], [] + start_range, end_range = -5, 5 + + # Calculate the average motion based on surrounding motion + search_times = 0 + while len(horizon_lists) == 0: # If we cannot find a direction, we use average value inside this mask, but we will flag it. + search_times += 1 + + if search_times > 1: + print("This is hard to track!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! and we have tracked " + str(search_times) + " times") + # TODO: 如果out of boundary那种,search times到了8-10次的就砍掉那后面frame吧,这种非常inaccurate了, 你也可以retrack一个新的点,但是没有什么意义,看整体数量来定吧 + is_hard_to_track = True + hard_track_idxs.add(idx) + + if abs(start_range) >= flow_uv.shape[0]//2: + file = open("log.txt", "a") + file.write("This folder has search all space but didn't find any place to track optical flow\n") + return False # If we have already search for the whole graph but didn't find anything to track, we discard this sample + + # Search for a larger space which is nearby 我觉得扩大搜索范围应该是最稳定的选择吧 + for i in range(start_range, end_range): + for j in range(start_range, end_range): + target_vertical, target_horizontal = ref_vertical + i, ref_horizontal + j + if 0 <= target_vertical and target_vertical < flow_uv.shape[0] and 0 <= target_horizontal and target_horizontal < flow_uv.shape[1]: + if flow_uv[target_vertical, target_horizontal, 0] == 0 or flow_uv[target_vertical, target_horizontal, 1] == 0: + continue # Ignore zero vector to ensure only calculate moving position + horizon_lists.append(flow_uv[target_vertical, target_horizontal, 0]) # Horizontal motion strength + vertical_lists.append(flow_uv[target_vertical, target_horizontal, 1]) # Vertical motion strength + + # If there isn't any to search, we kepp on a larger space + start_range -= 10 + end_range += 10 + + average_dx = sum(horizon_lists)/len(horizon_lists) + average_dy = sum(vertical_lists)/len(vertical_lists) + print("average movement is ", (average_dx, average_dy)) + traj_path.append(( int(traj_path[-1][0] + average_dy), int(traj_path[-1][1] + average_dx))) # Append the motion in independent order + + print(traj_path) + + + ##################### Visualize the trajectory path (Debug Purpose) ##################### + status, base_img = visualize_traj(base_img, traj_path) + if status == False: # If the traj is False, we won't consider it anymore. + return False + + cv2.imwrite(os.path.join(store_dir, "traj_path.png"), cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB)) + + if traj_visualization: + status, single_traj_img = visualize_traj(np.uint8(np.transpose(image1.numpy(), (1,2,0))), traj_path[:-1], connect_points=False) + if status == False: # If the traj is False, we won't consider it anymore. + return False + + traj_write_path = os.path.join(store_dir, "traj_"+str(idx)+".png") + # cv2.imwrite(traj_write_path, cv2.cvtColor(single_traj_img, cv2.COLOR_BGR2RGB)) + traj_image_lists.append(traj_write_path) + + + # if traj_visualization: + # images = [] + # for filename in traj_image_lists: + # images.append(imageio.imread(filename)) + # # os.remove(filename) # Remove when used + # imageio.mimsave(os.path.join(store_dir, 'traj_motion.gif'), images, duration=0.05) + + + # TODO: 可以如果hard to track,就aggressivly多试即便,我们根据这个hard_track_idxs的长度来粗略判断哪个最好,三次里面选最好的 + if is_hard_to_track: + if len(hard_track_idxs) >= len(img_pairs)//3: # If more than half of the traj is hard to track, we need to consider discard this one + file = open("log.txt", "a") + file.write("we have a lot of times hard to find dx and dy movement. Under this circumstance, we are not very recommended to use the track\n") + return False + + + # Write a file store all position for further utilization + txt_path = os.path.join(store_dir, "traj_data.txt") + if os.path.exists(txt_path): + os.remove(txt_path) + file = open(txt_path, "a") + for traj in traj_path: + file.write(str(traj[0]) + " " + str(traj[1]) + "\n") + # Save in numpy information + # with open(os.path.join(store_dir, 'traj_data.npy'), 'wb') as f: + # np.save(f, flow_uv) + print("We write ", traj_path) + return True + + + +def manage_seq_range(input_dir, store_dir, total_frame_needed): + + lists = os.listdir(input_dir) + lists = lists[2:-2] + num_frames_input = len(lists) + + if num_frames_input < total_frame_needed: + print("The number of frames is too short for constructing the sequnece length needed") + return False + + + division_factor = num_frames_input // total_frame_needed + remain_frame = num_frames_input % total_frame_needed + + gaps = [division_factor for _ in range(total_frame_needed)] + for idx in range(remain_frame): + gaps[idx] += 1 + + + cur_idx = 2 + for global_idx, gap in enumerate(gaps): + source_path = os.path.join(input_dir, "im_"+str(cur_idx)+".jpg") + destination_path = os.path.join(store_dir, "im_"+str(global_idx)+".jpg") + + shutil.copyfile(source_path, destination_path) + cur_idx += gap + + return True + + +def generate_pairs(dirname, start_idx, end_idx): + img_pairs = [] + for idx in range(start_idx, end_idx): + img1 = osp.join(dirname, f'im_{idx}.jpg') + img2 = osp.join(dirname, f'im_{idx+1}.jpg') + # img1 = f'{idx:06}.png' + # img2 = f'{idx+1:06}.png' + img_pairs.append((img1, img2)) + + return img_pairs + + +def process_partial_request(request_list, num_frames, traj_visualization, viz_root_dir): + + + # Init the optical flow model + optical_flow_model = build_model() + + # Init SAM for segmentation task + model_type = "vit_h" + weight_path = "pretrained/sam_vit_h_4b8939.pth" + SAM_positive_sample_num = 20 # How many points we use for the positive sample num () + SAM_negative_sample_num = 0 # How many points we use for the negative sample num + + print("In multi processing, we will build an instance of mask_generator independently") + sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda") + mask_generator = SamAutomaticMaskGenerator(sam) + print("In multi processing, we will build an instance of sam_predictor independently") + sam_predictor = SamPredictor(sam) + + + counter = 0 + while True: + counter += 1 + if counter == 10: + counter = 0 + gc.collect() + print("We will sleep here to clear memory") + time.sleep(5) + info = request_list[0] + request_list = request_list[1:] + if info == None: + print("This queue ends") + break + + + # Process each sub_input_dir and store the information there + sub_input_dir = info + + + img_pairs = generate_pairs(sub_input_dir, 0, num_frames-1) + print(img_pairs) + + with torch.no_grad(): + + # Calculate the optical flow and return a status to say whther this generated flow is usable + status = calculate_flow(viz_root_dir, sub_input_dir, img_pairs, optical_flow_model, sam_predictor, SAM_positive_sample_num, SAM_negative_sample_num, + mask_generator, traj_visualization, keep_size = True) + + # file = open("log.txt", "a") + print("The status for folder " + sub_input_dir + " is " + str(status) + "\n") + + if status == False: + # If the status is failed, we will remove it afterwords + print("The status is Failed, so we won't store this one as one promising data") + else: + print("We have successfully process one!") + + +if __name__ == '__main__': + + # Manage the paramter + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', default = '../validation_flow14/') + parser.add_argument('--num_workers', type = int, default = 1) # starting index of the image sequence + parser.add_argument('--viz_root_dir', default = 'viz_results') + parser.add_argument('--traj_visualization', default = True) # If this is True, + + # list_start = 0 + # list_end = 25000 + num_frames = 14 + + args = parser.parse_args() + input_dir = args.input_dir + num_workers = args.num_workers + viz_root_dir = args.viz_root_dir + traj_visualization = args.traj_visualization + + + + store_idx = 0 + dir_list = [] + for sub_input_name in sorted(os.listdir(input_dir)): + sub_input_dir = os.path.join(input_dir, sub_input_name) + # sub_store_dir = os.path.join(store_dir, "0"*(7-len(str(store_idx)))+str(store_idx)) + store_idx += 1 + dir_list.append(sub_input_dir) + + # Truncate the list to the target + # dir_list = dir_list[list_start:] + + + # Use multiprocessing to handle to speed up + num = math.ceil(len(dir_list) / num_workers) + for idx in range(num_workers): + # set_start_method('spawn', force=True) + + request_list = dir_list[:num] + request_list.append(None) + dir_list = dir_list[num:] + + + process_partial_request(request_list, num_frames, traj_visualization, viz_root_dir) # This is for debug purpose + # p = mp.Process(target=process_partial_request, args=(request_list, num_frames, traj_visualization, viz_root_dir, )) + # p.start() + + print("Submitted all jobs!") + # p.join() # 好像不加这个multiprocess就莫名自己结束了 + print("All task finished!") + + + \ No newline at end of file diff --git a/scripts/interpolate_by_repeat.py b/scripts/interpolate_by_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..f558741cfead297521fb4ffd509b54ae660fe83e --- /dev/null +++ b/scripts/interpolate_by_repeat.py @@ -0,0 +1,55 @@ +''' + This file is trying to repeat the frames such the it reaches target frames needed +''' +import os, shutil, sys + +if __name__ == "__main__": + input_path = "/nfs/turbo/coe-jjparkcv/boyangwa/AVDC/AVDC_results" + store_path = "/nfs/turbo/coe-jjparkcv/boyangwa/AVDC/AVDC_results_interpolated" + total_frames_needed = 14 + + # Handle the file folder management + if os.path.exists(store_path): + shutil.rmtree(store_path) + os.makedirs(store_path) + + for video_name in sorted(os.listdir(input_path)): + sub_input_path = os.path.join(input_path, video_name) + sub_store_path = os.path.join(store_path, video_name) + + # Create the store place + os.makedirs(sub_store_path) + + # Find valid image lists + num_frames_input = 0 + for file_name in os.listdir(sub_input_path): + if file_name.endswith("png"): + num_frames_input += 1 + print("num_frames_input is ", num_frames_input) + + # Calculate needed parameters + division_factor = total_frames_needed // num_frames_input + remain_frames = (total_frames_needed % num_frames_input) - 1 # -1 for adaptation + + # Define the gap + gaps = [division_factor for _ in range(num_frames_input)] + for idx in range(remain_frames): + if idx % 2 == 0: + gaps[idx//2] += 1 # Start to end order + else: + gaps[-1*(1+(idx//2))] += 1 # End to start order + + print("gaps is ", gaps) + + + # Write to the new folder + store_idx = 0 + for frame_idx, gap in enumerate(gaps): + for tmp in range(gap): # Repeat copy gap num of times + img_path = os.path.join(sub_input_path, str(frame_idx)+".png") + shutil.copyfile(img_path, os.path.join(sub_store_path, str(store_idx)+".png")) + store_idx += 1 + + + + diff --git a/scripts/length_stats.py b/scripts/length_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f600e14dffe1e0d3b078c0843ae13a9784ed90 --- /dev/null +++ b/scripts/length_stats.py @@ -0,0 +1,21 @@ +import os, sys, shutil +import numpy as np +import matplotlib.pyplot as plt + + +if __name__ == "__main__": + input_folder_path = "../Bridge_v2" + + average_length = [] + + # Iterate each file + for sub_folder_name in sorted(os.listdir(input_folder_path)): + sub_folder_path = os.path.join(input_folder_path, sub_folder_name) + + average_length.append(len(os.listdir(sub_folder_path))) # Have more than one than expected, but we keep this + print("average length of {} is {}".format(sub_folder_name, average_length[-1])) + + print("average_movement_list is ", average_length) + n, bins, patches = plt.hist(average_length, bins=100) + plt.savefig("dataset_length2.png") + diff --git a/scripts/motion_stats.py b/scripts/motion_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..effbb9d3d4b334cb77600f90cc46f9aa8357a52a --- /dev/null +++ b/scripts/motion_stats.py @@ -0,0 +1,75 @@ +import os, sys, shutil +import numpy as np +import math +from statistics import mean +import matplotlib.pyplot as plt + + +if __name__ == "__main__": + input_folder_paths = ["../datasets_rob/Bridge_v1_raw", "../datasets_rob/Bridge_v2_raw"] # "../datasets_rob/Bridge_v1_raw", "../datasets_rob/Bridge_v2_raw" + num_frames = 14 + store_name = "movement.png" + + + average_movement_list = [] + not_valid_num = 0 + not_exists_num = 0 + # Iterate each file + for input_folder_path in input_folder_paths: + for sub_folder_name in sorted(os.listdir(input_folder_path)): + sub_folder_path = os.path.join(input_folder_path, sub_folder_name) + flow_path = os.path.join(sub_folder_path, 'flow.txt') + + if not os.path.exists(flow_path): + not_exists_num += 1 + continue + + + # Read the movement + file = open(flow_path, 'r') + info = file.readlines() + print(info) + if len(info) == 0: + not_valid_num += 1 + continue + info = info[0][:-2] + per_video_movement = float(info) + + + # Calculate the number of frames in this video + num_frames_input = 0 + valid = True + for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_ + if file_name.startswith("im_"): + num_frames_input += 1 + for idx in range(num_frames_input): # Ensure that this number is concurrent + img_path = os.path.join(sub_folder_path, 'im_' + str(idx) + '.jpg') + if not os.path.exists(img_path): # Should be sequential existing + valid = False + break + if num_frames_input < 2: + valid = False + if not valid: + not_valid_num += 1 + print("This is not valid path") + continue + + average_movement_list.append(per_video_movement * (num_frames_input/num_frames)) # Have more than one than expected, but we keep this + print("average movement of {} is {}".format(sub_folder_name, average_movement_list[-1])) + + print("not_exists_num is ", not_exists_num) + print("not_valid_num is ", not_valid_num) + print("average_movement_list length is ", len(average_movement_list)) + + # Get mean and variance data + mean_value = mean(average_movement_list) + std_value = math.sqrt(np.var(average_movement_list)) + print("Mean is ", mean_value) + print("std_value is ", std_value) + + # Plot the figure + n, bins, patches = plt.hist(average_movement_list, bins=100) + plt.title("Mean" + str(mean_value) + "_STD"+str(std_value)) + plt.savefig(store_name) + + diff --git a/scripts/process_llama.py b/scripts/process_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..00ba8cd69e4161cb2d5bf34c44a761cc2edff734 --- /dev/null +++ b/scripts/process_llama.py @@ -0,0 +1,74 @@ +''' + Process the llama file for the next step +''' +import os, shutil, sys +import json +import pandas as pd +import collections + + +if __name__ == "__main__": + + # Define important path + json_path = "../SVD1/v1.jsonl" + folder_path = "/home/kiteret/Desktop/StableVideoDiffusion/full_text_tmp/" + + + # Read the json file + with open(json_path, 'r') as json_file: + json_list = list(json_file) + + # Iterate all the json files + length_stats = collections.defaultdict(int) + for json_info in json_list: + json_info = json.loads(json_info) + + + # Define the path to write + key_start = len("/home/chfeng/llama3/full_text_tmp/") + key_end = len("lang.txt") + sub_path = json_info["file_path"][key_start:int(-1*key_end)] + new_text_path = os.path.join(folder_path, sub_path, "processed_text.txt") + if os.path.exists(new_text_path): + os.remove(new_text_path) + + + # Sanity check for the case where input is missed + if json_info["input"] == "": + print("It is weird for the input is empty in the LLM process for ", sub_path) + continue + + + # Re-Define the content + outputs = json_info["output"] + if outputs.find("action:") != 0: + print("It is weird for no actions: keyword in the outputs for ", sub_path, " with prompt ", outputs) + continue + + # Prepare write file + contents = outputs.split('\n') + f = open(new_text_path, "a") + + # Itearte + effective_length = 0 + for idx, content in enumerate(contents): + key_word = content.split(":")[1][1:] + if key_word != "": + effective_length += 1 + else: + if idx == 1: + print("It is abnormal for the this content to be empty ", sub_path, " with prompt ", outputs) + f.write(key_word + "\n") + # if effective_length == 2: + # print("short prompt case is ", sub_path, " with prompt ", outputs) + if effective_length < 2: # For those only 1 or zero, we won't consider them + print("The prompt is too short for ", sub_path, " with prompt ", outputs) + os.remove(new_text_path) + + length_stats[effective_length] += 1 + + print("length_stats is ", length_stats) + + + + diff --git a/scripts/process_sim.py b/scripts/process_sim.py new file mode 100644 index 0000000000000000000000000000000000000000..4a06b3aa4ecd8a0c8b89e45d039b94bd2b71ebff --- /dev/null +++ b/scripts/process_sim.py @@ -0,0 +1,59 @@ +''' + This is a script to processs Mark's data. +''' +import os, sys, shutil + +if __name__ == "__main__": + file_path = "/nfs/turbo/coe-jjparkcv/datasets/isaac-gym-pick-place/full/dataset_v3_proc" + store_path = "../datasets_rob/sim_raw" + most_descriptive_prompt_idx = 6 # Start from the 0 + + + # Folder management + if os.path.exists(store_path): + shutil.rmtree(store_path) + os.makedirs(store_path) + + # Check length + file_names = os.listdir(file_path) + target_length = len(file_names) // 10 # 10 files as a cycle + + + for idx in range(target_length): + sub_folder_path = os.path.join(file_path, "run_"+str(10*idx)) + if not os.path.exists(sub_folder_path): + continue + + # Prepare the target position + sub_store_path = os.path.join(store_path, str(idx)) + os.makedirs(sub_store_path) + + # Find the key prompt to read it + prompt_content = [] + for tmp_idx in range(10): + tmp_text_path = os.path.join(file_path, "run_"+str(10*idx + tmp_idx), "lang.txt") # Usually, the 6th is the most concrete version + if not os.path.exists(tmp_text_path): + continue + file = open(tmp_text_path, 'r') + prompt_content.append(file.readlines()[0]) + file.close() + print("prompt_content we have num ", len(prompt_content)) + + + + # Copy the image into the target position and copy the data.txt + for file_name in os.listdir(sub_folder_path): + if file_name == "lang.txt": + continue + shutil.copyfile(os.path.join(sub_folder_path, file_name), os.path.join(sub_store_path, file_name)) + + # Handle the lang.txt + target_lang_txt_path = os.path.join(sub_store_path, "lang.txt") + f = open(target_lang_txt_path, "a") + f.write(prompt_content[most_descriptive_prompt_idx]+"\n") + for tmp_idx in range(10): + if tmp_idx == most_descriptive_prompt_idx: + continue + f.write(prompt_content[tmp_idx]+"\n") + f.close() + diff --git a/scripts/resize_img.py b/scripts/resize_img.py new file mode 100644 index 0000000000000000000000000000000000000000..b98cba71eb7cdd35fb189e8574c152856748a3be --- /dev/null +++ b/scripts/resize_img.py @@ -0,0 +1,17 @@ +import os, sys, shutil +import cv2 + +if __name__ == "__main__": + input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/resize" + output_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/resize_resized" + + if os.path.exists(output_path): + shutil.rmtree(output_path) + os.makedirs(output_path) + + for img_name in os.listdir(input_path): + img_path = os.path.join(input_path, img_name) + img = cv2.imread(img_path) + img = cv2.resize(img, (384, 256)) + store_path = os.path.join(output_path, img_name) + cv2.imwrite(store_path, img) \ No newline at end of file diff --git a/scripts/resize_video_seq.py b/scripts/resize_video_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..4eea0d2efa04a83d388ae4d0b75ba91cd09e1553 --- /dev/null +++ b/scripts/resize_video_seq.py @@ -0,0 +1,33 @@ +''' + This file is designed to resize the video sequence to the target resolution +''' +import os, sys, shutil +import cv2 + +if __name__ == "__main__": + input_folder = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/SVD_results" + store_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/SVD_results_resized" + target_height, target_width = 256, 384 + + if os.path.exists(store_path): + shutil.rmtree(store_path) + os.makedirs(store_path) + + for video_name in sorted(os.listdir(input_folder)): + print("We are processing ", video_name) + sub_video_folder = os.path.join(input_folder, video_name) + sub_store_folder = os.path.join(store_path, video_name) + os.makedirs(sub_store_folder) + + for img_name in os.listdir(sub_video_folder): + if not img_name.endswith("jpg") and not img_name.endswith("png"): + continue + + img_path = os.path.join(sub_video_folder, img_name) + store_img_path = os.path.join(sub_store_folder, img_name) + img = cv2.imread(img_path) + + # Resize + img = cv2.resize(img, (target_width, target_height)) + cv2.imwrite(store_img_path, img) + diff --git a/scripts/train_test_split.py b/scripts/train_test_split.py new file mode 100644 index 0000000000000000000000000000000000000000..61c0791b889108a466e585b30b40a4149b703f74 --- /dev/null +++ b/scripts/train_test_split.py @@ -0,0 +1,23 @@ +import os, sys, shutil +import random + + +if __name__ == "__main__": + base_dataset_path = "../datasets_rob/Bridge_v1_raw" + test_store_path = "../datasets_rob/Bridge_v1_test_raw" + split_ratio = 0.1 # [0, 1] range + + # Prepare the folder + if os.path.exists(test_store_path): + shutil.rmtree(test_store_path) + os.makedirs(test_store_path) + + full_img_lists = os.listdir(base_dataset_path) + random.shuffle(full_img_lists) + target_test_length = int(len(full_img_lists) * split_ratio) + test_img_lists = full_img_lists[-1 * target_test_length : ] + + # Move the lists based on test_img_lists + for test_img_name in test_img_lists: + shutil.move(os.path.join(base_dataset_path, test_img_name), os.path.join(test_store_path, test_img_name)) + diff --git a/scripts/visualize_thisthat_point.py b/scripts/visualize_thisthat_point.py new file mode 100644 index 0000000000000000000000000000000000000000..fb0c3703b4276cc6fe4323c3ca523412a378d547 --- /dev/null +++ b/scripts/visualize_thisthat_point.py @@ -0,0 +1,43 @@ +''' + This repo is provided to change the destination area. +''' + +import os, cv2 + + +def draw_dot(ref_img, new_h, new_w): + # Draw the dot + dot_range = 3 + for i in range(-1*dot_range, dot_range+1): + for j in range(-1*dot_range, dot_range+1): + dil_vertical, dil_horizontal = new_h + i, new_w + j + if (0 <= dil_vertical and dil_vertical < ref_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < ref_img.shape[1]): + ref_img[dil_vertical, dil_horizontal, :] = [0, 128, 0] + + return ref_img + + +if __name__ == "__main__": + instance_path = "datasets/validation_thisthat14/000049/" + new_w, new_h = 385, 310 + # 256.1850280761719 241.71287155151367 + + # Read the items + data_path = os.path.join(instance_path, "data.txt") + ref_img_path = os.path.join(instance_path, "im_0.jpg") + ref_img = cv2.imread(ref_img_path) + + + # Read the first point + file1 = open(data_path, 'r') + Lines = file1.readlines() + frame_idx, horizontal, vertical = Lines[0].split(' ') + ref_img = draw_dot(ref_img, int(float(vertical)), int(float(horizontal))) + + # Second dot + ref_img = draw_dot(ref_img, new_h, new_w) + + + + # Store the image + cv2.imwrite("visual.png", ref_img) \ No newline at end of file diff --git a/svd/diffusion_arch/transformer_temporal.py b/svd/diffusion_arch/transformer_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1bbbe269462154d7030125e1024a8c4c75f028 --- /dev/null +++ b/svd/diffusion_arch/transformer_temporal.py @@ -0,0 +1,381 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.resnet import AlphaBlender + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> TransformerTemporalModelOutput: + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) + + +class TransformerSpatioTemporalModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] # This part means that the cross attn section for the temporal blocks only consider ths first frames + + + encoder_hidden_states_dim = time_context_first_timestep.shape[1] + time_context = time_context_first_timestep[None, :].broadcast_to( + height * width, batch_size, encoder_hidden_states_dim, time_context.shape[-1] + ) + time_context = time_context.reshape(height * width * batch_size, encoder_hidden_states_dim, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = torch.arange(num_frames, device=hidden_states.device) + num_frames_emb = num_frames_emb.repeat(batch_size, 1) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + None, + encoder_hidden_states, + None, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/svd/diffusion_arch/unet_3d_blocks.py b/svd/diffusion_arch/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..fe615a2f4246f1e3392b0ca9196537577df46186 --- /dev/null +++ b/svd/diffusion_arch/unet_3d_blocks.py @@ -0,0 +1,2396 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union +import os, sys +import torch +from torch import nn + +from diffusers.utils import is_torch_version +from diffusers.utils.torch_utils import apply_freeu +from diffusers.models.attention import Attention +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import ( + Downsample2D, + ResnetBlock2D, + SpatioTemporalResBlock, + TemporalConvLayer, + Upsample2D, +) +from diffusers.models.transformer_2d import Transformer2DModel + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from svd.diffusion_arch.transformer_temporal import TransformerSpatioTemporalModel, TransformerTemporalModel + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int = 1, +) -> Union[ + "DownBlock3D", + "CrossAttnDownBlock3D", + "DownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", +]: + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + if down_block_type == "DownBlockMotion": + return DownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "CrossAttnDownBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") + return CrossAttnDownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "DownBlockSpatioTemporal": + # added for SDV + return DownBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlockSpatioTemporal": + # added for SVD + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") + return CrossAttnDownBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) + + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resolution_idx: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int = 1, + dropout: float = 0.0, +) -> Union[ + "UpBlock3D", + "CrossAttnUpBlock3D", + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", +]: + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + ) + if up_block_type == "UpBlockMotion": + return UpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "CrossAttnUpBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") + return CrossAttnUpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "UpBlockSpatioTemporal": + # added for SDV + return UpBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") + return CrossAttnUpBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_upsample=add_upsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resolution_idx=resolution_idx, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + num_frames: int = 1, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, scale + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, + ) + + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + ): + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_norm_num_groups: int = 32, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size=None, + scale: float = 1.0, + num_frames: int = 1, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + ) + + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: float = False, + use_linear_projection: float = False, + upcast_attention: float = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class MidBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + attention_head_dim: int = 512, + num_layers: int = 1, + upcast_attention: bool = False, + ): + super().__init__() + + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=1e-6, + upcast_attention=upcast_attention, + norm_num_groups=32, + bias=True, + residual_connection=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + ): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = attn(hidden_states) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class UpBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0]( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class DownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-6, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=1, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + for resnet, attn in blocks: + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + resnet_eps: float = 1e-6, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/svd/pipeline_stable_video_diffusion.py b/svd/pipeline_stable_video_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..532488ff4721cde3d19c857fe233cbbbd0cb74b2 --- /dev/null +++ b/svd/pipeline_stable_video_diffusion.py @@ -0,0 +1,687 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union +import cv2 +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DiffusionPipeline + + +# Import files from the local folder +from utils.img_utils import tensor2np + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + return outputs + + +@dataclass +class StableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + frames (`[List[PIL.Image.Image]`, `np.ndarray`]): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + frames: Union[List[PIL.Image.Image], np.ndarray] + + +class StableVideoDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + + def encode_clip(self, image, prompt, use_text, text_encoder, device, num_videos_per_prompt, do_classifier_free_guidance): + # Encode image and text prompt by the clip + + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 # [0, 1] -> [-1, 1] + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 # [-1, 1] -> [0, 1] + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + encoder_hidden_states = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + + # Prepare for the text embeddings if needed + if use_text: + text_embeddings = text_encoder(prompt)[0] + + # Concat two embeddings together on dim 1 + encoder_hidden_states = torch.cat((text_embeddings, encoder_hidden_states), dim=1) + + layer_norm = nn.LayerNorm((78, 1024)).to(device=device, dtype=dtype) + encoder_hidden_states = layer_norm(encoder_hidden_states) + + + # CFG in inference + if do_classifier_free_guidance: + negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + encoder_hidden_states = torch.cat([negative_encoder_hidden_states, encoder_hidden_states]) + + + + return encoder_hidden_states + + + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + prompt = None, + use_text: bool = False, + text_encoder = None, + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: int = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`int`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale > 1.0 + + + # 3. Encode input image by CLIP + encoder_hidden_states = self.encode_clip(image, prompt, use_text, text_encoder, device, num_videos_per_prompt, do_classifier_free_guidance) + + + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width) + # cv2.imwrite("no_noise.png", cv2.cvtColor(tensor2np(image), cv2.COLOR_BGR2RGB)) + noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) + image = image + noise_aug_strength * noise + # cv2.imwrite("noise.png", cv2.cvtColor(tensor2np(image), cv2.COLOR_BGR2RGB)) + + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + image_latents = image_latents.to(encoder_hidden_states.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + encoder_hidden_states.dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + encoder_hidden_states.dtype, + device, + generator, + latents, + ) + + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimension + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) # image_latents is fixed and latent_model_input will be based on latents which is updated frequently + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, # [batch, frames, 4*2, height, width] + t, + encoder_hidden_states = encoder_hidden_states, + added_time_ids = added_time_ids, + return_dict = False, + )[0] # encoder_hidden_states is used for cross attention metioned in the paper + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) # There are two noises here: one is unconditional and one is conditional + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + +# resizing utils +# TODO: clean up later +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out diff --git a/svd/pipeline_stable_video_diffusion_controlnet.py b/svd/pipeline_stable_video_diffusion_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8208f8a6b09c667cc8e9ff5684a0399cc9342f20 --- /dev/null +++ b/svd/pipeline_stable_video_diffusion_controlnet.py @@ -0,0 +1,852 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union +import cv2, os, sys +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + + +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.models import AutoencoderKLTemporalDecoder +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DiffusionPipeline + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from utils.img_utils import tensor2np +from svd.temporal_controlnet import ControlNetModel +from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + return outputs + + +@dataclass +class StableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + frames (`[List[PIL.Image.Image]`, `np.ndarray`]): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + frames: Union[List[PIL.Image.Image], np.ndarray] + + +class StableVideoDiffusionControlNetPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + # TODO: multi-controlnet consideration + self.register_modules( + vae = vae, + image_encoder = image_encoder, + unet = unet, + scheduler = scheduler, + feature_extractor = feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) # The vae_scale_factor is for image dimension, not for image size + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + + def encode_clip(self, image, prompt, use_text, text_encoder, device, num_videos_per_prompt, do_classifier_free_guidance, use_instructpix2pix): + + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) # Map [0, 255] to [0, 1] range + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then, we unnormalize it after resizing. + image = image * 2.0 - 1.0 # [-1, 1] range + image = _resize_with_antialiasing(image, (224, 224)) # Resize to square image + image = (image + 1.0) / 2.0 # [0, 1] range + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values # The value range is a little deviated now, and I got [-1.76, 2.15] for one sample + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + encoder_hidden_states = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + + # Prepare for the text embeddings if needed + if use_text: + text_embeddings = text_encoder(prompt)[0] + + # Concat two embeddings together on dim 1 + encoder_hidden_states = torch.cat((text_embeddings, encoder_hidden_states), dim=1) + + # Layer norm on the last dim TODO: 这里order小改了一下顺序,变成先encoder hidden states了 + layer_norm = nn.LayerNorm((78, 1024)).to(device=device, dtype=dtype) + encoder_hidden_states = layer_norm(encoder_hidden_states) + + + if do_classifier_free_guidance: + negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if use_instructpix2pix: + encoder_hidden_states = torch.cat([encoder_hidden_states, negative_encoder_hidden_states, negative_encoder_hidden_states]) + else: + encoder_hidden_states = torch.cat([negative_encoder_hidden_states, encoder_hidden_states]) + + + return encoder_hidden_states + + + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + use_instructpix2pix, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if use_instructpix2pix: + image_latents = torch.cat([image_latents, image_latents, negative_image_latents]) + else: + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + guess_mode, + use_instructpix2pix, + ): + # Define the default values from SVD + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + + # Sanity Check + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + if use_instructpix2pix: + add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids]) + else: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + + # Return the info + return add_time_ids + + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, + image, + height, + width, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # TODO: I didn't test input for controlnet_conditioning_scale, control_guidance_start, and control_guidance_end + + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # We don't directly have do_classifier_free_guidance function, we judge simply by max_guidance + + @property + def num_timesteps(self): + return self._num_timesteps + + + def prepare_condition_image( + self, + condition_img, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + # The input of condition_img is already in the range [0, 1] + condition_img = torch.from_numpy(condition_img) # hwc -> chw + condition_img = condition_img.to(torch.float16).to(self._execution_device) # Set this in default + + # CFG will be done in main function, not here now + return condition_img # [0, 1] range && Torch data type + + + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + condition_img: np, + controlnet: ControlNetModel, + prompt = None, + use_text: bool = False, + text_encoder = None, + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + # controlnet_image_index: Optional[int] = [0], + # coordinate_values = None, + noise_aug_strength: int = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + use_instructpix2pix: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + inner_conditioning_scale: float = 1.0, + guess_mode: bool = True, + image_guidance_scale: float = 7.5, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`int`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + + # align format for control guidance + mult = 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # It seems that self.unet.config.sample_size * self.vae_scale_factor is a default image size input setting + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale > 1.0 + if do_classifier_free_guidance: + print("We will use CFG!!!") + + + # 3. Encode input image + encoder_hidden_states = self.encode_clip(image, prompt, use_text, text_encoder, device, num_videos_per_prompt, do_classifier_free_guidance, use_instructpix2pix) + + + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width) # [0, 255] to [-1, 1] + noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, use_instructpix2pix) + image_latents = image_latents.to(encoder_hidden_states.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + + + # 4.5 Prepare control image (Will need to consider multiControlNet) + condition_img = self.prepare_condition_image( + condition_img = condition_img, + width = width, + height = height, + batch_size = batch_size * num_videos_per_prompt, + num_videos_per_prompt = num_videos_per_prompt, + device = device, + dtype = controlnet.dtype, + do_classifier_free_guidance = do_classifier_free_guidance, + guess_mode = guess_mode, + ) # [0, 255] to [0, 1] range + + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + # coordinate_values, + encoder_hidden_states.dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + guess_mode = guess_mode, + use_instructpix2pix = use_instructpix2pix, + ) + added_time_ids = added_time_ids.to(device) + + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + encoder_hidden_states.dtype, + device, + generator, + latents, + ) # Nosiy latents across all frames needed + + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + self._guidance_scale = guidance_scale + + + # 7.5 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + # expand the latents if we are doing classifier free guidance + if use_instructpix2pix: + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + else: + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # I think that this is where sequential generation takes influence + + # Concatenate image_latents over channels dimension for video diffusion purposes + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) # image_latents is fixed and latent_model_input will be based on latents which is updated frequently + + + # ControlNet Scale + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + + # assert condition_img.shape[1] >= len(controlnet_image_index) + + + # VAE encode + controlnet_cond = self.vae.encode(condition_img).latent_dist.mode() + + + if do_classifier_free_guidance: + if use_instructpix2pix: + controlnet_cond = torch.cat([controlnet_cond, controlnet_cond, controlnet_cond]) + # controlnet_conditioning_mask = torch.cat([controlnet_conditioning_mask, controlnet_conditioning_mask, controlnet_conditioning_mask]) + else: + controlnet_cond = torch.cat([controlnet_cond, controlnet_cond]) + # controlnet_conditioning_mask = torch.cat([controlnet_conditioning_mask, controlnet_conditioning_mask]) + + + down_block_res_samples, mid_block_res_sample = controlnet( + sample = latent_model_input, + timestep = t, + encoder_hidden_states = encoder_hidden_states, + added_time_ids = added_time_ids, + controlnet_cond = controlnet_cond, + return_dict = False, + inner_conditioning_scale = inner_conditioning_scale, # Inner conditioning scale + conditioning_scale = cond_scale, # Outer conditioning scale + guess_mode = guess_mode, + ) + + if guess_mode and do_classifier_free_guidance: # Won't consider this one, since we don't use guess mode + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + + # predict the noise residual + noise_pred = self.unet( + sample = latent_model_input, # [batch, frames, 4*2, height, width] + timestep = t, + encoder_hidden_states = encoder_hidden_states, + added_time_ids = added_time_ids, + down_block_additional_residuals = down_block_res_samples, + mid_block_additional_residual = mid_block_res_sample, + return_dict = False, + )[0] # image_embeddings is used for cross attention metioned in the paper + + + # perform guidance + if do_classifier_free_guidance: + if use_instructpix2pix: + noise_pred_1st_frame, noise_pred_cond, noise_pred_uncond = noise_pred.chunk(3) # There are two noises here: one is unconditional and one is conditional + noise_pred = noise_pred_uncond + \ + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + \ + image_guidance_scale * (noise_pred_cond - noise_pred_1st_frame) # InstructPix2Pix is (noise_pred_1st_frame - noise_pred_cond) + else: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) # There are two noises here: one is unconditional and one is conditional + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + +# resizing utils +# TODO: clean up later (put to shared utils file) +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out + + +def is_compiled_module(module) -> bool: + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) \ No newline at end of file diff --git a/svd/temporal_controlnet.py b/svd/temporal_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..32c685b552c157ea5063bbb7cbf5e02ba3448964 --- /dev/null +++ b/svd/temporal_controlnet.py @@ -0,0 +1,644 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' + This is a ControlNet for sptio temporal unet (SVD) +''' +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os, sys +import random +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers import AutoencoderKLTemporalDecoder +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalControlnetMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel +from svd.diffusion_arch.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 8, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + mid_block_type: Optional[str] = "UNetMidBlockSpatioTemporal", + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + layers_per_block: int = 2, + act_fn: str = "silu", + cross_attention_dim: int = 1024, + projection_class_embeddings_input_dim: Optional[int] = 768, + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), # This is modified to SVD config setting for the default case + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + controlnet_conditioning_channel_order = 'rgb', + ): + super().__init__() + + self.controlnet_conditioning_channel_order = controlnet_conditioning_channel_order + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + + + ########################## First convolution channel for sample (noise) ########################## + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in_concat = zero_module(nn.Conv2d( + 12, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + )) # Input is 12 channels (8 + 4) right now + + + ########################## Time embedding and so on ########################## + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) # defualt flip_sin_to_cos True + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # Additional time embedding for other purpose + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) # This will include hyperparameter like fps, motion_bucket_id, noise_aug_strength + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + + + ############################# Down and Mid Blocks Init ############################# + # Init ModuleList and prepare information needed + self.down_blocks = nn.ModuleList([]) + output_channel = block_out_channels[0] + + + # Check instance + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + + # ControlNet Module!!!!! + self.controlnet_down_blocks = nn.ModuleList([]) + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) # Zero Convolution + self.controlnet_down_blocks.append(controlnet_block) + + # Down block init one by one + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + + # ControlNet Module !!!! + for _ in range(layers_per_block[0]): # Loop 2 times here + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: # Loop only once + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # Mid block + mid_block_channel = block_out_channels[-1] + + # ControlNet Module !!!! + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlockSpatioTemporal": + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + + @classmethod + def from_unet( + cls, + unet: UNetSpatioTemporalConditionModel, + conditioning_channels: int = 3, + load_weights_from_unet: bool = True, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNetSpatioTemporalConditionModel`]. + + Parameters: + unet (`UNetSpatioTemporalConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + load_weights_from_unet (bool): + Whether we used unet as trainable copy (Should be True in default) + """ + + controlnet = cls(conditioning_channels=conditioning_channels) + + if load_weights_from_unet: + # controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) # Won't load this conv_in now, we will replace it with another zero conv + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + added_positions: torch.Tensor = None, + controlnet_cond: torch.FloatTensor = None, + conditioning_scale: float = 1.0, + inner_conditioning_scale: float = 1.0, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, 4, hidden_size)` which is already encoded in VAE. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + + # check channel order + channel_order = self.controlnet_conditioning_channel_order + + # if channel_order == "rgb": + # # in rgb order by default + # ... + # elif channel_order == "bgr": + # controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + # else: + # raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] # Take the classifier guidance also as an input in batch + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) # No more timestep_cond because usually this is None + + # motion score + fps + aug strength embeds + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + + # Wrap up + emb = emb + aug_emb + + + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + + # 2. Pre-Process + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + + # Feature: Concat the sample && controlnet_cond at dim 1 (channel-wise) !!! + sample = torch.cat([sample, controlnet_cond], dim=1) + + + # Merge sample and controlnet_cond together + sample = self.conv_in_concat(sample) + + + + # 3. Down block + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, # Vae encode + noise + temb=emb, + encoder_hidden_states=encoder_hidden_states, # Clip encode + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + + # 4. Mid block + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + + # 5. ControlNet blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + # Mid block + mid_block_res_sample = self.controlnet_mid_block(sample) + + + # 6. Scaling + if guess_mode: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + + diff --git a/svd/unet_spatio_temporal_condition.py b/svd/unet_spatio_temporal_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ad52c6b9e52c8486444738d4fa9ee2b5f9a067 --- /dev/null +++ b/svd/unet_spatio_temporal_condition.py @@ -0,0 +1,537 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union +import os, sys +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +# from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from svd.diffusion_arch.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetSpatioTemporalConditionOutput(BaseOutput): + """ + The output of [`UNetSpatioTemporalConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), + num_frames: int = 25, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + padding=1, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + # self.position_embedding = TimestepEmbedding(1024, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-5, + resolution_idx=i, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + resnet_act_fn="silu", + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) + self.conv_act = nn.SiLU() + + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + added_positions: torch.Tensor = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.FloatTensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain + tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] # Take the classifier guidance also as an input in batch + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + # motion score + fps + aug strength embeds + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + + # Wrap up + emb = emb + aug_emb + + + # if added_positions is not None: + # # position embeds + # position_embeds = self.add_time_proj(added_positions.flatten()) + # position_embeds = position_embeds.reshape((batch_size, -1)) + # position_embeds = position_embeds.to(emb.dtype) + # emb += self.position_embedding(position_embeds) + + + + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + + # 3. Down + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + if is_controlnet: + # For Controlnet, we will append down block together + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) + diff --git a/test_code/inference.py b/test_code/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..1360c92d3c710162fb93a8856daec281f8f78a17 --- /dev/null +++ b/test_code/inference.py @@ -0,0 +1,465 @@ +''' + This file is to test UNet and GestureNet. +''' + +import os, shutil, sys +import urllib.request +import argparse +import imageio +import math +import cv2 +from PIL import Image +import collections +import numpy as np + +import torch +from pathlib import Path +from omegaconf import OmegaConf +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from diffusers import ( + AutoencoderKLTemporalDecoder, + DDPMScheduler, +) +from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, PretrainedConfig + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from train_code.train_svd import import_pretrained_text_encoder +from data_loader.video_dataset import tokenize_captions +from data_loader.video_this_that_dataset import get_thisthat_sam +from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel +from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline +from svd.temporal_controlnet import ControlNetModel +from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline + + + +# Seed +# torch.manual_seed(42) +# np.random.seed(42) + + +def unet_inference(vae, unet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, + parent_store_folder = None, force_close_flip = False, use_ambiguous_prompt=False): + + # Init + validation_source_folder = config["validation_img_folder"] + + + # Init the pipeline + pipeline = StableVideoDiffusionPipeline.from_pretrained( + config["pretrained_model_name_or_path"], + vae = accelerator.unwrap_model(vae), + image_encoder = accelerator.unwrap_model(image_encoder), + unet = accelerator.unwrap_model(unet), + revision = None, # Set None directly now + torch_dtype = weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + + # Process all image in the folder + frames_collection = [] + for image_name in sorted(os.listdir(validation_source_folder)): + if accelerator.is_main_process: + if parent_store_folder is None: + validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) + else: + validation_store_folder = os.path.join(parent_store_folder, image_name) + + if os.path.exists(validation_store_folder): + shutil.rmtree(validation_store_folder) + os.makedirs(validation_store_folder) + + image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') + ref_image = load_image(image_path) + ref_image = ref_image.resize((config["width"], config["height"])) + + + # Decide the motion score in SVD (mostly what we use is fix value now) + if config["motion_bucket_id"] is None: + raise NotImplementedError("We need a fixed motion_bucket_id in the config") + else: + reflected_motion_bucket_id = config["motion_bucket_id"] + # print("Inference Motion Bucket ID is ", reflected_motion_bucket_id) + + + # Prepare text prompt + if config["use_text"]: + # Read the file + file_path = os.path.join(validation_source_folder, image_name, "lang.txt") + file = open(file_path, 'r') + prompt = file.readlines()[0] # Only read the first line + if use_ambiguous_prompt: + prompt = prompt.split(" ")[0] + " this to there" + print("We are creating ambiguous prompt, which is: ", prompt) + else: + prompt = "" + # Use the same tokenize process as the dataset preparation stage + tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim + + # Store the prompt for the sanity check + f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") + f.write(prompt) + f.close() + + + # Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) + flip = False + if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard + if random.random() < config["flip_aug_prob"]: + if config["use_text"]: + if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) + flip = True + else: + flip = True + if flip: + print("Use flip in validation!") + ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) + + + # Call the model for inference + with torch.autocast("cuda"): + frames = pipeline( + ref_image, + tokenized_prompt, + config["use_text"], + text_encoder, + height = config["height"], + width = config["width"], + num_frames = config["video_seq_length"], + num_inference_steps = config["num_inference_steps"], + decode_chunk_size = 8, + motion_bucket_id = reflected_motion_bucket_id, + fps = 7, + noise_aug_strength = config["inference_noise_aug_strength"], + ).frames[0] + + # Store the frames + # breakpoint() + for idx, frame in enumerate(frames): + frame.save(os.path.join(validation_store_folder, str(idx)+".png")) + imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames) # gif storage quality is not high, recommend to check png images + + frames_collection.append(frames) + + + # Cleaning process + del pipeline + torch.cuda.empty_cache() + + return frames_collection # Return resuly based on the need + + + +def gesturenet_inference(vae, unet, controlnet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, + parent_store_folder=None, force_close_flip=False, use_ambiguous_prompt=False): + + + # Init + validation_source_folder = config["validation_img_folder"] + + + # Init the pipeline + pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained( + config["pretrained_model_name_or_path"], # Still based on regular SVD config + vae = vae, + image_encoder = image_encoder, + unet = unet, + revision = None, # Set None directly now + torch_dtype = weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + + # Process all image in the folder + frames_collection = [] + for image_name in sorted(os.listdir(validation_source_folder)): + if accelerator.is_main_process: + if parent_store_folder is None: + validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) + else: + validation_store_folder = os.path.join(parent_store_folder, image_name) + + if os.path.exists(validation_store_folder): + shutil.rmtree(validation_store_folder) + os.makedirs(validation_store_folder) + + image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') + ref_image = load_image(image_path) # [0, 255] Range + ref_image = ref_image.resize((config["width"], config["height"])) + + + # Prepare text prompt + if config["use_text"]: + # Read the file + file_path = os.path.join(validation_source_folder, image_name, "lang.txt") + file = open(file_path, 'r') + prompt = file.readlines()[0] # Only read the first line + if use_ambiguous_prompt: + prompt = prompt.split(" ")[0] + " this to there" + print("We are creating ambiguous prompt, which is: ", prompt) + else: + prompt = "" + # Use the same tokenize process as the dataset preparation stage + tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim + + # Store the prompt for the sanity check + f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") + f.write(prompt) + f.close() + + # Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) + flip = False + if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard + if random.random() < config["flip_aug_prob"]: + if config["use_text"]: + if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) + flip = True + else: + flip = True + if flip: + print("Use flip in validation!") + ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) + + + if config["data_loader_type"] == "thisthat": + condition_img, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(config, + os.path.join(validation_source_folder, image_name), + flip = flip, + store_dir = validation_store_folder, + verbose = True) + else: + raise NotImplementedError("We don't support such data loader type") + + + + # Call the pipeline + with torch.autocast("cuda"): + frames = pipeline( + image = ref_image, + condition_img = condition_img, # numpy [0,1] range + controlnet = accelerator.unwrap_model(controlnet), + prompt = tokenized_prompt, + use_text = config["use_text"], + text_encoder = text_encoder, + height = config["height"], + width = config["width"], + num_frames = config["video_seq_length"], + decode_chunk_size = 8, + motion_bucket_id = reflected_motion_bucket_id, + # controlnet_image_index = controlnet_image_index, + # coordinate_values = coordinate_values, + num_inference_steps = config["num_inference_steps"], + max_guidance_scale = config["inference_max_guidance_scale"], + fps = 7, + use_instructpix2pix = config["use_instructpix2pix"], + noise_aug_strength = config["inference_noise_aug_strength"], + controlnet_conditioning_scale = config["outer_conditioning_scale"], + inner_conditioning_scale = config["inner_conditioning_scale"], + guess_mode = config["inference_guess_mode"], # False in inference + image_guidance_scale = config["image_guidance_scale"], + ).frames[0] + + for idx, frame in enumerate(frames): + frame.save(os.path.join(validation_store_folder, str(idx)+".png")) + imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames, duration=0.05) + + frames_collection.append(frames) + + + # Cleaning process + del pipeline + torch.cuda.empty_cache() + + return frames_collection # Return resuly based on the need + + + +def execute_inference(huggingface_pretrained_path, model_type, validation_path, parent_store_folder, use_ambiguous_prompt): + + # Check path + if os.path.exists(parent_store_folder): + shutil.rmtree(parent_store_folder) + os.makedirs(parent_store_folder) + + + # Read the yaml setting files (Very important for loading hyperparamters needed) + if not os.path.exists(huggingface_pretrained_path): + yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="unet", filename="train_image2video.yaml") + if model_type == "GestureNet": + yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="gesturenet", filename="train_image2video_gesturenet.yaml") + else: # If the path is a local path we can concatenate it here + yaml_download_path = os.path.join(huggingface_pretrained_path, "unet", "train_image2video.yaml") + if model_type == "GestureNet": + yaml_download_path = os.path.join(huggingface_pretrained_path, "gesturenet", "train_image2video_gesturenet.yaml") + + # Load the config + assert(os.path.exists(yaml_download_path)) + base_config = OmegaConf.load(yaml_download_path) + + + # Other Settings + base_config["validation_img_folder"] = validation_path + + + + ################################################ Prepare vae, unet, image_encoder Same as before ################################################################# + accelerator = Accelerator( + gradient_accumulation_steps = base_config["gradient_accumulation_steps"], + mixed_precision = base_config["mixed_precision"], + log_with = base_config["report_to"], + project_config = ProjectConfiguration(project_dir=base_config["output_dir"], logging_dir=Path(base_config["output_dir"], base_config["logging_name"])), + ) + feature_extractor = CLIPImageProcessor.from_pretrained( + base_config["pretrained_model_name_or_path"], subfolder="feature_extractor", revision=None + ) # This instance has now weight, they are just seeting file + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + base_config["pretrained_model_name_or_path"], subfolder="image_encoder", revision=None, variant="fp16" + ) + vae = AutoencoderKLTemporalDecoder.from_pretrained( + base_config["pretrained_model_name_or_path"], subfolder="vae", revision=None, variant="fp16" + ) + unet = UNetSpatioTemporalConditionModel.from_pretrained( + huggingface_pretrained_path, + subfolder = "unet", + low_cpu_mem_usage = True, + # variant = "fp16", + ) + + + # For text .............................................. + tokenizer = AutoTokenizer.from_pretrained( + base_config["pretrained_tokenizer_name_or_path"], + subfolder = "tokenizer", + revision = None, + use_fast = False, + ) + # Clip Text Encoder + text_encoder_cls = import_pretrained_text_encoder(base_config["pretrained_tokenizer_name_or_path"], revision=None) + text_encoder = text_encoder_cls.from_pretrained(base_config["pretrained_tokenizer_name_or_path"], subfolder = "text_encoder", revision = None, variant = None) + + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae + image_encoder to gpu and cast to weight_dtype + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) # Will switch back at the end + text_encoder.requires_grad_(False) + + # Move to accelerator + vae.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # For GestureNet + if model_type == "GestureNet": + unet.to(accelerator.device, dtype=weight_dtype) # There is no need to cast unet in unet training, only needed in controlnet one + + # Handle the Controlnet first from UNet + gesturenet = ControlNetModel.from_pretrained( + huggingface_pretrained_path, + subfolder = "gesturenet", + low_cpu_mem_usage = True, + variant = None, + ) + + gesturenet.requires_grad_(False) + gesturenet.to(accelerator.device) + ############################################################################################################################################################## + + + + ############################################################### Execution ##################################################################################### + + # Prepare the iterative calling + if model_type == "UNet": + generated_frames = unet_inference( + vae, unet, image_encoder, text_encoder, tokenizer, + base_config, accelerator, weight_dtype, step="", + parent_store_folder=parent_store_folder, force_close_flip = True, + use_ambiguous_prompt = use_ambiguous_prompt, + ) + + elif model_type == "GestureNet": + generated_frames = gesturenet_inference( + vae, unet, gesturenet, image_encoder, text_encoder, tokenizer, + base_config, accelerator, weight_dtype, step="", + parent_store_folder=parent_store_folder, force_close_flip = True, + use_ambiguous_prompt = use_ambiguous_prompt, + ) + + else: + raise NotImplementedError("model_type is no the predefined choices we provide!") + + ################################################################################################################################################################ + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--model_type", + type=str, + default="GestureNet", + help="\"UNet\" for VL (vision language) / \"GestureNet\" for VGL (vision gesture language)", + ) + parser.add_argument( + "--huggingface_pretrained_path", + type=str, + default="HikariDawn/This-and-That-1.1", + help="Path to the unet folder path.", + ) + parser.add_argument( + "--validation_path", + type=str, + default="__assets__/Bridge_example/", + help="Sample dataset path, default to the Bridge example.", + ) + parser.add_argument( + "--parent_store_folder", + type=str, + default="generated_results/", + help="Path to the store result folder.", + ) + parser.add_argument( + "--use_ambiguous_prompt", + type=str, + default=False, + help="Whether we will use action verb + \"this to there\" ambgiguous prompt combo.", + ) + args = parser.parse_args() + + + # File Setting + model_type = args.model_type + huggingface_pretrained_path = args.huggingface_pretrained_path + # validation_path Needs to have subforder for each instance. + # Each instance requries "im_0.jpg" for the first image; data.txt for the gesture position; lang.txt for the language + validation_path = args.validation_path + parent_store_folder = args.parent_store_folder + use_ambiguous_prompt = args.use_ambiguous_prompt + + + # Execution + execute_inference(huggingface_pretrained_path, model_type, validation_path, parent_store_folder, use_ambiguous_prompt) + + + print("All finished!!!") + + diff --git a/track_anything_code/model.py b/track_anything_code/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1985af3596111df492800a206121cae701fc86 --- /dev/null +++ b/track_anything_code/model.py @@ -0,0 +1,82 @@ +import os, sys +import PIL +from tqdm import tqdm +import numpy as np +import argparse + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from track_anything_code.tools.interact_tools import SamControler +from track_anything_code.tracker.base_tracker import BaseTracker + + +class TrackingAnything(): + def __init__(self, sam_checkpoint, xmem_checkpoint, args): + self.args = args + self.sam_checkpoint = sam_checkpoint + self.xmem_checkpoint = xmem_checkpoint + self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device) + self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device) + + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) + return mask, logit, painted_image + + + def generator(self, images: list, template_mask:np.ndarray): + + masks = [] + logits = [] + painted_images = [] + for i in tqdm(range(len(images)), desc="Tracking image"): + if i ==0: + mask, logit, painted_image = self.xmem.track(images[i], template_mask) + masks.append(mask) + logits.append(logit) + painted_images.append(painted_image) + + else: + mask, logit, painted_image = self.xmem.track(images[i]) + masks.append(mask) + logits.append(logit) + painted_images.append(painted_image) + return masks, logits, painted_images + + +def parse_augment(): + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default="cuda:0") + parser.add_argument('--sam_model_type', type=str, default="vit_h") + parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications") + parser.add_argument('--debug', action="store_true") + parser.add_argument('--mask_save', default=False) + args = parser.parse_args() + + if args.debug: + print(args) + return args + + +# if __name__ == "__main__": +# masks = None +# logits = None +# painted_images = None +# images = [] +# image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg')) +# args = parse_augment() +# # images.append(np.ones((20,20,3)).astype('uint8')) +# # images.append(np.ones((20,20,3)).astype('uint8')) +# images.append(image) +# images.append(image) + +# mask = np.zeros_like(image)[:,:,0] +# mask[0,0]= 1 +# trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args) +# masks, logits ,painted_images= trackany.generator(images, mask) + + + + + \ No newline at end of file diff --git a/track_anything_code/tools/__init__.py b/track_anything_code/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/track_anything_code/tools/base_segmenter.py b/track_anything_code/tools/base_segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..2b975bb779b47485f9e6ba7435646b4db40a2c6a --- /dev/null +++ b/track_anything_code/tools/base_segmenter.py @@ -0,0 +1,129 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter + + +class BaseSegmenter: + def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): + """ + device: model device + SAM_checkpoint: path of SAM checkpoint + model_type: vit_b, vit_l, vit_h + """ + print(f"Initializing BaseSegmenter to {device}") + assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' + + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) + self.model.to(device=self.device) + self.predictor = SamPredictor(self.model) + self.embedded = False + + @torch.no_grad() + def set_image(self, image: np.ndarray): + # PIL.open(image_path) 3channel: RGB + # image embedding: avoid encode the same image multiple times + self.orignal_image = image + if self.embedded: + print('repeat embedding, please reset_image.') + return + self.predictor.set_image(image) + self.embedded = True + return + + @torch.no_grad() + def reset_image(self): + # reset image embeding + self.predictor.reset_image() + self.embedded = False + + def predict(self, prompts, mode, multimask=True): + """ + image: numpy array, h, w, 3 + prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' + prompts['point_coords']: numpy array [N,2] + prompts['point_labels']: numpy array [1,N] + prompts['mask_input']: numpy array [1,256,256] + mode: 'point' (points only), 'mask' (mask only), 'both' (consider both) + mask_outputs: True (return 3 masks), False (return 1 mask only) + whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :] + """ + assert self.embedded, 'prediction is called before set_image (feature embedding).' + assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both' + + if mode == 'point': + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + multimask_output=multimask) + elif mode == 'mask': + masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'], + multimask_output=multimask) + elif mode == 'both': # both + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + mask_input=prompts['mask_input'], + multimask_output=multimask) + else: + raise("Not implement now!") + # masks (n, h, w), scores (n,), logits (n, 256, 256) + return masks, scores, logits + + +if __name__ == "__main__": + # load and show an image + image = cv2.imread('/hhd3/gaoshang/truck.jpg') + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3) + + # initialise BaseSegmenter + SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + model_type = 'vit_h' + device = "cuda:4" + base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) + + # image embedding (once embedded, multiple prompts can be applied) + base_segmenter.set_image(image) + + # examples + # point only ------------------------ + mode = 'point' + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 1]), + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) + + # both ------------------------ + mode = 'both' + mask_input = logits[np.argmax(scores), :, :] + prompts = {'mask_input': mask_input [None, :, :]} + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 0]), + 'mask_input': mask_input[None, :, :] + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) + + # mask only ------------------------ + mode = 'mask' + mask_input = logits[np.argmax(scores), :, :] + + prompts = {'mask_input': mask_input[None, :, :]} + + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) diff --git a/track_anything_code/tools/interact_tools.py b/track_anything_code/tools/interact_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..daecc73e5f54c95b53c04520110775281a6e0560 --- /dev/null +++ b/track_anything_code/tools/interact_tools.py @@ -0,0 +1,265 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter as mask_painter2 +from .base_segmenter import BaseSegmenter +from .painter import mask_painter, point_painter +import os +import requests +import sys + + +mask_color = 3 +mask_alpha = 0.7 +contour_color = 1 +contour_width = 5 +point_color_ne = 8 +point_color_ps = 50 +point_alpha = 0.9 +point_radius = 15 +contour_color = 2 +contour_width = 5 + + +class SamControler(): + def __init__(self, SAM_checkpoint, model_type, device): + ''' + initialize sam controler + ''' + + + self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) + + + # def seg_again(self, image: np.ndarray): + # ''' + # it is used when interact in video + # ''' + # self.sam_controler.reset_image() + # self.sam_controler.set_image(image) + # return + + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3): + ''' + it is used in first frame in video + return: mask, logit, painted image(mask+point) + ''' + # self.sam_controler.set_image(image) + origal_image = self.sam_controler.orignal_image + neg_flag = labels[-1] + if neg_flag==1: + #find neg + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + prompts = { + 'point_coords': points, + 'point_labels': labels, + 'mask_input': logit[None, :, :] + } + masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + else: + #find positive + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + + assert len(points)==len(labels) + + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + painted_image = Image.fromarray(painted_image) + + return mask, logit, painted_image + + # def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + # origal_image = self.sam_controler.orignal_image + # if same: + # ''' + # true; loop in the same image + # ''' + # prompts = { + # 'point_coords': points, + # 'point_labels': labels, + # 'mask_input': logits[None, :, :] + # } + # masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + # painted_image = Image.fromarray(painted_image) + + # return mask, logit, painted_image + # else: + # ''' + # loop in the different image, interact in the video + # ''' + # if image is None: + # raise('Image error') + # else: + # self.seg_again(image) + # prompts = { + # 'point_coords': points, + # 'point_labels': labels, + # } + # masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + # painted_image = Image.fromarray(painted_image) + + # return mask, logit, painted_image + + + + + + +# def initialize(): +# ''' +# initialize sam controler +# ''' +# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" +# folder = "segmenter" +# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth' +# download_checkpoint(checkpoint_url, folder, SAM_checkpoint) + + +# model_type = 'vit_h' +# device = "cuda:0" +# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) +# return sam_controler + + +# def seg_again(sam_controler, image: np.ndarray): +# ''' +# it is used when interact in video +# ''' +# sam_controler.reset_image() +# sam_controler.set_image(image) +# return + + +# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): +# ''' +# it is used in first frame in video +# return: mask, logit, painted image(mask+point) +# ''' +# sam_controler.set_image(image) +# prompts = { +# 'point_coords': points, +# 'point_labels': labels, +# } +# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask) +# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + +# assert len(points)==len(labels) + +# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) +# painted_image = Image.fromarray(painted_image) + +# return mask, logit, painted_image + +# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): +# if same: +# ''' +# true; loop in the same image +# ''' +# prompts = { +# 'point_coords': points, +# 'point_labels': labels, +# 'mask_input': logits[None, :, :] +# } +# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask) +# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + +# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) +# painted_image = Image.fromarray(painted_image) + +# return mask, logit, painted_image +# else: +# ''' +# loop in the different image, interact in the video +# ''' +# if image is None: +# raise('Image error') +# else: +# seg_again(sam_controler, image) +# prompts = { +# 'point_coords': points, +# 'point_labels': labels, +# } +# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask) +# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + +# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) +# painted_image = Image.fromarray(painted_image) + +# return mask, logit, painted_image + + + + +# if __name__ == "__main__": +# points = np.array([[500, 375], [1125, 625]]) +# labels = np.array([1, 1]) +# image = cv2.imread('/hhd3/gaoshang/truck.jpg') +# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + +# sam_controler = initialize() +# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True) +# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) +# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) +# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) +# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image) +# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg') + +# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True) +# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) +# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) +# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image) +# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg') + +# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True) +# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) +# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) +# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image) +# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg') + + + + + + + + + + + + \ No newline at end of file diff --git a/track_anything_code/tools/mask_painter.py b/track_anything_code/tools/mask_painter.py new file mode 100644 index 0000000000000000000000000000000000000000..f471ea0116d656e2cc236832893b07c6d7be1643 --- /dev/null +++ b/track_anything_code/tools/mask_painter.py @@ -0,0 +1,288 @@ +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha): + background_color = np.array(background_color) + contour_color = np.array(contour_color) + + # background_mask = 1 - background_mask + # contour_mask = 1 - contour_mask + + for i in range(3): + image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \ + + background_color[i] * (background_alpha-background_mask*background_alpha) + + image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \ + + contour_color[i] * (contour_alpha-contour_mask*contour_alpha) + + return image.astype('uint8') + + +def mask_generator_00(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + return mask, contour_mask + + +def mask_generator_01(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return mask, contour_mask + + +def mask_generator_10(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + return background_mask, contour_mask + + +def mask_generator_11(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return background_mask, contour_mask + + +def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'): + """ + Input: + input_image: numpy array + input_mask: numpy array + background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing + background_blur_radius: radius of background blur, must be odd number + contour_width: width of mask contour, must be odd number + contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others + contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted + mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both + + Output: + painted_image: numpy array + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape' + assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' + assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11' + + # downsample input image and mask + width, height = input_image.shape[0], input_image.shape[1] + res = 1024 + ratio = min(1.0 * res / max(width, height), 1.0) + input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio))) + input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio))) + + # 0: background, 1: foreground + msk = np.clip(input_mask, 0, 1) + + # generate masks for background and contour pixels + background_radius = (background_blur_radius - 1) // 2 + contour_radius = (contour_width - 1) // 2 + generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11} + background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius) + + # paint + painted_image = vis_add_mask\ + (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background + + return painted_image + + +if __name__ == '__main__': + + background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing + background_blur_radius = 31 # radius of background blur, must be odd number + contour_width = 11 # contour width, must be odd number + contour_color = 3 # id in color map, 0: black, 1: white, >1: others + contour_alpha = 1 # transparency of background, 0: no contour highlighted + + # load input image and mask + input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P')) + + # paint + overall_time_1 = 0 + overall_time_2 = 0 + overall_time_3 = 0 + overall_time_4 = 0 + overall_time_5 = 0 + + for i in range(50): + t2 = time.time() + painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00') + e2 = time.time() + + t3 = time.time() + painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10') + e3 = time.time() + + t1 = time.time() + painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha) + e1 = time.time() + + t4 = time.time() + painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01') + e4 = time.time() + + t5 = time.time() + painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11') + e5 = time.time() + + overall_time_1 += (e1 - t1) + overall_time_2 += (e2 - t2) + overall_time_3 += (e3 - t3) + overall_time_4 += (e4 - t4) + overall_time_5 += (e5 - t5) + + print(f'average time w gaussian: {overall_time_1/50}') + print(f'average time w/o gaussian00: {overall_time_2/50}') + print(f'average time w/o gaussian10: {overall_time_3/50}') + print(f'average time w/o gaussian01: {overall_time_4/50}') + print(f'average time w/o gaussian11: {overall_time_5/50}') + + # save + painted_image_00 = Image.fromarray(painted_image_00) + painted_image_00.save('./test_img/painter_output_image_00.png') + + painted_image_10 = Image.fromarray(painted_image_10) + painted_image_10.save('./test_img/painter_output_image_10.png') + + painted_image_01 = Image.fromarray(painted_image_01) + painted_image_01.save('./test_img/painter_output_image_01.png') + + painted_image_11 = Image.fromarray(painted_image_11) + painted_image_11.save('./test_img/painter_output_image_11.png') diff --git a/track_anything_code/tools/painter.py b/track_anything_code/tools/painter.py new file mode 100644 index 0000000000000000000000000000000000000000..0e711d35aa8348d15cdad9d1cd413da41ea4f1ab --- /dev/null +++ b/track_anything_code/tools/painter.py @@ -0,0 +1,215 @@ +# paint masks, contours, or points on images, with specified colors +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, mask, color, alpha): + color = np.array(color_list[color]) + mask = mask > 0.5 + image[mask] = image[mask] * (1-alpha) + color * alpha + return image.astype('uint8') + +def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5): + h, w = input_image.shape[:2] + point_mask = np.zeros((h, w)).astype('uint8') + for point in input_points: + point_mask[point[1], point[0]] = 1 + + kernel = cv2.getStructuringElement(2, (point_radius, point_radius)) + point_mask = cv2.dilate(point_mask, kernel) + + contour_radius = (contour_width - 1) // 2 + dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + return painted_image + +def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3): + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.clip(input_mask, 0, 1) + contour_radius = (contour_width - 1) // 2 + + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + + return painted_image + +def background_remover(input_image, input_mask): + """ + input_image: H, W, 3, np.array + input_mask: H, W, np.array + + image_wo_background: PIL.Image + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255 + image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4 + image_wo_background = Image.fromarray(image_wo_background).convert('RGBA') + + return image_wo_background + +if __name__ == '__main__': + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P')) + + # example of mask painter + mask_color = 3 + mask_alpha = 0.7 + contour_color = 1 + contour_width = 5 + + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original.png') + + painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width) + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original1.png') + + # example of point painter + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_points = np.array([[500, 375], [70, 600]]) # x, y + point_color = 5 + point_alpha = 0.9 + point_radius = 15 + contour_color = 2 + contour_width = 5 + painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width) + # save + painted_image = Image.fromarray(painted_image_1) + painted_image.save('images/point_painter_1.png') + + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29) + # save + painted_image = Image.fromarray(painted_image_2) + painted_image.save('images/point_painter_2.png') + + # example of background remover + input_image = np.array(Image.open('images/original.png').convert('RGB')) + image_wo_background = background_remover(input_image, input_mask) # return PIL.Image + image_wo_background.save('images/image_wo_background.png') diff --git a/track_anything_code/track_anything_module.py b/track_anything_code/track_anything_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6c5088c28d1f94460e91ca8c209eceb6503fe6 --- /dev/null +++ b/track_anything_code/track_anything_module.py @@ -0,0 +1,352 @@ +import gradio as gr +import argparse +import gdown +import cv2 +import numpy as np +import os +import sys +import requests +import json +import torchvision +import torch +import psutil +import time +import imageio +try: + from mmcv.cnn import ConvModule +except: + os.system("mim install mmcv") + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from track_anything_code.model import TrackingAnything + + + + +def parse_augment(): + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default="cuda:0") + parser.add_argument('--sam_model_type', type=str, default="vit_h") + parser.add_argument('--debug', action="store_true") + parser.add_argument('--mask_save', default=False) + args = parser.parse_args() + + if args.debug: + print(args) + return args + + +# download checkpoints +def download_checkpoint(url, folder, filename): + os.makedirs(folder, exist_ok=True) + filepath = os.path.join(folder, filename) + + if not os.path.exists(filepath): + print("download checkpoints ......") + response = requests.get(url, stream=True) + with open(filepath, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + print("download successfully!") + + return filepath + +def download_checkpoint_from_google_drive(file_id, folder, filename): + os.makedirs(folder, exist_ok=True) + filepath = os.path.join(folder, filename) + + if not os.path.exists(filepath): + print("Downloading checkpoints from Google Drive... tips: If you cannot see the progress bar, please try to download it manuall \ + and put it in the checkpointes directory. E2FGVI-HQ-CVPR22.pth: https://github.com/MCG-NKU/E2FGVI(E2FGVI-HQ model)") + url = f"https://drive.google.com/uc?id={file_id}" + gdown.download(url, filepath, quiet=False) + print("Downloaded successfully!") + + return filepath + +# convert points input to prompt state +def get_prompt(click_state, click_input): + inputs = json.loads(click_input) + points = click_state[0] + labels = click_state[1] + for input in inputs: + points.append(input[:2]) + labels.append(input[2]) + click_state[0] = points + click_state[1] = labels + prompt = { + "prompt_type":["click"], + "input_point":click_state[0], + "input_label":click_state[1], + "multimask_output":"False", + } + return prompt + + + +# extract frames from upload video +def get_frames_from_video(video_path, video_state, model): + """ Extract video information based on the input + Args: + video_path: str + timestamp: float64 + Return + [[0:nearest_frame], [nearest_frame:], nearest_frame] + """ + frames = [] + user_name = time.time() + try: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + while cap.isOpened(): + ret, frame = cap.read() + if ret == True: + current_memory_usage = psutil.virtual_memory().percent + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + if current_memory_usage > 90: + print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") + break + else: + break + except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: + print("read_frame_source:{} error. {}\n".format(video_path, str(e))) + image_size = (frames[0].shape[0],frames[0].shape[1]) + + # initialize video_state + video_state = { + "user_name": user_name, + "video_name": os.path.split(video_path)[-1], + "origin_images": frames, + "painted_images": frames.copy(), + "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), + "logits": [None]*len(frames), + "select_frame_number": 0, + "fps": fps + } + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) + return video_state, video_state["origin_images"][0] + + + +def run_example(example): + return video_input +# get the select frame from gradio slider +def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown): + + # images = video_state[1] + image_selection_slider -= 1 + video_state["select_frame_number"] = image_selection_slider + + # once select a new template frame, set the image in sam + + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) + + if mask_dropdown: + print("ok") + operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")] + + + return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log + +# set the tracking end frame +def get_end_number(track_pause_number_slider, video_state, interactive_state): + interactive_state["track_end_number"] = track_pause_number_slider + operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")] + + return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log + +def get_resize_ratio(resize_ratio_slider, interactive_state): + interactive_state["resize_ratio"] = resize_ratio_slider + + return interactive_state + +# use sam to get the mask +def sam_refine(model, video_state, point_prompt, click_state, interactive_state, point_cord): + """ + Args: + template_frame: PIL.Image + point_prompt: flag for positive or negative button click + click_state: [[points], [labels]] + """ + if point_prompt == "Positive": + coordinate = "[[{},{},1]]".format(point_cord[0], point_cord[1]) # Height and Width + interactive_state["positive_click_times"] += 1 + else: + coordinate = "[[{},{},0]]".format(point_cord[0], point_cord[1]) + interactive_state["negative_click_times"] += 1 + + # prompt for sam model + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) + prompt = get_prompt(click_state=click_state, click_input=coordinate) + + mask, logit, painted_image = model.first_frame_click( + image=video_state["origin_images"][video_state["select_frame_number"]], + points=np.array(prompt["input_point"]), + labels=np.array(prompt["input_label"]), + multimask=False, # False by default + ) + video_state["masks"][video_state["select_frame_number"]] = mask + video_state["logits"][video_state["select_frame_number"]] = logit + video_state["painted_images"][video_state["select_frame_number"]] = painted_image + + operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")] + return painted_image, video_state, interactive_state, operation_log + + +def clear_click(video_state, click_state): + click_state = [[],[]] + template_frame = video_state["origin_images"][video_state["select_frame_number"]] + operation_log = [("",""), ("Clear points history and refresh the image.","Normal")] + return template_frame, click_state, operation_log + +def remove_multi_mask(interactive_state, mask_dropdown): + interactive_state["multi_mask"]["mask_names"]= [] + interactive_state["multi_mask"]["masks"] = [] + + operation_log = [("",""), ("Remove all mask, please add new masks","Normal")] + return interactive_state, gr.update(choices=[],value=[]), operation_log + + + +# tracking vos +def vos_tracking_video(model, output_path, video_state, interactive_state, mask_dropdown): + operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")] + model.xmem.clear_memory() + if interactive_state["track_end_number"]: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] + else: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:] + + if interactive_state["multi_mask"]["masks"]: + if len(mask_dropdown) == 0: + mask_dropdown = ["mask_001"] + mask_dropdown.sort() + template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) + for i in range(1,len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) + video_state["masks"][video_state["select_frame_number"]]= template_mask + else: + template_mask = video_state["masks"][video_state["select_frame_number"]] + fps = video_state["fps"] + + # operation error + if len(np.unique(template_mask))==1: + template_mask[0][0]=1 + operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")] + # return video_output, video_state, interactive_state, operation_error + masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) + # clear GPU memory + model.xmem.clear_memory() + + if interactive_state["track_end_number"]: + video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks + video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits + video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images + else: + video_state["masks"][video_state["select_frame_number"]:] = masks + video_state["logits"][video_state["select_frame_number"]:] = logits + video_state["painted_images"][video_state["select_frame_number"]:] = painted_images + + generate_video_from_frames(video_state["painted_images"], output_path=output_path, fps=fps) # import video_input to name the output video + interactive_state["inference_times"] += 1 + + print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], + interactive_state["positive_click_times"]+interactive_state["negative_click_times"], + interactive_state["positive_click_times"], + interactive_state["negative_click_times"])) + + #### shanggao code for mask save + if interactive_state["mask_save"]: # May not need to use this branch + if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): + os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) + i = 0 + print("save mask") + for mask in video_state["masks"]: + np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) + i+=1 + # save_mask(video_state["masks"], video_state["video_name"]) + #### shanggao code for mask save + return video_state, video_state, interactive_state, operation_log + +# extracting masks from mask_dropdown +# def extract_sole_mask(video_state, mask_dropdown): +# combined_masks = +# unique_masks = np.unique(combined_masks) +# return 0 + +# inpaint +def inpaint_video(video_state, interactive_state, mask_dropdown): + operation_log = [("",""), ("Removed the selected masks.","Normal")] + + frames = np.asarray(video_state["origin_images"]) + fps = video_state["fps"] + inpaint_masks = np.asarray(video_state["masks"]) + if len(mask_dropdown) == 0: + mask_dropdown = ["mask_001"] + mask_dropdown.sort() + # convert mask_dropdown to mask numbers + inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))] + # interate through all masks and remove the masks that are not in mask_dropdown + unique_masks = np.unique(inpaint_masks) + num_masks = len(unique_masks) - 1 + for i in range(1, num_masks + 1): + if i in inpaint_mask_numbers: + continue + inpaint_masks[inpaint_masks==i] = 0 + # inpaint for videos + + try: + inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3 + except: + operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")] + inpainted_frames = video_state["origin_images"] + video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video + + return video_output, operation_log + + +# generate video after vos inference +def generate_video_from_frames(frames, output_path=None, fps=30): + """ + Generates a video from a list of frames. + + Args: + frames (list of numpy arrays): The frames to include in the video. + output_path (str): If provided, it is the path to save the generated video. Else, we won't store it + fps (int, optional): The frame rate of the output video. Defaults to 30. + """ + + # frames = torch.from_numpy(np.asarray(frames)) + imageio.mimsave(output_path, frames) + # return output_path + + + + +if __name__ == "__main__": + # args, defined in track_anything.py + args = parse_augment() + + # check and download checkpoints if needed + sam_checkpoint = "sam_vit_h_4b8939.pth" + sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + xmem_checkpoint = "XMem-s012.pth" + xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth" + + + folder ="./pretrained" + SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) + xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) + args.device = "cuda" # Any GPU is ok + + # initialize sam, xmem, e2fgvi models + model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) diff --git a/track_anything_code/tracker/base_tracker.py b/track_anything_code/tracker/base_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..b65bf85b955a1b52e5c8d02041dae4491fee6770 --- /dev/null +++ b/track_anything_code/tracker/base_tracker.py @@ -0,0 +1,265 @@ +# import for debugging +import os, sys +import glob +import numpy as np +from PIL import Image +# import for base_tracker +import torch +import yaml +import torch.nn.functional as F +from .inference.inference_core import InferenceCore +from torchvision import transforms +from torchvision.transforms import Resize +import progressbar + + +# Import files from the local folder +# root_path = os.path.abspath('.') +# sys.path.append(root_path) +from .model.network import XMem +from .util.mask_mapper import MaskMapper +from .util.range_transform import im_normalization +from ..tools.painter import mask_painter +from ..tools.base_segmenter import BaseSegmenter + + +class BaseTracker: + def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None: + """ + device: model device + xmem_checkpoint: checkpoint of XMem model + """ + # load configurations + with open("track_anything_code/tracker/config/config.yaml", 'r') as stream: + config = yaml.safe_load(stream) + # initialise XMem + network = XMem(config, xmem_checkpoint).to(device).eval() + # initialise IncerenceCore + self.tracker = InferenceCore(network, config) + # data transformation + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + im_normalization, + ]) + self.device = device + + # changable properties + self.mapper = MaskMapper() + self.initialised = False + + # # SAM-based refinement + # self.sam_model = sam_model + # self.resizer = Resize([256, 256]) + + @torch.no_grad() + def resize_mask(self, mask): + # mask transform is applied AFTER mapper, so we need to post-process it in eval.py + h, w = mask.shape[-2:] + min_hw = min(h, w) + return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), + mode='nearest') + + @torch.no_grad() + def track(self, frame, first_frame_annotation=None): + """ + Input: + frames: numpy arrays (H, W, 3) + logit: numpy array (H, W), logit + + Output: + mask: numpy arrays (H, W) + logit: numpy arrays, probability map (H, W) + painted_image: numpy array (H, W, 3) + """ + + if first_frame_annotation is not None: # first frame mask + # initialisation + mask, labels = self.mapper.convert_mask(first_frame_annotation) + mask = torch.Tensor(mask).to(self.device) + self.tracker.set_all_labels(list(self.mapper.remappings.values())) + else: + mask = None + labels = None + # prepare inputs + frame_tensor = self.im_transform(frame).to(self.device) + # track one frame + probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W + # # refine + # if first_frame_annotation is None: + # out_mask = self.sam_refinement(frame, logits[1], ti) + + # convert to mask + out_mask = torch.argmax(probs, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + + final_mask = np.zeros_like(out_mask) + + # map back + for k, v in self.mapper.remappings.items(): + final_mask[out_mask == v] = k + + num_objs = final_mask.max() + painted_image = frame + for obj in range(1, num_objs+1): + if np.max(final_mask==obj) == 0: + continue + painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1) + + # print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB') + + return final_mask, final_mask, painted_image + + @torch.no_grad() + def sam_refinement(self, frame, logits, ti): + """ + refine segmentation results with mask prompt + """ + # convert to 1, 256, 256 + self.sam_model.set_image(frame) + mode = 'mask' + logits = logits.unsqueeze(0) + logits = self.resizer(logits).cpu().numpy() + prompts = {'mask_input': logits} # 1 256 256 + masks, scores, logits = self.sam_model.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(frame, masks[np.argmax(scores)].astype('uint8'), mask_alpha=0.8) + painted_image = Image.fromarray(painted_image) + painted_image.save(f'/ssd1/gaomingqi/refine/{ti:05d}.png') + self.sam_model.reset_image() + + @torch.no_grad() + def clear_memory(self): + self.tracker.clear_memory() + self.mapper.clear_labels() + torch.cuda.empty_cache() + + +## how to use: +## 1/3) prepare device and xmem_checkpoint +# device = 'cuda:2' +# XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' +## 2/3) initialise Base Tracker +# tracker = BaseTracker(XMEM_checkpoint, device, None, device) # leave an interface for sam model (currently set None) +## 3/3) + + +if __name__ == '__main__': + # video frames (take videos from DAVIS-2017 as examples) + video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg')) + video_path_list.sort() + # load frames + frames = [] + for video_path in video_path_list: + frames.append(np.array(Image.open(video_path).convert('RGB'))) + frames = np.stack(frames, 0) # T, H, W, C + # load first frame annotation + first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png' + first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C + + # ------------------------------------------------------------------------------------ + # how to use + # ------------------------------------------------------------------------------------ + # 1/4: set checkpoint and device + device = 'cuda:2' + XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' + # SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + # model_type = 'vit_h' + # ------------------------------------------------------------------------------------ + # 2/4: initialise inpainter + tracker = BaseTracker(XMEM_checkpoint, device, None, device) + # ------------------------------------------------------------------------------------ + # 3/4: for each frame, get tracking results by tracker.track(frame, first_frame_annotation) + # frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins + painted_frames = [] + for ti, frame in enumerate(frames): + if ti == 0: + mask, prob, painted_frame = tracker.track(frame, first_frame_annotation) + # mask: + else: + mask, prob, painted_frame = tracker.track(frame) + painted_frames.append(painted_frame) + # ---------------------------------------------- + # 3/4: clear memory in XMEM for the next video + tracker.clear_memory() + # ---------------------------------------------- + # end + # ---------------------------------------------- + print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB') + # set saving path + save_path = '/ssd1/gaomingqi/results/TAM/blackswan' + if not os.path.exists(save_path): + os.mkdir(save_path) + # save + for painted_frame in progressbar.progressbar(painted_frames): + painted_frame = Image.fromarray(painted_frame) + painted_frame.save(f'{save_path}/{ti:05d}.png') + + # tracker.clear_memory() + # for ti, frame in enumerate(frames): + # print(ti) + # # if ti > 200: + # # break + # if ti == 0: + # mask, prob, painted_image = tracker.track(frame, first_frame_annotation) + # else: + # mask, prob, painted_image = tracker.track(frame) + # # save + # painted_image = Image.fromarray(painted_image) + # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png') + + # # track anything given in the first frame annotation + # for ti, frame in enumerate(frames): + # if ti == 0: + # mask, prob, painted_image = tracker.track(frame, first_frame_annotation) + # else: + # mask, prob, painted_image = tracker.track(frame) + # # save + # painted_image = Image.fromarray(painted_image) + # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png') + + # # ---------------------------------------------------------- + # # another video + # # ---------------------------------------------------------- + # # video frames + # video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg')) + # video_path_list.sort() + # # first frame + # first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png' + # # load frames + # frames = [] + # for video_path in video_path_list: + # frames.append(np.array(Image.open(video_path).convert('RGB'))) + # frames = np.stack(frames, 0) # N, H, W, C + # # load first frame annotation + # first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C + + # print('first video done. clear.') + + # tracker.clear_memory() + # # track anything given in the first frame annotation + # for ti, frame in enumerate(frames): + # if ti == 0: + # mask, prob, painted_image = tracker.track(frame, first_frame_annotation) + # else: + # mask, prob, painted_image = tracker.track(frame) + # # save + # painted_image = Image.fromarray(painted_image) + # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png') + + # # failure case test + # failure_path = '/ssd1/gaomingqi/failure' + # frames = np.load(os.path.join(failure_path, 'video_frames.npy')) + # # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB')) + # first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P')) + # first_mask = np.clip(first_mask, 0, 1) + + # for ti, frame in enumerate(frames): + # if ti == 0: + # mask, probs, painted_image = tracker.track(frame, first_mask) + # else: + # mask, probs, painted_image = tracker.track(frame) + # # save + # painted_image = Image.fromarray(painted_image) + # painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png') + # prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8')) + + # # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png') diff --git a/track_anything_code/tracker/config/config.yaml b/track_anything_code/tracker/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c99064e04262eb50827056bef225877bbc12822 --- /dev/null +++ b/track_anything_code/tracker/config/config.yaml @@ -0,0 +1,15 @@ +# config info for XMem +benchmark: False +disable_long_term: False +max_mid_term_frames: 10 +min_mid_term_frames: 5 +max_long_term_elements: 1000 +num_prototypes: 128 +top_k: 30 +mem_every: 5 +deep_update_every: -1 +save_scores: False +flip: False +size: 480 +enable_long_term: True +enable_long_term_count_usage: True diff --git a/track_anything_code/tracker/inference/__init__.py b/track_anything_code/tracker/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/track_anything_code/tracker/inference/inference_core.py b/track_anything_code/tracker/inference/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..f1dadc179329d3df6b2a44c7f69ed84314019f70 --- /dev/null +++ b/track_anything_code/tracker/inference/inference_core.py @@ -0,0 +1,115 @@ +from .memory_manager import MemoryManager +from ..model.network import XMem +from ..model.aggregate import aggregate + +from ..util.tensor_util import pad_divide_by, unpad + + +class InferenceCore: + def __init__(self, network:XMem, config): + self.config = config + self.network = network + self.mem_every = config['mem_every'] + self.deep_update_every = config['deep_update_every'] + self.enable_long_term = config['enable_long_term'] + + # if deep_update_every < 0, synchronize deep update with memory frame + self.deep_update_sync = (self.deep_update_every < 0) + + self.clear_memory() + self.all_labels = None + + def clear_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + if not self.deep_update_sync: + self.last_deep_update_ti = -self.deep_update_every + self.memory = MemoryManager(config=self.config) + + def update_config(self, config): + self.mem_every = config['mem_every'] + self.deep_update_every = config['deep_update_every'] + self.enable_long_term = config['enable_long_term'] + + # if deep_update_every < 0, synchronize deep update with memory frame + self.deep_update_sync = (self.deep_update_every < 0) + self.memory.update_config(config) + + def set_all_labels(self, all_labels): + # self.all_labels = [l.item() for l in all_labels] + self.all_labels = all_labels + + def step(self, image, mask=None, valid_labels=None, end=False): + # image: 3*H*W + # mask: num_objects*H*W or None + self.curr_ti += 1 + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + + is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) + need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) + is_deep_update = ( + (self.deep_update_sync and is_mem_frame) or # synchronized + (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync + ) and (not end) + is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end) + + key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, + need_ek=(self.enable_long_term or need_segment), + need_sk=is_mem_frame) + multi_scale_features = (f16, f8, f4) + + # segment the current frame is needed + if need_segment: + memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) + + hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout, + self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False) + # remove batch dim + pred_prob_with_bg = pred_prob_with_bg[0] + pred_prob_no_bg = pred_prob_with_bg[1:] + + pred_logits_with_bg = pred_logits_with_bg[0] + pred_logits_no_bg = pred_logits_with_bg[1:] + + if is_normal_update: + self.memory.set_hidden(hidden) + else: + pred_prob_no_bg = pred_prob_with_bg = pred_logits_with_bg = pred_logits_no_bg = None + + # use the input mask if any + if mask is not None: + mask, _ = pad_divide_by(mask, 16) + + if pred_prob_no_bg is not None: + # if we have a predicted mask, we work on it + # make pred_prob_no_bg consistent with the input mask + mask_regions = (mask.sum(0) > 0.5) + pred_prob_no_bg[:, mask_regions] = 0 + # shift by 1 because mask/pred_prob_no_bg do not contain background + mask = mask.type_as(pred_prob_no_bg) + if valid_labels is not None: + shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels] + # non-labelled objects are copied from the predicted mask + mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels] + pred_prob_with_bg = aggregate(mask, dim=0) + + # also create new hidden states + self.memory.create_hidden_state(len(self.all_labels), key) + + # save as memory if needed + if is_mem_frame: + value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), + pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update) + self.memory.add_memory(key, shrinkage, value, self.all_labels, + selection=selection if self.enable_long_term else None) + self.last_mem_ti = self.curr_ti + + if is_deep_update: + self.memory.set_hidden(hidden) + self.last_deep_update_ti = self.curr_ti + + if pred_logits_with_bg is None: + return unpad(pred_prob_with_bg, self.pad), None + else: + return unpad(pred_prob_with_bg, self.pad), unpad(pred_logits_with_bg, self.pad) diff --git a/track_anything_code/tracker/inference/kv_memory_store.py b/track_anything_code/tracker/inference/kv_memory_store.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1113096c652ef8ce0504a4e8583007914e1957 --- /dev/null +++ b/track_anything_code/tracker/inference/kv_memory_store.py @@ -0,0 +1,214 @@ +import torch +from typing import List + +class KeyValueMemoryStore: + """ + Works for key/value pairs type storage + e.g., working and long-term memory + """ + + """ + An object group is created when new objects enter the video + Objects in the same group share the same temporal extent + i.e., objects initialized in the same frame are in the same group + For DAVIS/interactive, there is only one object group + For YouTubeVOS, there can be multiple object groups + """ + + def __init__(self, count_usage: bool): + self.count_usage = count_usage + + # keys are stored in a single tensor and are shared between groups/objects + # values are stored as a list indexed by object groups + self.k = None + self.v = [] + self.obj_groups = [] + # for debugging only + self.all_objects = [] + + # shrinkage and selection are also single tensors + self.s = self.e = None + + # usage + if self.count_usage: + self.use_count = self.life_count = None + + def add(self, key, value, shrinkage, selection, objects: List[int]): + new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7 + + # add the key + if self.k is None: + self.k = key + self.s = shrinkage + self.e = selection + if self.count_usage: + self.use_count = new_count + self.life_count = new_life + else: + self.k = torch.cat([self.k, key], -1) + if shrinkage is not None: + self.s = torch.cat([self.s, shrinkage], -1) + if selection is not None: + self.e = torch.cat([self.e, selection], -1) + if self.count_usage: + self.use_count = torch.cat([self.use_count, new_count], -1) + self.life_count = torch.cat([self.life_count, new_life], -1) + + # add the value + if objects is not None: + # When objects is given, v is a tensor; used in working memory + assert isinstance(value, torch.Tensor) + # First consume objects that are already in the memory bank + # cannot use set here because we need to preserve order + # shift by one as background is not part of value + remaining_objects = [obj-1 for obj in objects] + for gi, group in enumerate(self.obj_groups): + for obj in group: + # should properly raise an error if there are overlaps in obj_groups + remaining_objects.remove(obj) + self.v[gi] = torch.cat([self.v[gi], value[group]], -1) + + # If there are remaining objects, add them as a new group + if len(remaining_objects) > 0: + new_group = list(remaining_objects) + self.v.append(value[new_group]) + self.obj_groups.append(new_group) + self.all_objects.extend(new_group) + + assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order ' + else: + # When objects is not given, v is a list that already has the object groups sorted + # used in long-term memory + assert isinstance(value, list) + for gi, gv in enumerate(value): + if gv is None: + continue + if gi < self.num_groups: + self.v[gi] = torch.cat([self.v[gi], gv], -1) + else: + self.v.append(gv) + + def update_usage(self, usage): + # increase all life count by 1 + # increase use of indexed elements + if not self.count_usage: + return + + self.use_count += usage.view_as(self.use_count) + self.life_count += 1 + + def sieve_by_range(self, start: int, end: int, min_size: int): + # keep only the elements *outside* of this range (with some boundary conditions) + # i.e., concat (a[:start], a[end:]) + # min_size is only used for values, we do not sieve values under this size + # (because they are not consolidated) + + if end == 0: + # negative 0 would not work as the end index! + self.k = self.k[:,:,:start] + if self.count_usage: + self.use_count = self.use_count[:,:,:start] + self.life_count = self.life_count[:,:,:start] + if self.s is not None: + self.s = self.s[:,:,:start] + if self.e is not None: + self.e = self.e[:,:,:start] + + for gi in range(self.num_groups): + if self.v[gi].shape[-1] >= min_size: + self.v[gi] = self.v[gi][:,:,:start] + else: + self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) + if self.count_usage: + self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) + self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) + if self.s is not None: + self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) + if self.e is not None: + self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) + + for gi in range(self.num_groups): + if self.v[gi].shape[-1] >= min_size: + self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) + + def remove_obsolete_features(self, max_size: int): + # normalize with life duration + usage = self.get_usage().flatten() + + values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True) + survived = (usage > values[-1]) + + self.k = self.k[:, :, survived] + self.s = self.s[:, :, survived] if self.s is not None else None + # Long-term memory does not store ek so this should not be needed + self.e = self.e[:, :, survived] if self.e is not None else None + if self.num_groups > 1: + raise NotImplementedError("""The current data structure does not support feature removal with + multiple object groups (e.g., some objects start to appear later in the video) + The indices for "survived" is based on keys but not all values are present for every key + Basically we need to remap the indices for keys to values + """) + for gi in range(self.num_groups): + self.v[gi] = self.v[gi][:, :, survived] + + self.use_count = self.use_count[:, :, survived] + self.life_count = self.life_count[:, :, survived] + + def get_usage(self): + # return normalized usage + if not self.count_usage: + raise RuntimeError('I did not count usage!') + else: + usage = self.use_count / self.life_count + return usage + + def get_all_sliced(self, start: int, end: int): + # return k, sk, ek, usage in order, sliced by start and end + + if end == 0: + # negative 0 would not work as the end index! + k = self.k[:,:,start:] + sk = self.s[:,:,start:] if self.s is not None else None + ek = self.e[:,:,start:] if self.e is not None else None + usage = self.get_usage()[:,:,start:] + else: + k = self.k[:,:,start:end] + sk = self.s[:,:,start:end] if self.s is not None else None + ek = self.e[:,:,start:end] if self.e is not None else None + usage = self.get_usage()[:,:,start:end] + + return k, sk, ek, usage + + def get_v_size(self, ni: int): + return self.v[ni].shape[2] + + def engaged(self): + return self.k is not None + + @property + def size(self): + if self.k is None: + return 0 + else: + return self.k.shape[-1] + + @property + def num_groups(self): + return len(self.v) + + @property + def key(self): + return self.k + + @property + def value(self): + return self.v + + @property + def shrinkage(self): + return self.s + + @property + def selection(self): + return self.e diff --git a/track_anything_code/tracker/inference/memory_manager.py b/track_anything_code/tracker/inference/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6d4246324d420009a152f63e957894e9126223 --- /dev/null +++ b/track_anything_code/tracker/inference/memory_manager.py @@ -0,0 +1,286 @@ +import torch +import warnings + +from .kv_memory_store import KeyValueMemoryStore +from ..model.memory_util import * + + +class MemoryManager: + """ + Manages all three memory stores and the transition between working/long-term memory + """ + def __init__(self, config): + self.hidden_dim = config['hidden_dim'] + self.top_k = config['top_k'] + + self.enable_long_term = config['enable_long_term'] + self.enable_long_term_usage = config['enable_long_term_count_usage'] + if self.enable_long_term: + self.max_mt_frames = config['max_mid_term_frames'] + self.min_mt_frames = config['min_mid_term_frames'] + self.num_prototypes = config['num_prototypes'] + self.max_long_elements = config['max_long_term_elements'] + + # dimensions will be inferred from input later + self.CK = self.CV = None + self.H = self.W = None + + # The hidden state will be stored in a single tensor for all objects + # B x num_objects x CH x H x W + self.hidden = None + + self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) + if self.enable_long_term: + self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) + + self.reset_config = True + + def update_config(self, config): + self.reset_config = True + self.hidden_dim = config['hidden_dim'] + self.top_k = config['top_k'] + + assert self.enable_long_term == config['enable_long_term'], 'cannot update this' + assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this' + + self.enable_long_term_usage = config['enable_long_term_count_usage'] + if self.enable_long_term: + self.max_mt_frames = config['max_mid_term_frames'] + self.min_mt_frames = config['min_mid_term_frames'] + self.num_prototypes = config['num_prototypes'] + self.max_long_elements = config['max_long_term_elements'] + + def _readout(self, affinity, v): + # this function is for a single object group + return v @ affinity + + def match_memory(self, query_key, selection): + # query_key: B x C^k x H x W + # selection: B x C^k x H x W + num_groups = self.work_mem.num_groups + h, w = query_key.shape[-2:] + + query_key = query_key.flatten(start_dim=2) + selection = selection.flatten(start_dim=2) if selection is not None else None + + """ + Memory readout using keys + """ + + if self.enable_long_term and self.long_mem.engaged(): + # Use long-term memory + long_mem_size = self.long_mem.size + memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1) + shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1) + + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + work_mem_similarity = similarity[:, long_mem_size:] + long_mem_similarity = similarity[:, :long_mem_size] + + # get the usage with the first group + # the first group always have all the keys valid + affinity, usage = do_softmax( + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1), + top_k=self.top_k, inplace=True, return_usage=True) + affinity = [affinity] + + # compute affinity group by group as later groups only have a subset of keys + for gi in range(1, num_groups): + if gi < self.long_mem.num_groups: + # merge working and lt similarities before softmax + affinity_one_group = do_softmax( + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], + work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1), + top_k=self.top_k, inplace=True) + else: + # no long-term memory for this group + affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):], + top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity.append(affinity_one_group) + + all_memory_value = [] + for gi, gv in enumerate(self.work_mem.value): + # merge the working and lt values before readout + if gi < self.long_mem.num_groups: + all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1)) + else: + all_memory_value.append(gv) + + """ + Record memory usage for working and long-term memory + """ + # ignore the index return for long-term memory + work_usage = usage[:, long_mem_size:] + self.work_mem.update_usage(work_usage.flatten()) + + if self.enable_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_usage(long_usage.flatten()) + else: + # No long-term memory + similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection) + + if self.enable_long_term: + affinity, usage = do_softmax(similarity, inplace=(num_groups==1), + top_k=self.top_k, return_usage=True) + + # Record memory usage for working memory + self.work_mem.update_usage(usage.flatten()) + else: + affinity = do_softmax(similarity, inplace=(num_groups==1), + top_k=self.top_k, return_usage=False) + + affinity = [affinity] + + # compute affinity group by group as later groups only have a subset of keys + for gi in range(1, num_groups): + affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):], + top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity.append(affinity_one_group) + + all_memory_value = self.work_mem.value + + # Shared affinity within each group + all_readout_mem = torch.cat([ + self._readout(affinity[gi], gv) + for gi, gv in enumerate(all_memory_value) + ], 0) + + return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) + + def add_memory(self, key, shrinkage, value, objects, selection=None): + # key: 1*C*H*W + # value: 1*num_objects*C*H*W + # objects contain a list of object indices + if self.H is None or self.reset_config: + self.reset_config = False + self.H, self.W = key.shape[-2:] + self.HW = self.H*self.W + if self.enable_long_term: + # convert from num. frames to num. nodes + self.min_work_elements = self.min_mt_frames*self.HW + self.max_work_elements = self.max_mt_frames*self.HW + + # key: 1*C*N + # value: num_objects*C*N + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + value = value[0].flatten(start_dim=2) + + self.CK = key.shape[1] + self.CV = value.shape[1] + + if selection is not None: + if not self.enable_long_term: + warnings.warn('the selection factor is only needed in long-term mode', UserWarning) + selection = selection.flatten(start_dim=2) + + self.work_mem.add(key, value, shrinkage, selection, objects) + + # long-term memory cleanup + if self.enable_long_term: + # Do memory compressed if needed + if self.work_mem.size >= self.max_work_elements: + # print('remove memory') + # Remove obsolete features if needed + if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): + self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) + + self.compress_features() + + def create_hidden_state(self, n, sample_key): + # n is the TOTAL number of objects + h, w = sample_key.shape[-2:] + if self.hidden is None: + self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device) + elif self.hidden.shape[1] != n: + self.hidden = torch.cat([ + self.hidden, + torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device) + ], 1) + + assert(self.hidden.shape[1] == n) + + def set_hidden(self, hidden): + self.hidden = hidden + + def get_hidden(self): + return self.hidden + + def compress_features(self): + HW = self.HW + candidate_value = [] + total_work_mem_size = self.work_mem.size + for gv in self.work_mem.value: + # Some object groups might be added later in the video + # So not all keys have values associated with all objects + # We need to keep track of the key->value validity + mem_size_in_this_group = gv.shape[-1] + if mem_size_in_this_group == total_work_mem_size: + # full LT + candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + else: + # mem_size is smaller than total_work_mem_size, but at least HW + assert HW <= mem_size_in_this_group < total_work_mem_size + if mem_size_in_this_group > self.min_work_elements+HW: + # part of this object group still goes into LT + candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + else: + # this object group cannot go to the LT at all + candidate_value.append(None) + + # perform memory consolidation + prototype_key, prototype_value, prototype_shrinkage = self.consolidation( + *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value) + + # remove consolidated working memory + self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW) + + # add to long-term memory + self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) + # print(f'long memory size: {self.long_mem.size}') + # print(f'work memory size: {self.work_mem.size}') + + def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value): + # keys: 1*C*N + # values: num_objects*C*N + N = candidate_key.shape[-1] + + # find the indices with max usage + _, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True) + prototype_indices = max_usage_indices.flatten() + + # Prototypes are invalid for out-of-bound groups + validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value] + + prototype_key = candidate_key[:, :, prototype_indices] + prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None + + """ + Potentiation step + """ + similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection) + + # convert similarity to affinity + # need to do it group by group since the softmax normalization would be different + affinity = [ + do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None + for gi, gv in enumerate(candidate_value) + ] + + # some values can be have all False validity. Weed them out. + affinity = [ + aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity + ] + + # readout the values + prototype_value = [ + self._readout(affinity[gi], gv) if affinity[gi] is not None else None + for gi, gv in enumerate(candidate_value) + ] + + # readout the shrinkage term + prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None + + return prototype_key, prototype_value, prototype_shrinkage \ No newline at end of file diff --git a/track_anything_code/tracker/model/__init__.py b/track_anything_code/tracker/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/track_anything_code/tracker/model/aggregate.py b/track_anything_code/tracker/model/aggregate.py new file mode 100644 index 0000000000000000000000000000000000000000..7622391fb3ac9aa8b515df88cf3ea5297b367538 --- /dev/null +++ b/track_anything_code/tracker/model/aggregate.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +# Soft aggregation from STM +def aggregate(prob, dim, return_logits=False): + new_prob = torch.cat([ + torch.prod(1-prob, dim=dim, keepdim=True), + prob + ], dim).clamp(1e-7, 1-1e-7) + logits = torch.log((new_prob /(1-new_prob))) + prob = F.softmax(logits, dim=dim) + + if return_logits: + return logits, prob + else: + return prob \ No newline at end of file diff --git a/track_anything_code/tracker/model/cbam.py b/track_anything_code/tracker/model/cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..6423358429e2843b1f36ceb2bc1a485ea72b8eb4 --- /dev/null +++ b/track_anything_code/tracker/model/cbam.py @@ -0,0 +1,77 @@ +# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + + def forward(self, x): + x = self.conv(x) + return x + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type=='avg': + avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( avg_pool ) + elif pool_type=='max': + max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( max_pool ) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + return x * scale + +class CBAM(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out diff --git a/track_anything_code/tracker/model/group_modules.py b/track_anything_code/tracker/model/group_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..749ef2386a992a468b7cf631293ebd22036b2777 --- /dev/null +++ b/track_anything_code/tracker/model/group_modules.py @@ -0,0 +1,82 @@ +""" +Group-specific modules +They handle features that also depends on the mask. +Features are typically of shape + batch_size * num_objects * num_channels * H * W + +All of them are permutation equivariant w.r.t. to the num_objects dimension +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def interpolate_groups(g, ratio, mode, align_corners): + batch_size, num_objects = g.shape[:2] + g = F.interpolate(g.flatten(start_dim=0, end_dim=1), + scale_factor=ratio, mode=mode, align_corners=align_corners) + g = g.view(batch_size, num_objects, *g.shape[1:]) + return g + +def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False): + return interpolate_groups(g, ratio, mode, align_corners) + +def downsample_groups(g, ratio=1/2, mode='area', align_corners=None): + return interpolate_groups(g, ratio, mode, align_corners) + + +class GConv2D(nn.Conv2d): + def forward(self, g): + batch_size, num_objects = g.shape[:2] + g = super().forward(g.flatten(start_dim=0, end_dim=1)) + return g.view(batch_size, num_objects, *g.shape[1:]) + + +class GroupResBlock(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + if in_dim == out_dim: + self.downsample = None + else: + self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) + + self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g): + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + if self.downsample is not None: + g = self.downsample(g) + + return out_g + g + + +class MainToGroupDistributor(nn.Module): + def __init__(self, x_transform=None, method='cat', reverse_order=False): + super().__init__() + + self.x_transform = x_transform + self.method = method + self.reverse_order = reverse_order + + def forward(self, x, g): + num_objects = g.shape[1] + + if self.x_transform is not None: + x = self.x_transform(x) + + if self.method == 'cat': + if self.reverse_order: + g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2) + else: + g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2) + elif self.method == 'add': + g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g + else: + raise NotImplementedError + + return g diff --git a/track_anything_code/tracker/model/losses.py b/track_anything_code/tracker/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..60a2894b6f5b330aa4baa56db226e8a59cb8c1ae --- /dev/null +++ b/track_anything_code/tracker/model/losses.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import defaultdict + + +def dice_loss(input_mask, cls_gt): + num_objects = input_mask.shape[1] + losses = [] + for i in range(num_objects): + mask = input_mask[:,i].flatten(start_dim=1) + # background not in mask, so we add one to cls_gt + gt = (cls_gt==(i+1)).float().flatten(start_dim=1) + numerator = 2 * (mask * gt).sum(-1) + denominator = mask.sum(-1) + gt.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + losses.append(loss) + return torch.cat(losses).mean() + + +# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch +class BootstrappedCE(nn.Module): + def __init__(self, start_warm, end_warm, top_p=0.15): + super().__init__() + + self.start_warm = start_warm + self.end_warm = end_warm + self.top_p = top_p + + def forward(self, input, target, it): + if it < self.start_warm: + return F.cross_entropy(input, target), 1.0 + + raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) + num_pixels = raw_loss.numel() + + if it > self.end_warm: + this_p = self.top_p + else: + this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) + loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) + return loss.mean(), this_p + + +class LossComputer: + def __init__(self, config): + super().__init__() + self.config = config + self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) + + def compute(self, data, num_objects, it): + losses = defaultdict(int) + + b, t = data['rgb'].shape[:2] + + losses['total_loss'] = 0 + for ti in range(1, t): + for bi in range(b): + loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it) + losses['p'] += p / b / (t-1) + losses[f'ce_loss_{ti}'] += loss / b + + losses['total_loss'] += losses['ce_loss_%d'%ti] + losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0]) + losses['total_loss'] += losses[f'dice_loss_{ti}'] + + return losses diff --git a/track_anything_code/tracker/model/memory_util.py b/track_anything_code/tracker/model/memory_util.py new file mode 100644 index 0000000000000000000000000000000000000000..faf6197b8c4ea990317476e2e3aeb8952a78aedf --- /dev/null +++ b/track_anything_code/tracker/model/memory_util.py @@ -0,0 +1,80 @@ +import math +import numpy as np +import torch +from typing import Optional + + +def get_similarity(mk, ms, qk, qe): + # used for training/inference and memory reading/memory potentiation + # mk: B x CK x [N] - Memory keys + # ms: B x 1 x [N] - Memory shrinkage + # qk: B x CK x [HW/P] - Query keys + # qe: B x CK x [HW/P] - Query selection + # Dimensions in [] are flattened + CK = mk.shape[1] + mk = mk.flatten(start_dim=2) + ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None + qk = qk.flatten(start_dim=2) + qe = qe.flatten(start_dim=2) if qe is not None else None + + if qe is not None: + # See appendix for derivation + # or you can just trust me ヽ(ー_ー )ノ + mk = mk.transpose(1, 2) + a_sq = (mk.pow(2) @ qe) + two_ab = 2 * (mk @ (qk * qe)) + b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) + similarity = (-a_sq+two_ab-b_sq) + else: + # similar to STCN if we don't have the selection term + a_sq = mk.pow(2).sum(1).unsqueeze(2) + two_ab = 2 * (mk.transpose(1, 2) @ qk) + similarity = (-a_sq+two_ab) + + if ms is not None: + similarity = similarity * ms / math.sqrt(CK) # B*N*HW + else: + similarity = similarity / math.sqrt(CK) # B*N*HW + + return similarity + +def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False): + # normalize similarity with top-k softmax + # similarity: B x N x [HW/P] + # use inplace with care + if top_k is not None: + values, indices = torch.topk(similarity, k=top_k, dim=1) + + x_exp = values.exp_() + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + if inplace: + similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW + affinity = similarity + else: + affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW + else: + maxes = torch.max(similarity, dim=1, keepdim=True)[0] + x_exp = torch.exp(similarity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + indices = None + + if return_usage: + return affinity, affinity.sum(dim=2) + + return affinity + +def get_affinity(mk, ms, qk, qe): + # shorthand used in training with no top-k + similarity = get_similarity(mk, ms, qk, qe) + affinity = do_softmax(similarity) + return affinity + +def readout(affinity, mv): + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T*H*W) + mem = torch.bmm(mo, affinity) + mem = mem.view(B, CV, H, W) + + return mem diff --git a/track_anything_code/tracker/model/modules.py b/track_anything_code/tracker/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6e6af37bc5a0c2c5854faadaef2f8eee9610f6 --- /dev/null +++ b/track_anything_code/tracker/model/modules.py @@ -0,0 +1,250 @@ +""" +modules.py - This file stores the rather boring network blocks. + +x - usually means features that only depends on the image +g - usually means features that also depends on the mask. + They might have an extra "group" or "num_objects" dimension, hence + batch_size * num_objects * num_channels * H * W + +The trailing number of a variable usually denote the stride + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .group_modules import * +from .resnet import resnet18, resnet50 +from .cbam import CBAM + + +class FeatureFusionBlock(nn.Module): + def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim): + super().__init__() + + self.distributor = MainToGroupDistributor() + self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim) + self.attention = CBAM(g_mid_dim) + self.block2 = GroupResBlock(g_mid_dim, g_out_dim) + + def forward(self, x, g): + batch_size, num_objects = g.shape[:2] + + g = self.distributor(x, g) + g = self.block1(g) + r = self.attention(g.flatten(start_dim=0, end_dim=1)) + r = r.view(batch_size, num_objects, *r.shape[1:]) + + g = self.block2(g+r) + + return g + + +class HiddenUpdater(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims, mid_dim, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + + self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1) + + self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g, h): + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + + g = torch.cat([g, h], 2) + + # defined slightly differently than standard GRU, + # namely the new value is generated before the forget gate. + # might provide better gradient but frankly it was initially just an + # implementation error that I never bothered fixing + values = self.transform(g) + forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim]) + update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2]) + new_value = torch.tanh(values[:,:,self.hidden_dim*2:]) + new_h = forget_gate*h*(1-update_gate) + update_gate*new_value + + return new_h + + +class HiddenReinforcer(nn.Module): + # Used in the value encoder, a single GRU + def __init__(self, g_dim, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g, h): + g = torch.cat([g, h], 2) + + # defined slightly differently than standard GRU, + # namely the new value is generated before the forget gate. + # might provide better gradient but frankly it was initially just an + # implementation error that I never bothered fixing + values = self.transform(g) + forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim]) + update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2]) + new_value = torch.tanh(values[:,:,self.hidden_dim*2:]) + new_h = forget_gate*h*(1-update_gate) + update_gate*new_value + + return new_h + + +class ValueEncoder(nn.Module): + def __init__(self, value_dim, hidden_dim, single_object=False): + super().__init__() + + self.single_object = single_object + network = resnet18(pretrained=True, extra_dim=1 if single_object else 2) + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu # 1/2, 64 + self.maxpool = network.maxpool + + self.layer1 = network.layer1 # 1/4, 64 + self.layer2 = network.layer2 # 1/8, 128 + self.layer3 = network.layer3 # 1/16, 256 + + self.distributor = MainToGroupDistributor() + self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim) + if hidden_dim > 0: + self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim) + else: + self.hidden_reinforce = None + + def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): + # image_feat_f16 is the feature from the key encoder + if not self.single_object: + g = torch.stack([masks, others], 2) + else: + g = masks.unsqueeze(2) + g = self.distributor(image, g) + + batch_size, num_objects = g.shape[:2] + g = g.flatten(start_dim=0, end_dim=1) + + g = self.conv1(g) + g = self.bn1(g) # 1/2, 64 + g = self.maxpool(g) # 1/4, 64 + g = self.relu(g) + + g = self.layer1(g) # 1/4 + g = self.layer2(g) # 1/8 + g = self.layer3(g) # 1/16 + + g = g.view(batch_size, num_objects, *g.shape[1:]) + g = self.fuser(image_feat_f16, g) + + if is_deep_update and self.hidden_reinforce is not None: + h = self.hidden_reinforce(g, h) + + return g, h + + +class KeyEncoder(nn.Module): + def __init__(self): + super().__init__() + network = resnet50(pretrained=True) + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu # 1/2, 64 + self.maxpool = network.maxpool + + self.res2 = network.layer1 # 1/4, 256 + self.layer2 = network.layer2 # 1/8, 512 + self.layer3 = network.layer3 # 1/16, 1024 + + def forward(self, f): + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + f4 = self.res2(x) # 1/4, 256 + f8 = self.layer2(f4) # 1/8, 512 + f16 = self.layer3(f8) # 1/16, 1024 + + return f16, f8, f4 + + +class UpsampleBlock(nn.Module): + def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2): + super().__init__() + self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1) + self.distributor = MainToGroupDistributor(method='add') + self.out_conv = GroupResBlock(g_up_dim, g_out_dim) + self.scale_factor = scale_factor + + def forward(self, skip_f, up_g): + skip_f = self.skip_conv(skip_f) + g = upsample_groups(up_g, ratio=self.scale_factor) + g = self.distributor(skip_f, g) + g = self.out_conv(g) + return g + + +class KeyProjection(nn.Module): + def __init__(self, in_dim, keydim): + super().__init__() + + self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) + # shrinkage + self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1) + # selection + self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x, need_s, need_e): + shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None + selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None + + return self.key_proj(x), shrinkage, selection + + +class Decoder(nn.Module): + def __init__(self, val_dim, hidden_dim): + super().__init__() + + self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512) + if hidden_dim > 0: + self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim) + else: + self.hidden_update = None + + self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 + self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 + + self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1) + + def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True): + batch_size, num_objects = memory_readout.shape[:2] + + if self.hidden_update is not None: + g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2)) + else: + g16 = self.fuser(f16, memory_readout) + + g8 = self.up_16_8(f8, g16) + g4 = self.up_8_4(f4, g8) + logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1))) + + if h_out and self.hidden_update is not None: + g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2) + hidden_state = self.hidden_update([g16, g8, g4], hidden_state) + else: + hidden_state = None + + logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False) + logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) + + return hidden_state, logits diff --git a/track_anything_code/tracker/model/network.py b/track_anything_code/tracker/model/network.py new file mode 100644 index 0000000000000000000000000000000000000000..c7492140f790347f9d90e9223d123df9a4c04732 --- /dev/null +++ b/track_anything_code/tracker/model/network.py @@ -0,0 +1,198 @@ +""" +This file defines XMem, the highest level nn.Module interface +During training, it is used by trainer.py +During evaluation, it is used by inference_core.py + +It further depends on modules.py which gives more detailed implementations of sub-modules +""" + +import torch +import torch.nn as nn + +from .aggregate import aggregate +from .modules import * +from .memory_util import * + + +class XMem(nn.Module): + def __init__(self, config, model_path=None, map_location=None): + """ + model_path/map_location are used in evaluation only + map_location is for converting models saved in cuda to cpu + """ + super().__init__() + model_weights = self.init_hyperparameters(config, model_path, map_location) + + self.single_object = config.get('single_object', False) + print(f'Single object mode: {self.single_object}') + + self.key_encoder = KeyEncoder() + self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object) + + # Projection from f16 feature space to key/value space + self.key_proj = KeyProjection(1024, self.key_dim) + + self.decoder = Decoder(self.value_dim, self.hidden_dim) + + if model_weights is not None: + self.load_weights(model_weights, init_as_zero_if_needed=True) + + def encode_key(self, frame, need_sk=True, need_ek=True): + # Determine input shape + if len(frame.shape) == 5: + # shape is b*t*c*h*w + need_reshape = True + b, t = frame.shape[:2] + # flatten so that we can feed them into a 2D CNN + frame = frame.flatten(start_dim=0, end_dim=1) + elif len(frame.shape) == 4: + # shape is b*c*h*w + need_reshape = False + else: + raise NotImplementedError + + f16, f8, f4 = self.key_encoder(frame) + key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek) + + if need_reshape: + # B*C*T*H*W + key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous() + if shrinkage is not None: + shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous() + if selection is not None: + selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous() + + # B*T*C*H*W + f16 = f16.view(b, t, *f16.shape[-3:]) + f8 = f8.view(b, t, *f8.shape[-3:]) + f4 = f4.view(b, t, *f4.shape[-3:]) + + return key, shrinkage, selection, f16, f8, f4 + + def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True): + num_objects = masks.shape[1] + if num_objects != 1: + others = torch.cat([ + torch.sum( + masks[:, [j for j in range(num_objects) if i!=j]] + , dim=1, keepdim=True) + for i in range(num_objects)], 1) + else: + others = torch.zeros_like(masks) + + g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update) + + return g16, h16 + + # Used in training only. + # This step is replaced by MemoryManager in test time + def read_memory(self, query_key, query_selection, memory_key, + memory_shrinkage, memory_value): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + memory_value : B * num_objects * CV * T * H * W + """ + batch_size, num_objects = memory_value.shape[:2] + memory_value = memory_value.flatten(start_dim=1, end_dim=2) + + affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection) + memory = readout(affinity, memory_value) + memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:]) + + return memory + + def segment(self, multi_scale_features, memory_readout, + hidden_state, selector=None, h_out=True, strip_bg=True): + + hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out) + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + + logits, prob = aggregate(prob, dim=1, return_logits=True) + if strip_bg: + # Strip away the background + prob = prob[:, 1:] + + return hidden_state, logits, prob + + def forward(self, mode, *args, **kwargs): + if mode == 'encode_key': + return self.encode_key(*args, **kwargs) + elif mode == 'encode_value': + return self.encode_value(*args, **kwargs) + elif mode == 'read_memory': + return self.read_memory(*args, **kwargs) + elif mode == 'segment': + return self.segment(*args, **kwargs) + else: + raise NotImplementedError + + def init_hyperparameters(self, config, model_path=None, map_location=None): + """ + Init three hyperparameters: key_dim, value_dim, and hidden_dim + If model_path is provided, we load these from the model weights + The actual parameters are then updated to the config in-place + + Otherwise we load it either from the config or default + """ + if model_path is not None: + # load the model and key/value/hidden dimensions with some hacks + # config is updated with the loaded parameters + model_weights = torch.load(model_path, map_location="cpu") + self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0] + self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0] + self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights + if self.disable_hidden: + self.hidden_dim = 0 + else: + self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3 + print(f'Hyperparameters read from the model weights: ' + f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}') + else: + model_weights = None + # load dimensions from config or default + if 'key_dim' not in config: + self.key_dim = 64 + print(f'key_dim not found in config. Set to default {self.key_dim}') + else: + self.key_dim = config['key_dim'] + + if 'value_dim' not in config: + self.value_dim = 512 + print(f'value_dim not found in config. Set to default {self.value_dim}') + else: + self.value_dim = config['value_dim'] + + if 'hidden_dim' not in config: + self.hidden_dim = 64 + print(f'hidden_dim not found in config. Set to default {self.hidden_dim}') + else: + self.hidden_dim = config['hidden_dim'] + + self.disable_hidden = (self.hidden_dim <= 0) + + config['key_dim'] = self.key_dim + config['value_dim'] = self.value_dim + config['hidden_dim'] = self.hidden_dim + + return model_weights + + def load_weights(self, src_dict, init_as_zero_if_needed=False): + # Maps SO weight (without other_mask) to MO weight (with other_mask) + for k in list(src_dict.keys()): + if k == 'value_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + print('Converting weights from single object to multiple objects.') + pads = torch.zeros((64,1,7,7), device=src_dict[k].device) + if not init_as_zero_if_needed: + print('Randomly initialized padding.') + nn.init.orthogonal_(pads) + else: + print('Zero-initialized padding.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + + self.load_state_dict(src_dict) diff --git a/track_anything_code/tracker/model/resnet.py b/track_anything_code/tracker/model/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..984ea3cbfac047537e7de6cfc47108e637e9dde7 --- /dev/null +++ b/track_anything_code/tracker/model/resnet.py @@ -0,0 +1,165 @@ +""" +resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_add_extra_dim(target, source_state, extra_dim=1): + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if not 'num_batches_tracked' in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, + padding=dilation, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + +def resnet18(pretrained=True, extra_dim=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) + return model + +def resnet50(pretrained=True, extra_dim=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) + return model + diff --git a/track_anything_code/tracker/model/trainer.py b/track_anything_code/tracker/model/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc1ced345bed7282e25d0e8abe21b4523947e088 --- /dev/null +++ b/track_anything_code/tracker/model/trainer.py @@ -0,0 +1,244 @@ +""" +trainer.py - warpper and utility functions for network training +Compute loss, back-prop, update parameters, logging, etc. +""" +import datetime +import os +import time +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from network import XMem +from losses import LossComputer +from util.log_integrator import Integrator +from util.image_saver import pool_pairs + + +class XMemTrainer: + def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1): + self.config = config + self.num_frames = config['num_frames'] + self.num_ref_frames = config['num_ref_frames'] + self.deep_update_prob = config['deep_update_prob'] + self.local_rank = local_rank + + self.XMem = nn.parallel.DistributedDataParallel( + XMem(config).cuda(), + device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) + + # Set up logger when local_rank=0 + self.logger = logger + self.save_path = save_path + if logger is not None: + self.last_time = time.time() + self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()]))) + self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) + self.loss_computer = LossComputer(config) + + self.train() + self.optimizer = optim.AdamW(filter( + lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay']) + self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma']) + if config['amp']: + self.scaler = torch.cuda.amp.GradScaler() + + # Logging info + self.log_text_interval = config['log_text_interval'] + self.log_image_interval = config['log_image_interval'] + self.save_network_interval = config['save_network_interval'] + self.save_checkpoint_interval = config['save_checkpoint_interval'] + if config['debug']: + self.log_text_interval = self.log_image_interval = 1 + + def do_pass(self, data, max_it, it=0): + # No need to store the gradient outside training + torch.set_grad_enabled(self._is_train) + + for k, v in data.items(): + if type(v) != list and type(v) != dict and type(v) != int: + data[k] = v.cuda(non_blocking=True) + + out = {} + frames = data['rgb'] + first_frame_gt = data['first_frame_gt'].float() + b = frames.shape[0] + num_filled_objects = [o.item() for o in data['info']['num_objects']] + num_objects = first_frame_gt.shape[2] + selector = data['selector'].unsqueeze(2).unsqueeze(2) + + global_avg = 0 + + with torch.cuda.amp.autocast(enabled=self.config['amp']): + # image features never change, compute once + key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames) + + filler_one = torch.zeros(1, dtype=torch.int64) + hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:])) + v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0]) + values = v16.unsqueeze(3) # add the time dimension + + for ti in range(1, self.num_frames): + if ti <= self.num_ref_frames: + ref_values = values + ref_keys = key[:,:,:ti] + ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None + else: + # pick num_ref_frames random frames + # this is not very efficient but I think we would + # need broadcasting in gather which we don't have + indices = [ + torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1]) + for _ in range(b)] + ref_values = torch.stack([ + values[bi, :, :, indices[bi]] for bi in range(b) + ], 0) + ref_keys = torch.stack([ + key[bi, :, indices[bi]] for bi in range(b) + ], 0) + ref_shrinkage = torch.stack([ + shrinkage[bi, :, indices[bi]] for bi in range(b) + ], 0) if shrinkage is not None else None + + # Segment frame ti + memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None, + ref_keys, ref_shrinkage, ref_values) + hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout, + hidden, selector, h_out=(ti < (self.num_frames-1))) + + # No need to encode the last frame + if ti < (self.num_frames-1): + is_deep_update = np.random.rand() < self.deep_update_prob + v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update) + values = torch.cat([values, v16.unsqueeze(3)], 3) + + out[f'masks_{ti}'] = masks + out[f'logits_{ti}'] = logits + + if self._do_log or self._is_train: + losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it) + + # Logging + if self._do_log: + self.integrator.add_dict(losses) + if self._is_train: + if it % self.log_image_interval == 0 and it != 0: + if self.logger is not None: + images = {**data, **out} + size = (384, 384) + self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it) + + if self._is_train: + + if (it) % self.log_text_interval == 0 and it != 0: + time_spent = time.time()-self.last_time + + if self.logger is not None: + self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) + self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it) + + global_avg = 0.5*(global_avg) + 0.5*(time_spent) + eta_seconds = global_avg * (max_it - it) / 100 + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(f'ETA: {eta_string}') + + self.last_time = time.time() + self.train_integrator.finalize('train', it) + self.train_integrator.reset_except_hooks() + + if it % self.save_network_interval == 0 and it != 0: + if self.logger is not None: + self.save_network(it) + + if it % self.save_checkpoint_interval == 0 and it != 0: + if self.logger is not None: + self.save_checkpoint(it) + + # Backward pass + self.optimizer.zero_grad(set_to_none=True) + if self.config['amp']: + self.scaler.scale(losses['total_loss']).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + losses['total_loss'].backward() + self.optimizer.step() + + self.scheduler.step() + + def save_network(self, it): + if self.save_path is None: + print('Saving has been disabled.') + return + + os.makedirs(os.path.dirname(self.save_path), exist_ok=True) + model_path = f'{self.save_path}_{it}.pth' + torch.save(self.XMem.module.state_dict(), model_path) + print(f'Network saved to {model_path}.') + + def save_checkpoint(self, it): + if self.save_path is None: + print('Saving has been disabled.') + return + + os.makedirs(os.path.dirname(self.save_path), exist_ok=True) + checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth' + checkpoint = { + 'it': it, + 'network': self.XMem.module.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict()} + torch.save(checkpoint, checkpoint_path) + print(f'Checkpoint saved to {checkpoint_path}.') + + def load_checkpoint(self, path): + # This method loads everything and should be used to resume training + map_location = 'cuda:%d' % self.local_rank + checkpoint = torch.load(path, map_location={'cpu': map_location}) + + it = checkpoint['it'] + network = checkpoint['network'] + optimizer = checkpoint['optimizer'] + scheduler = checkpoint['scheduler'] + + map_location = 'cuda:%d' % self.local_rank + self.XMem.module.load_state_dict(network) + self.optimizer.load_state_dict(optimizer) + self.scheduler.load_state_dict(scheduler) + + print('Network weights, optimizer states, and scheduler states loaded.') + + return it + + def load_network_in_memory(self, src_dict): + self.XMem.module.load_weights(src_dict) + print('Network weight loaded from memory.') + + def load_network(self, path): + # This method loads only the network weight and should be used to load a pretrained model + map_location = 'cuda:%d' % self.local_rank + src_dict = torch.load(path, map_location={'cpu': map_location}) + + self.load_network_in_memory(src_dict) + print(f'Network weight loaded from {path}') + + def train(self): + self._is_train = True + self._do_log = True + self.integrator = self.train_integrator + self.XMem.eval() + return self + + def val(self): + self._is_train = False + self._do_log = True + self.XMem.eval() + return self + + def test(self): + self._is_train = False + self._do_log = False + self.XMem.eval() + return self + diff --git a/track_anything_code/tracker/util/__init__.py b/track_anything_code/tracker/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/track_anything_code/tracker/util/mask_mapper.py b/track_anything_code/tracker/util/mask_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..815807bf4b98c6674ab3ede55517f38a29bb59fb --- /dev/null +++ b/track_anything_code/tracker/util/mask_mapper.py @@ -0,0 +1,78 @@ +import numpy as np +import torch + +def all_to_onehot(masks, labels): + if len(masks.shape) == 3: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) + else: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) + + for ni, l in enumerate(labels): + Ms[ni] = (masks == l).astype(np.uint8) + + return Ms + +class MaskMapper: + """ + This class is used to convert a indexed-mask to a one-hot representation. + It also takes care of remapping non-continuous indices + It has two modes: + 1. Default. Only masks with new indices are supposed to go into the remapper. + This is also the case for YouTubeVOS. + i.e., regions with index 0 are not "background", but "don't care". + + 2. Exhaustive. Regions with index 0 are considered "background". + Every single pixel is considered to be "labeled". + """ + def __init__(self): + self.labels = [] + self.remappings = {} + + # if coherent, no mapping is required + self.coherent = True + + def clear_labels(self): + self.labels = [] + self.remappings = {} + # if coherent, no mapping is required + self.coherent = True + + def convert_mask(self, mask, exhaustive=False): + # mask is in index representation, H*W numpy array + labels = np.unique(mask).astype(np.uint8) + labels = labels[labels!=0].tolist() + + new_labels = list(set(labels) - set(self.labels)) + if not exhaustive: + assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' + + # add new remappings + for i, l in enumerate(new_labels): + self.remappings[l] = i+len(self.labels)+1 + if self.coherent and i+len(self.labels)+1 != l: + self.coherent = False + + if exhaustive: + new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) + else: + if self.coherent: + new_mapped_labels = new_labels + else: + new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) + + self.labels.extend(new_labels) + mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() + + # mask num_objects*H*W + return mask, new_mapped_labels + + + def remap_index_mask(self, mask): + # mask is in index representation, H*W numpy array + if self.coherent: + return mask + + new_mask = np.zeros_like(mask) + for l, i in self.remappings.items(): + new_mask[mask==i] = l + return new_mask \ No newline at end of file diff --git a/track_anything_code/tracker/util/range_transform.py b/track_anything_code/tracker/util/range_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1b0b3b2a01a061b9b2220a93cdf7f7a6357bfb --- /dev/null +++ b/track_anything_code/tracker/util/range_transform.py @@ -0,0 +1,12 @@ +import torchvision.transforms as transforms + +im_mean = (124, 116, 104) + +im_normalization = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + +inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) diff --git a/track_anything_code/tracker/util/tensor_util.py b/track_anything_code/tracker/util/tensor_util.py new file mode 100644 index 0000000000000000000000000000000000000000..05189d38e2b0b0d1d08bd7804b8e43418d6da637 --- /dev/null +++ b/track_anything_code/tracker/util/tensor_util.py @@ -0,0 +1,47 @@ +import torch.nn.functional as F + + +def compute_tensor_iu(seg, gt): + intersection = (seg & gt).float().sum() + union = (seg | gt).float().sum() + + return intersection, union + +def compute_tensor_iou(seg, gt): + intersection, union = compute_tensor_iu(seg, gt) + iou = (intersection + 1e-6) / (union + 1e-6) + + return iou + +# STM +def pad_divide_by(in_img, d): + h, w = in_img.shape[-2:] + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) + lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + +def unpad(img, pad): + if len(img.shape) == 4: + if pad[2]+pad[3] > 0: + img = img[:,:,pad[2]:-pad[3],:] + if pad[0]+pad[1] > 0: + img = img[:,:,:,pad[0]:-pad[1]] + elif len(img.shape) == 3: + if pad[2]+pad[3] > 0: + img = img[:,pad[2]:-pad[3],:] + if pad[0]+pad[1] > 0: + img = img[:,:,pad[0]:-pad[1]] + else: + raise NotImplementedError + return img \ No newline at end of file diff --git a/train_code/train_csvd.py b/train_code/train_csvd.py new file mode 100644 index 0000000000000000000000000000000000000000..611ba75e0c10fc999e74295efe50eea391205425 --- /dev/null +++ b/train_code/train_csvd.py @@ -0,0 +1,1008 @@ +#!/usr/bin/env python +''' + This file is to train Stable Video Diffusion with Conditioning design by my peronal implementation which is based on diffusers' training example code. +''' + +import argparse +import logging +import math +import os, sys +import time +import random +import shutil +import warnings +from PIL import Image +from einops import rearrange, repeat +from pathlib import Path +from omegaconf import OmegaConf +import imageio +import cv2 + + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import RandomSampler +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + + +import diffusers +from diffusers import ( + AutoencoderKLTemporalDecoder, + DDPMScheduler, + UniPCMultistepScheduler, +) +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video +from diffusers.utils.import_utils import is_xformers_available +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.torch_utils import randn_tensor +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +if is_wandb_available(): + import wandb + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline +from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel +from svd.temporal_controlnet import ControlNetModel +from utils.img_utils import resize_with_antialiasing +from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian +from data_loader.video_dataset import tokenize_captions +from data_loader.video_this_that_dataset import Video_ThisThat_Dataset, get_thisthat_sam +from train_code.train_svd import import_pretrained_text_encoder + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) +warnings.filterwarnings('ignore') + + +################################################################################################################################################### +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--config_path", + type=str, + default="config/train_image2video_controlnet.yaml", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + + args = parser.parse_args() + return args + + + +def log_validation(vae, unet, controlnet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, + parent_store_folder=None, force_close_flip=False, use_ambiguous_prompt=False): + # This function will also be used in other files + print("Running validation... ") + + + # Init + validation_source_folder = config["validation_img_folder"] + + + # Init the pipeline + pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained( + config["pretrained_model_name_or_path"], # Still based on regular SVD config + vae = vae, + image_encoder = image_encoder, + unet = unet, + revision = None, # Set None directly now + torch_dtype = weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + + # Process all image in the folder + frames_collection = [] + for image_name in sorted(os.listdir(validation_source_folder)): + if accelerator.is_main_process: + if parent_store_folder is None: + validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) + else: + validation_store_folder = os.path.join(parent_store_folder, image_name) + + if os.path.exists(validation_store_folder): + shutil.rmtree(validation_store_folder) + os.makedirs(validation_store_folder) + + image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') + ref_image = load_image(image_path) # [0, 255] Range + ref_image = ref_image.resize((config["width"], config["height"])) + + + # Prepare text prompt + if config["use_text"]: + # Read the file + file_path = os.path.join(validation_source_folder, image_name, "lang.txt") + file = open(file_path, 'r') + prompt = file.readlines()[0] # Only read the first line + if use_ambiguous_prompt: + prompt = prompt.split(" ")[0] + " this to there" + print("We are creating ambiguous prompt, which is: ", prompt) + else: + prompt = "" + # Use the same tokenize process as the dataset preparation stage + tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim + + # Store the prompt for the sanity check + f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") + f.write(prompt) + f.close() + + # Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) + flip = False + if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard + if random.random() < config["flip_aug_prob"]: + if config["use_text"]: + if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) + flip = True + else: + flip = True + if flip: + print("Use flip in validation!") + ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) + + + if config["data_loader_type"] == "thisthat": + condition_img, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(config, + os.path.join(validation_source_folder, image_name), + flip = flip, + store_dir = validation_store_folder, + verbose = True) + else: + raise NotImplementedError("We don't support such data loader type") + + + + # Call the pipeline + with torch.autocast("cuda"): + frames = pipeline( + image = ref_image, + condition_img = condition_img, # numpy [0,1] range + controlnet = accelerator.unwrap_model(controlnet), + prompt = tokenized_prompt, + use_text = config["use_text"], + text_encoder = text_encoder, + height = config["height"], + width = config["width"], + num_frames = config["video_seq_length"], + decode_chunk_size = 8, + motion_bucket_id = reflected_motion_bucket_id, + controlnet_image_index = controlnet_image_index, + coordinate_values = coordinate_values, + num_inference_steps = config["num_inference_steps"], + max_guidance_scale = config["inference_max_guidance_scale"], + fps = 7, + use_instructpix2pix = config["use_instructpix2pix"], + noise_aug_strength = config["inference_noise_aug_strength"], + controlnet_conditioning_scale = config["outer_conditioning_scale"], + inner_conditioning_scale = config["inner_conditioning_scale"], + guess_mode = config["inference_guess_mode"], # False in inference + image_guidance_scale = config["image_guidance_scale"], + ).frames[0] + + for idx, frame in enumerate(frames): + frame.save(os.path.join(validation_store_folder, str(idx)+".png")) + imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames, duration=0.05) + + frames_collection.append(frames) + + + # Cleaning process + del pipeline + torch.cuda.empty_cache() + + return frames_collection # Return resuly based on the need + + +def tensor_to_vae_latent(inputs, vae): + video_length = inputs.shape[1] + + inputs = rearrange(inputs, "b f c h w -> (b f) c h w") + latents = vae.encode(inputs).latent_dist.mode() + latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) # Use f or b to rearrage should have the same effect + latents = latents * vae.config.scaling_factor + + return latents + + + +def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): + """Draws samples from an lognormal distribution.""" + u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 # In the range [0, 1] + # TODO: "* (1 - 2e-7) + 1e-7" is not included in previous code, I add it back, don't why whether there is any influence now + return torch.distributions.Normal(loc, scale).icdf(u).exp() + + +def get_add_time_ids( + unet_config, + expected_add_embed_dim, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance = False, + ): + + # Construct Basic add_time_ids items + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + + # Sanity Check + passed_add_embed_dim = unet_config.addition_time_embed_dim * len(add_time_ids) + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + return add_time_ids +#################################################################################################################################################################### + + + +def main(config): + # Read Config Setting + resume_from_checkpoint = config["resume_from_checkpoint"] + output_dir = config["output_dir"] + logging_name = config["logging_name"] + mixed_precision = config["mixed_precision"] + report_to = config["report_to"] + pretrained_model_name_or_path = config["pretrained_model_name_or_path"] + pretrained_tokenizer_name_or_path = config["pretrained_tokenizer_name_or_path"] + gradient_checkpointing = config["gradient_checkpointing"] + learning_rate = config["learning_rate"] + adam_beta1 = config["adam_beta1"] + adam_beta2 = config["adam_beta2"] + adam_weight_decay = config["adam_weight_decay"] + adam_epsilon = config["adam_epsilon"] + train_batch_size = config["train_batch_size"] + dataloader_num_workers = config["dataloader_num_workers"] + gradient_accumulation_steps = config["gradient_accumulation_steps"] + num_train_iters = config["num_train_iters"] + lr_warmup_steps = config["lr_warmup_steps"] + checkpointing_steps = config["checkpointing_steps"] + process_fps = config["process_fps"] + train_noise_aug_strength = config["train_noise_aug_strength"] + use_8bit_adam = config["use_8bit_adam"] + scale_lr = config["scale_lr"] + conditioning_dropout_prob = config["conditioning_dropout_prob"] + checkpoints_total_limit = config["checkpoints_total_limit"] + validation_step = config["validation_step"] + partial_finetune = config['partial_finetune'] + load_unet_path = config['load_unet_path'] + + if mixed_precision == 'None': # For mixed precision use + mixed_precision = 'no' + + + # Default Setting + revision = None + variant = "fp16" # TODO: 这里进行了调整,不知道会有多少区别,现在跟unet training保持一致 + lr_scheduler = "constant" + max_grad_norm = 1.0 + tracker_project_name = "img2video" + num_videos_per_prompt = 1 + seed = 42 + # No CFG in training now + + + + # Define the accelerator + logging_dir = Path(output_dir, logging_name) + accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps = gradient_accumulation_steps, + mixed_precision = mixed_precision, + log_with = report_to, + project_config = accelerator_project_config, + ) + generator = torch.Generator(device=accelerator.device).manual_seed(seed) + + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + + # Handle the repository creation + if accelerator.is_main_process and resume_from_checkpoint != "latest": # For the latest checkpoint version, we don't need to delete our folders + # Validation file + validation_store_folder = config["validation_store_folder"] + "_" + config["scheduler"] + print("We will remove ", validation_store_folder) + if os.path.exists(validation_store_folder): + archive_name = validation_store_folder + "_archive" + if os.path.exists(archive_name): + shutil.rmtree(archive_name) + print("We will move to archive ", archive_name) + os.rename(validation_store_folder, archive_name) + os.makedirs(validation_store_folder) + + # Output Dir + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + # os.makedirs(output_dir, exist_ok=True) + + # Log + if os.path.exists("runs"): + shutil.rmtree("runs") + + # Copy the config to here + os.system(" cp config/train_image2video_controlnet.yaml " + validation_store_folder + "/") + + + # Load All Module Needed + feature_extractor = CLIPImageProcessor.from_pretrained( + pretrained_model_name_or_path, subfolder="feature_extractor", revision=revision + ) # This instance has now weight, they are just seeting file + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path, subfolder="image_encoder", revision=revision, variant=variant + ) + vae = AutoencoderKLTemporalDecoder.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant + ) + if load_unet_path != None: + print("We will use pretrained UNet path by our, at ", load_unet_path) + unet = UNetSpatioTemporalConditionModel.from_pretrained( + load_unet_path, + subfolder = "unet", + low_cpu_mem_usage = True, + ) # For the variant, we don't have fp16 version, so we will read from fp32 + else: + print("We will still use provided UNet path") + unet = UNetSpatioTemporalConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder = "unet", + low_cpu_mem_usage = True, + variant = variant, + ) + + # Prepare for the tokenizer if use text + tokenizer = AutoTokenizer.from_pretrained( + pretrained_tokenizer_name_or_path, + subfolder = "tokenizer", + revision = revision, + use_fast = False, + ) + + if config["use_text"]: + # Clip Text Encoder + text_encoder_cls = import_pretrained_text_encoder(pretrained_tokenizer_name_or_path, revision) + text_encoder = text_encoder_cls.from_pretrained( + pretrained_tokenizer_name_or_path, subfolder = "text_encoder", revision = revision, variant = variant + ) + else: + text_encoder = None + + # Init for the Controlnet (check if has pretrained path to load) + if config["load_controlnet_path"] != None: + print("We will load pre-trained controlnet from ", config["load_controlnet_path"]) + controlnet = ControlNetModel.from_pretrained(config["load_controlnet_path"], subfolder="controlnet") + else: + controlnet = ControlNetModel.from_unet(unet, load_weights_from_unet=True, conditioning_channels=config["conditioning_channels"]) + + + # Store the config due to the disappearance after accelerator prepare + unet_config = unet.config + expected_add_embed_dim = unet.add_embedding.linear_1.in_features + + + # Freeze vae + feature_extractor + image_encoder, but set unet to trainable + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) # UNet won't be trained in conditioning branch + controlnet.requires_grad_(False) # Will turn back to requires grad later on + if config["use_text"]: + text_encoder.requires_grad_(False) + + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + + # Move vae + unet + image_encoder to gpu and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) # we don't train UNet anymore, so we cast it here + image_encoder.to(accelerator.device, dtype=weight_dtype) + if config["use_text"]: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + + if gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + if accelerator.unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + + + + ################################ Make Training dataset ###################################### + if config["data_loader_type"] == "thisthat": # Only keep thisthat mode now + train_dataset = Video_ThisThat_Dataset(config, accelerator.device, tokenizer=tokenizer) + else: + raise NotImplementedError("We don't support such data loader type") + + sampler = RandomSampler(train_dataset) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + sampler = sampler, + batch_size = train_batch_size, + num_workers = dataloader_num_workers * accelerator.num_processes, + ) + ############################################################################################## + + + + ####################################### Optimizer Setting ############################################################## + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes + ) + + # 8bit adam to save more memory + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + + # Make ControlNet require Grad + controlnet.requires_grad_(True) + + + ###################### For partial fine-tune setting ####################### + parameters_list = [] + for name, para in controlnet.named_parameters(): + if partial_finetune: # The partial finetune we use is to only train attn layers, which will be ~190M params (TODO:needs to check later for exact value) + if not name.find("attn") != -1: # Only block the spatial Transformer + para.requires_grad = False + else: + parameters_list.append(para) + para.requires_grad = True + else: + parameters_list.append(para) + para.requires_grad = True + + # Double check the weight that will be trained + total_params_for_training = 0 + for name, param in controlnet.named_parameters(): + if param.requires_grad: + total_params_for_training += param.numel() + print(name + " requires grad update") + print("Total parameter that will be trained in controlnet has ", total_params_for_training) + ############################################################################# + + # Optimizer creation + optimizer = optimizer_cls( + parameters_list, + lr = learning_rate, + betas = (adam_beta1, adam_beta2), + weight_decay = adam_weight_decay, + eps = adam_epsilon, + ) + + + # Scheduler and Training steps + dataset_length = len(train_dataset) + print("Dataset length read from the train side is ", dataset_length) + num_update_steps_per_epoch = math.ceil(dataset_length / gradient_accumulation_steps) + max_train_steps = num_train_iters * train_batch_size + + # Learning Rate Scheduler (we all use constant) + lr_scheduler = get_scheduler( + "constant", + optimizer = optimizer, + num_warmup_steps = lr_warmup_steps * accelerator.num_processes, + num_training_steps = max_train_steps * accelerator.num_processes, + num_cycles = 1, + power = 1.0, + ) + ####################################################################################################################### + + + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + + + # We need to RECALCULATE our total training steps as the size of the training dataloader may have changed. + print("accelerator.num_processes is ", accelerator.num_processes) + print("num_train_iters is ", num_train_iters) + num_train_epochs = math.ceil(num_train_iters * accelerator.num_processes * gradient_accumulation_steps / dataset_length) + print("num_train_epochs is ", num_train_epochs) + + # We need to initialize the trackers we use, and also store our configuration. + if accelerator.is_main_process: # Only on the main process! + tracker_config = dict(vars(args)) + accelerator.init_trackers(tracker_project_name, tracker_config) + + + + # Train! + logger.info("***** Running training *****") + logger.info(f" Dataset Length = {dataset_length}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {train_batch_size}") + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + + + + # Load the Closest / Best weight TODO: need to check how to use checkpoints from pre-trained weights!!! + global_step = 0 # Catch the current iteration + first_epoch = 0 + if resume_from_checkpoint: # Resume Checkpoints!!!!! + if resume_from_checkpoint != "latest": + path = os.path.basename(resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + print("We will resume the latest weight ", path) + + if path is None: + accelerator.print( + f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." + ) + resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + if accelerator.is_main_process: + print("Initial Learning rate is ", optimizer.param_groups[0]['lr']) + print("global_step will start from ", global_step) + + progress_bar = tqdm( + range(initial_global_step, max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + + + # Prepare tensorboard log + writer = SummaryWriter() + + + ################################### Auxiliary Function ################################################################################################ + def encode_clip(pixel_values, prompt): + ''' Encoder hidden states input source + pixel_values: first frame pixel information + prompt: language prompt with takenized + ''' + + ########################################## Prepare the Text Embedding ##################################################### + # pixel_values is in the range [-1, 1] + pixel_values = resize_with_antialiasing(pixel_values, (224, 224)) + pixel_values = (pixel_values + 1.0) / 2.0 # [-1, 1] -> [0, 1] + + # Normalize the image with for CLIP input + pixel_values = feature_extractor( + images=pixel_values, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + # The following is the same as _encode_image in SVD pipeline + pixel_values = pixel_values.to(device=accelerator.device, dtype=weight_dtype) + image_embeddings = image_encoder(pixel_values).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + encoder_hidden_states = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + ############################################################################################################################# + + + + ########################################## Prepare the Text embedding if needed ############################################# + if config["use_text"]: + text_embeddings = text_encoder(prompt)[0] + + # Concat two embeddings together on dim 1 + encoder_hidden_states = torch.cat((text_embeddings, encoder_hidden_states), dim=1) # 目前先用text_embeddings 再用encoder_hidden_states感觉好一点 + + # Layer norm on the last dim + layer_norm = nn.LayerNorm((78, 1024)).to(device=accelerator.device, dtype=weight_dtype) + encoder_hidden_states_norm = layer_norm(encoder_hidden_states) + + # Return + return encoder_hidden_states_norm + + else: # Just return back default on + return encoder_hidden_states + ############################################################################################################################# + + ######################################################################################################################################################### + + + ############################################################################################################################ + # For the training, we mimic the code from test2image in diffusers TODO: check the data loader conflict + for epoch in range(first_epoch, num_train_epochs): + controlnet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # batch is a dictionary with video_frames and controlnet_condition + video_frames = batch["video_frames"].to(weight_dtype).to(accelerator.device, non_blocking=True) # [-1, 1] range + condition_img = batch["controlnet_condition"].to(dtype=weight_dtype) # [0, 1] range + reflected_motion_bucket_id = batch["reflected_motion_bucket_id"] + controlnet_image_index = batch["controlnet_image_index"] + prompt = batch["prompt"] + + + # Images to VAE latent space + latents = tensor_to_vae_latent(video_frames, vae) # For all frames + + + ##################################### Add Noise ######################################## + bsz, num_frames = latents.shape[:2] + + + # Encode the first frame + conditional_pixel_values = video_frames[:, 0, :, :, :] # First frame + # Following AnimateSomething, we use constant to repace cond_sigmas + conditional_pixel_values = conditional_pixel_values + torch.randn_like(conditional_pixel_values) * train_noise_aug_strength # cond_sigmas + conditional_latents = vae.encode(conditional_pixel_values).latent_dist.mode() + conditional_latents = repeat(conditional_latents, 'b c h w->b f c h w', f=num_frames) # conditional_latents没有noise的成分的 + + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + sigmas = rand_log_normal(shape=[bsz,], loc=config["noise_mean"], scale=config["noise_std"]).to(weight_dtype).to(latents.device) # TODO: 我觉得noise这块,sigma算法是最不确定是否正确的地方 + sigmas = sigmas[:, None, None, None, None] + noisy_latents = latents + torch.randn_like(latents) * sigmas + inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5) # multiplied by c_in in paper + + + # For the encoder hidden states based on the first frame and prompt + encoder_hidden_states = encode_clip(video_frames[:, 0, :, :, :].float(), prompt) # First Frame + Text Prompt + + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if conditioning_dropout_prob != 0: + random_p = torch.rand(bsz, device=latents.device, generator=generator) + + # Sample masks for the encoder_hidden_states (to replace prompts in InstructPix2Pix). + prompt_mask = random_p < 2 * conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final encoder_hidden_states conditioning. + null_conditioning = torch.zeros_like(encoder_hidden_states) # encoder_hidden_states has already been used with .unsqueeze(1) + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original image latents. + image_mask_dtype = conditional_latents.dtype + image_mask = 1 - ((random_p >= conditioning_dropout_prob).to(image_mask_dtype) * (random_p < 3 * conditioning_dropout_prob).to(image_mask_dtype)) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + + # Final image conditioning. + conditional_latents = image_mask * conditional_latents + + # The Concatenation is move downward with the masking feature + + + # GT noise + target = latents + ########################################################################################## + + + ################################ Other Embedding and Conditioning ################################### + reflected_motion_bucket_id = torch.sum(reflected_motion_bucket_id)/len(reflected_motion_bucket_id) + reflected_motion_bucket_id = int(reflected_motion_bucket_id.cpu().detach().numpy()) + # print("Training reflected_motion_bucket_id is ", reflected_motion_bucket_id) + + added_time_ids = get_add_time_ids( + unet_config, + expected_add_embed_dim, + process_fps, + reflected_motion_bucket_id, + train_noise_aug_strength, # Note: noise strength + weight_dtype, + train_batch_size, + num_videos_per_prompt, + ) # The same as SVD pipeline's _get_add_time_ids + added_time_ids = added_time_ids.to(accelerator.device) + + timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device) + ########################################################################################## + + + + ################################### Get ControlNet Output ################################### + + # Transform controlnet_image_index to the data format we want + controlnet_image_index = list(controlnet_image_index.cpu().detach().numpy()[0]) + assert condition_img.shape[1] >= len(controlnet_image_index) + + # Designing the 0/1 mask for Sparse Conditioning + controlnet_conditioning_mask_shape = list(condition_img.shape) + controlnet_conditioning_mask_shape[2] = 1 # frame dim + controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(dtype=weight_dtype).to(accelerator.device) + controlnet_conditioning_mask[:, controlnet_image_index] = 1 + + + # Add vae latent mask to controlnet noise + if config["mask_controlnet_vae"]: + b, f, c, h, w = conditional_latents.shape + + # Create a mask: Value less than the threshold is set to be True + mask = torch.rand((b, f, 1, h, w), device=accelerator.device) < (1-config["mask_proportion"]) # channel sync + # mask[:,0,:,:,:] = 1 # For the first frame, we still keep it + + # Multiply to the conditional latents, we will just make the mean and variance zero to present those with zero masking + masked_conditional_latents = conditional_latents * mask + controlnet_inp_noisy_latents = torch.cat([inp_noisy_latents, masked_conditional_latents], dim=2) + else: + controlnet_inp_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2) + + + # VAE encode + controlnet_cond = condition_img.flatten(0, 1) + controlnet_cond = vae.encode(controlnet_cond).latent_dist.mode() + + + down_block_res_samples, mid_block_res_sample = controlnet( + sample = controlnet_inp_noisy_latents, + timestep = timesteps, + encoder_hidden_states = encoder_hidden_states, + added_time_ids = added_time_ids, + controlnet_cond = controlnet_cond, + return_dict = False, + conditioning_scale = config["outer_conditioning_scale"], + inner_conditioning_scale = config["inner_conditioning_scale"], + guess_mode = False, # No Guess Mode + ) + + ############################################################################################# + + + + ###################################### Predict Noise ######################################## + # Add vae latent mask to controlnet noise + if config["mask_unet_vae"]: + b, f, c, h, w = conditional_latents.shape + + # Create a mask + mask = torch.rand((b, f, 1, h, w), device=accelerator.device) < (1-config["mask_proportion"]) # channel sync + # mask[:,0,:,:,:] = 1 # For the first frame, we still keep it + + # Multiply to the conditional latents, we will just make the mean and variance zero to present those with zero masking + if not config["mask_controlnet_vae"]: + masked_conditional_latents = conditional_latents * mask + unet_inp_noisy_latents = torch.cat([inp_noisy_latents, masked_conditional_latents], dim=2) + else: + unet_inp_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2) + + # Add with controlnet middle output layers + model_pred = unet( + unet_inp_noisy_latents, + timesteps, + encoder_hidden_states, + added_time_ids = added_time_ids, + down_block_additional_residuals = [ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual = mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + + # Denoise the latents + c_out = -sigmas / ((sigmas**2 + 1)**0.5) + c_skip = 1 / (sigmas**2 + 1) + denoised_latents = model_pred * c_out + c_skip * noisy_latents # What our loss will optimize with + weighing = (1 + sigmas ** 2) * (sigmas**-2.0) + ########################################################################################## + + + ############################### Calculate Loss and Update Optimizer ####################### + # MSE loss + loss = torch.mean( + ( weighing.float() * (denoised_latents.float() - target.float())**2 ).reshape(target.shape[0], -1), + dim=1, + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() + train_loss += avg_loss.item() / gradient_accumulation_steps + + # Update Tensorboard + writer.add_scalar('Loss/train-Loss-Step', avg_loss.item()/ gradient_accumulation_steps, global_step) # 我觉得loss的report就用这个avg_loss就行了 + + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: # For ControlNet + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, max_grad_norm) + optimizer.step() + lr_scheduler.step() # I think constant will take no influence here + optimizer.zero_grad(set_to_none=True) + ########################################################################################## + + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + ########################################## Checkpoints ######################################### + if global_step != 0 and global_step % checkpointing_steps == 0: + if accelerator.is_main_process: + start = time.time() + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + print("Save time use " + str(time.time() - start) + " s") + ######################################################################################################## + + + # Update Log + logs = {"step_loss": loss.detach().item()} + progress_bar.set_postfix(**logs) + + + ##################################### Validation per XXX iterations ####################################### + if accelerator.is_main_process: + if global_step > -1 and global_step % validation_step == 0: # Fixed 100 steps to validate + + log_validation( + vae, + unet, + controlnet, + image_encoder, + text_encoder, + tokenizer, + config, + accelerator, + weight_dtype, + global_step, + use_ambiguous_prompt = config["mix_ambiguous"], + ) + + ############################################################################################################### + + # Update Steps and Break if needed global step should be updated together + global_step += 1 + + if global_step >= max_train_steps: + break + + ############################################################################################################################ + + +if __name__ == "__main__": + args = parse_args() + + config = OmegaConf.load(args.config_path) + main(config) diff --git a/train_code/train_svd.py b/train_code/train_svd.py new file mode 100644 index 0000000000000000000000000000000000000000..f30ed13e9b3d6ac157a2838d22f710904037dd40 --- /dev/null +++ b/train_code/train_svd.py @@ -0,0 +1,908 @@ +#!/usr/bin/env python +''' + This file is to train stable video diffusion by my personal implementation which is based on diffusers' training example code. +''' + +import argparse +import logging +import math +import os, sys +import time +import random +import shutil +import warnings +import cv2 +from PIL import Image +from einops import rearrange, repeat +from pathlib import Path +from omegaconf import OmegaConf +import imageio + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import RandomSampler +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKLTemporalDecoder, + DDPMScheduler, +) +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video +from diffusers.utils.import_utils import is_xformers_available +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.torch_utils import randn_tensor +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +if is_wandb_available(): + import wandb + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline +from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel +from data_loader.video_dataset import Video_Dataset, get_video_frames, tokenize_captions +from utils.img_utils import resize_with_antialiasing + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.25.0.dev0") +logger = get_logger(__name__) +warnings.filterwarnings('ignore') + + +################################################################################################################################################### +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--config_path", + type=str, + default="config/train_image2video.yaml", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + + args = parser.parse_args() + return args + + + +def log_validation(vae, unet, image_encoder, text_encoder, tokenizer, config, accelerator, weight_dtype, step, + parent_store_folder = None, force_close_flip = False, use_ambiguous_prompt=False): + # This function will also be used in other files + print("Running validation... ") + + + # Init + validation_source_folder = config["validation_img_folder"] + + + # Init the pipeline + pipeline = StableVideoDiffusionPipeline.from_pretrained( + config["pretrained_model_name_or_path"], + vae = accelerator.unwrap_model(vae), + image_encoder = accelerator.unwrap_model(image_encoder), + unet = accelerator.unwrap_model(unet), + revision = None, # Set None directly now + torch_dtype = weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + + # Process all image in the folder + frames_collection = [] + for image_name in sorted(os.listdir(validation_source_folder)): + if accelerator.is_main_process: + if parent_store_folder is None: + validation_store_folder = os.path.join(config["validation_store_folder"] + "_" + config["scheduler"], "step_" + str(step), image_name) + else: + validation_store_folder = os.path.join(parent_store_folder, image_name) + + if os.path.exists(validation_store_folder): + shutil.rmtree(validation_store_folder) + os.makedirs(validation_store_folder) + + image_path = os.path.join(validation_source_folder, image_name, 'im_0.jpg') + ref_image = load_image(image_path) + ref_image = ref_image.resize((config["width"], config["height"])) + + + # Decide the motion score in SVD (mostly what we use is fix value now) + if config["motion_bucket_id"] is None: + raise NotImplementedError("We need a fixed motion_bucket_id in the config") + else: + reflected_motion_bucket_id = config["motion_bucket_id"] + print("Inference Motion Bucket ID is ", reflected_motion_bucket_id) + + + # Prepare text prompt + if config["use_text"]: + # Read the file + file_path = os.path.join(validation_source_folder, image_name, "lang.txt") + file = open(file_path, 'r') + prompt = file.readlines()[0] # Only read the first line + if use_ambiguous_prompt: + prompt = prompt.split(" ")[0] + " this to there" + print("We are creating ambiguous prompt, which is: ", prompt) + else: + prompt = "" + # Use the same tokenize process as the dataset preparation stage + tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim + + # Store the prompt for the sanity check + f = open(os.path.join(validation_store_folder, "lang_cond.txt"), "a") + f.write(prompt) + f.close() + + + # Flip the image by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) + flip = False + if not force_close_flip: # force_close_flip is True in testing time; else, we cannot match in the same standard + if random.random() < config["flip_aug_prob"]: + if config["use_text"]: + if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) + flip = True + else: + flip = True + if flip: + print("Use flip in validation!") + ref_image = ref_image.transpose(Image.FLIP_LEFT_RIGHT) + + + # Call the model for inference + with torch.autocast("cuda"): + frames = pipeline( + ref_image, + tokenized_prompt, + config["use_text"], + text_encoder, + height = config["height"], + width = config["width"], + num_frames = config["video_seq_length"], + num_inference_steps = config["num_inference_steps"], + decode_chunk_size = 8, + motion_bucket_id = reflected_motion_bucket_id, + fps = 7, + noise_aug_strength = config["inference_noise_aug_strength"], + ).frames[0] + + # Store the frames + # breakpoint() + for idx, frame in enumerate(frames): + frame.save(os.path.join(validation_store_folder, str(idx)+".png")) + imageio.mimsave(os.path.join(validation_store_folder, 'combined.gif'), frames) # gif storage quality is not high, recommend to check png images + + frames_collection.append(frames) + + + # Cleaning process + del pipeline + torch.cuda.empty_cache() + + return frames_collection # Return resuly based on the need + + +def tensor_to_vae_latent(inputs, vae): + video_length = inputs.shape[1] + inputs = rearrange(inputs, "b f c h w -> (b f) c h w") + latents = vae.encode(inputs).latent_dist.mode() + latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) # Use f or b to rearrage should have the same effect + latents = latents * vae.config.scaling_factor + + return latents + + +def import_pretrained_text_encoder(pretrained_model_name_or_path: str, revision: str): + ''' Import Text encoder information + + ''' + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + + else: # No other cases will be considerred + raise ValueError(f"{model_class} is not supported.") + + + +def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): + """Draws samples from an lognormal distribution.""" + u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 + return torch.distributions.Normal(loc, scale).icdf(u).exp() + + +def get_add_time_ids( + unet_config, + expected_add_embed_dim, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + ): + + # Construct Basic add_time_ids items + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = unet_config.addition_time_embed_dim * len(add_time_ids) + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + return add_time_ids + + +#################################################################################################################################################################### + + + +def main(config): + + # Read Config Setting + resume_from_checkpoint = config["resume_from_checkpoint"] + output_dir = config["output_dir"] + logging_name = config["logging_name"] + mixed_precision = config["mixed_precision"] + report_to = config["report_to"] + pretrained_model_name_or_path = config["pretrained_model_name_or_path"] + pretrained_tokenizer_name_or_path = config["pretrained_tokenizer_name_or_path"] + gradient_checkpointing = config["gradient_checkpointing"] + learning_rate = config["learning_rate"] + adam_beta1 = config["adam_beta1"] + adam_beta2 = config["adam_beta2"] + adam_weight_decay = config["adam_weight_decay"] + adam_epsilon = config["adam_epsilon"] + train_batch_size = config["train_batch_size"] + dataloader_num_workers = config["dataloader_num_workers"] + gradient_accumulation_steps = config["gradient_accumulation_steps"] + num_train_iters = config["num_train_iters"] + lr_warmup_steps = config["lr_warmup_steps"] + checkpointing_steps = config["checkpointing_steps"] + process_fps = config["process_fps"] + train_noise_aug_strength = config["train_noise_aug_strength"] + use_8bit_adam = config["use_8bit_adam"] + scale_lr = config["scale_lr"] + conditioning_dropout_prob = config["conditioning_dropout_prob"] + checkpoints_total_limit = config["checkpoints_total_limit"] + validation_step = config["validation_step"] + partial_finetune = config['partial_finetune'] + + + # Default Setting + revision = None + variant = "fp16" + lr_scheduler = "constant" + max_grad_norm = 1.0 + tracker_project_name = "img2video" + num_videos_per_prompt = 1 + seed = 42 + # No CFG in training now + + + + # Define the accelerator + logging_dir = Path(output_dir, logging_name) + accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps = gradient_accumulation_steps, + mixed_precision = mixed_precision, + log_with = report_to, + project_config = accelerator_project_config, + ) + generator = torch.Generator(device=accelerator.device).manual_seed(seed) + + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + + # Handle the repository creation + if accelerator.is_main_process and resume_from_checkpoint != "latest": # For the latest checkpoint version, we don't need to delete our folders + # Validation file + validation_store_folder = config["validation_store_folder"] + "_" + config["scheduler"] + print("We will remove ", validation_store_folder) + if os.path.exists(validation_store_folder): + archive_name = validation_store_folder + "_archive" + if os.path.exists(archive_name): + shutil.rmtree(archive_name) + print("We will move to archive ", archive_name) + os.rename(validation_store_folder, archive_name) + os.makedirs(validation_store_folder) + + # Output Dir + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + # os.makedirs(output_dir, exist_ok=True) + + # Log + if os.path.exists("runs"): + shutil.rmtree("runs") + + # Copy the config to here + os.system(" cp config/train_image2video.yaml " + validation_store_folder + "/") + + + # Load All Module Needed + feature_extractor = CLIPImageProcessor.from_pretrained( + pretrained_model_name_or_path, subfolder="feature_extractor", revision=revision + ) # This instance has now weight, they are just seeting file + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path, subfolder="image_encoder", revision=revision, variant=variant + ) + vae = AutoencoderKLTemporalDecoder.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant + ) + if config["load_unet_path"] != None: + print("We will load UNet from ", config["load_unet_path"]) + unet = UNetSpatioTemporalConditionModel.from_pretrained( + config["load_unet_path"], + subfolder = "unet", + low_cpu_mem_usage = True, + ) # For the variant, we don't have fp16 version, so we will read from fp32 + else: + print("We will only use SVD pretrained UNet") + unet = UNetSpatioTemporalConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder = "unet", + low_cpu_mem_usage = True, + variant = variant, + ) + + # Prepare for the tokenizer if use text + tokenizer = AutoTokenizer.from_pretrained( + pretrained_tokenizer_name_or_path, + subfolder = "tokenizer", + revision = revision, + use_fast = False, + ) + + if config["use_text"]: + # Clip Text Encoder + text_encoder_cls = import_pretrained_text_encoder(pretrained_tokenizer_name_or_path, revision) + text_encoder = text_encoder_cls.from_pretrained( + pretrained_tokenizer_name_or_path, subfolder = "text_encoder", revision = revision, variant = variant + ) + else: + text_encoder = None + + + # Store the config due to the disappearance after accelerator prepare (This is written to handle some unknown phenomenon) + unet_config = unet.config + expected_add_embed_dim = unet.add_embedding.linear_1.in_features + + + # Freeze vae + feature_extractor + image_encoder, but set unet to trainable + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) # Will switch back to train mode later on + if config["use_text"]: + text_encoder.requires_grad_(False) # All set with no grad needed (like VAE) follow other T2I papers + + + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + + # Move vae + image_encoder to gpu and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + # unet.to(accelerator.device, dtype=weight_dtype) + if config["use_text"]: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + + + # Acceleration: `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + + + + ################################ Make Training dataset ############################### + train_dataset = Video_Dataset(config, device = accelerator.device, tokenizer=tokenizer) + sampler = RandomSampler(train_dataset) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + sampler = sampler, + batch_size = train_batch_size, + num_workers = dataloader_num_workers * accelerator.num_processes, + ) + ####################################################################################### + + + ####################################### Optimizer Setting ##################################################################### + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes + ) + + # 8bit adam to save more memory (Usally we need this to save the memory) + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + + # Switch back to unet in training mode + unet.requires_grad_(True) + + + ############################## For partial fine-tune setting ############################## + parameters_list = [] + for name, param in unet.named_parameters(): + if partial_finetune: # The partial finetune we use is to only train attn layers, which will be ~190M params (TODO:needs to check later for exact value) + # Full Spatial: .transformer_blocks. && spatial_ + # Attn + All emb: attn && emb + if name.find("attn") != -1 or name.find("emb") != -1: # Only block the spatial Transformer + parameters_list.append(param) + param.requires_grad = True + else: + param.requires_grad = False + else: + parameters_list.append(param) + param.requires_grad = True + + # Double check what will be trained + total_params_for_training = 0 + # if os.path.exists("param_lists.txt"): + # os.remove("param_lists.txt") + # file1 = open("param_lists.txt","a") + for name, param in unet.named_parameters(): + # file1.write(name + "\n") + if param.requires_grad: + total_params_for_training += param.numel() + print(name + " requires grad update") + print("Total parameter that will be trained has ", total_params_for_training) + ########################################################################################## + + # Optimizer creation + optimizer = optimizer_cls( + parameters_list, + lr = learning_rate, + betas = (adam_beta1, adam_beta2), + weight_decay = adam_weight_decay, + eps = adam_epsilon, + ) + + + # Scheduler and Training steps + dataset_length = len(train_dataset) + print("Dataset length read from the train side is ", dataset_length) + num_update_steps_per_epoch = math.ceil(dataset_length / gradient_accumulation_steps) + max_train_steps = num_train_iters * train_batch_size + + # Learning Rate Scheduler (we all use constant) + lr_scheduler = get_scheduler( + "constant", + optimizer = optimizer, + num_warmup_steps = lr_warmup_steps * accelerator.num_processes, + num_training_steps = max_train_steps * accelerator.num_processes, + num_cycles = 1, + power = 1.0, + ) + ##################################################################################################################################### + + + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + + # We need to RECALCULATE our total training steps as the size of the training dataloader may have changed. + print("accelerator.num_processes is ", accelerator.num_processes) + print("num_train_iters is ", num_train_iters) + num_train_epochs = math.ceil(num_train_iters * accelerator.num_processes * gradient_accumulation_steps / dataset_length) + print("num_train_epochs is ", num_train_epochs) + + # We need to initialize the trackers we use, and also store our configuration. + if accelerator.is_main_process: # Only on the main process! + tracker_config = dict(vars(args)) + accelerator.init_trackers(tracker_project_name, tracker_config) + + + + # Train! + logger.info("***** Running training *****") + logger.info(f" Dataset Length = {dataset_length}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {train_batch_size}") + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + + + + # Load the Closest / Best weight + global_step = 0 # Catch the current iteration + first_epoch = 0 + if resume_from_checkpoint: + if resume_from_checkpoint != "latest": + path = os.path.basename(resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + print("We will resume the latest weight ", path) + + if path is None: # Don't resume + accelerator.print( + f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." + ) + resume_from_checkpoint = None + initial_global_step = 0 + else: # Resume from the closest checkpoint + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + + if accelerator.is_main_process: + print("Initial Learning rate is ", optimizer.param_groups[0]['lr']) + print("global_step will start from ", global_step) + + progress_bar = tqdm( + range(initial_global_step, max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + + + # Prepare tensorboard log + writer = SummaryWriter() + + + ######################################################### Auxiliary Function ################################################################# + def encode_clip(pixel_values, prompt): + ''' Encoder hidden states input source + pixel_values: first frame pixel information + prompt: language prompt with takenized + ''' + + ########################################## Prepare the Text Embedding ##################################################### + # pixel_values is in the range [-1, 1] + pixel_values = resize_with_antialiasing(pixel_values, (224, 224)) + pixel_values = (pixel_values + 1.0) / 2.0 # [-1, 1] -> [0, 1] + + # Normalize the image with for CLIP input + pixel_values = feature_extractor( + images=pixel_values, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + # The following is the same as _encode_image in SVD pipeline + pixel_values = pixel_values.to(device=accelerator.device, dtype=weight_dtype) + image_embeddings = image_encoder(pixel_values).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + encoder_hidden_states = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + ############################################################################################################################# + + + + ########################################## Prepare the Text embedding if needed ############################################# + if config["use_text"]: + text_embeddings = text_encoder(prompt)[0] + + # Concat two embeddings together on dim 1 + encoder_hidden_states = torch.cat((text_embeddings, encoder_hidden_states), dim=1) + + # Layer norm on the last dim + layer_norm = nn.LayerNorm((78, 1024)).to(device=accelerator.device, dtype=weight_dtype) + encoder_hidden_states_norm = layer_norm(encoder_hidden_states) + + # Return + return encoder_hidden_states_norm + + else: # Just return back default on + return encoder_hidden_states + ############################################################################################################################# + + #################################################################################################################################################### + + + ############################################################################################################################ + # For the training, we mimic the code from T2I in diffusers + for epoch in range(first_epoch, num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # batch is a torch tensor with range of [-1, 1] but no other pre-porcessing + video_frames = batch["video_frames"].to(weight_dtype).to(accelerator.device, non_blocking=True) + reflected_motion_bucket_id = batch["reflected_motion_bucket_id"] + prompt = batch["prompt"] + + + # Images to VAE latent space + latents = tensor_to_vae_latent(video_frames, vae) + + + ##################################### Add Noise ######################################## + bsz, num_frames = latents.shape[:2] + + # Encode the first frame + conditional_pixel_values = video_frames[:, 0, :, :, :] # First frame + # Following AnimateSomething, we use constant to repace cond_sigmas + conditional_pixel_values = conditional_pixel_values + torch.randn_like(conditional_pixel_values) * train_noise_aug_strength + conditional_latents = vae.encode(conditional_pixel_values).latent_dist.mode() # mode() returns mean value no std influence + conditional_latents = repeat(conditional_latents, 'b c h w->b f c h w', f=num_frames) # copied across the frame axis to be the same shape as noise + + + # Add noise to the latents according to the noise magnitude at each timestep + # This is the forward diffusion process + sigmas = rand_log_normal(shape=[bsz,], loc=config["noise_mean"], scale=config["noise_std"]).to(latents.device) + sigmas = sigmas[:, None, None, None, None] + noisy_latents = latents + torch.randn_like(latents) * sigmas + inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5) + + + # For the encoder hidden states based on the first frame and prompt + encoder_hidden_states = encode_clip(video_frames[:, 0, :, :, :].float(), prompt) # First Frame + Text Prompt + + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800 (InstructPix2Pix). + if conditioning_dropout_prob != 0: + random_p = torch.rand(bsz, device=latents.device, generator=generator) + + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = torch.zeros_like(encoder_hidden_states) + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = conditional_latents.dtype + image_mask = 1 - ((random_p >= conditioning_dropout_prob).to(image_mask_dtype) * (random_p < 3 * conditioning_dropout_prob).to(image_mask_dtype)) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + + # Final image conditioning. + conditional_latents = image_mask * conditional_latents + + + # Concatenate the `conditional_latents` with the `noisy_latents`. + inp_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2) + + + # GT noise + target = latents + ########################################################################################## + + + ################################ Other Embedding and Conditioning ################################### + reflected_motion_bucket_id = torch.sum(reflected_motion_bucket_id)/len(reflected_motion_bucket_id) + reflected_motion_bucket_id = int(reflected_motion_bucket_id.cpu().detach().numpy()) + # print("Training reflected_motion_bucket_id is ", reflected_motion_bucket_id) + + added_time_ids = get_add_time_ids( + unet_config, + expected_add_embed_dim, + process_fps, + reflected_motion_bucket_id, + train_noise_aug_strength, + weight_dtype, + train_batch_size, + num_videos_per_prompt, + ) # The same as SVD pipeline's _get_add_time_ids + added_time_ids = added_time_ids.to(accelerator.device) + + timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device) + ##################################################################################################### + + + + ###################################### Predict Noise ###################################### + model_pred = unet( + inp_noisy_latents, + timesteps, + encoder_hidden_states, + added_time_ids = added_time_ids + ).sample + + # Denoise the latents + c_out = -sigmas / ((sigmas**2 + 1)**0.5) + c_skip = 1 / (sigmas**2 + 1) + denoised_latents = model_pred * c_out + c_skip * noisy_latents + weighing = (1 + sigmas ** 2) * (sigmas**-2.0) + ########################################################################################## + + + ############################### Calculate Loss and Update Optimizer ####################### + # MSE loss + loss = torch.mean( + ( weighing.float() * (denoised_latents.float() - target.float())**2 ).reshape(target.shape[0], -1), + dim=1, + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() + train_loss += avg_loss.item() / gradient_accumulation_steps + + # Update Tensorboard + writer.add_scalar('Loss/train-Loss-Step', avg_loss, global_step) + + + # Backpropagate + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + ########################################################################################## + + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + ########################################## Checkpoints ######################################### + if global_step != 0 and global_step % checkpointing_steps == 0: + if accelerator.is_main_process: + start = time.time() + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + print("Save time use " + str(time.time() - start) + " s") + ######################################################################################################## + + + # Update Log + logs = {"step_loss": loss.detach().item(), "lr": optimizer.param_groups[0]['lr']} + progress_bar.set_postfix(**logs) + + + ##################################### Validation per XXX iterations ####################################### + if accelerator.is_main_process: + if global_step % validation_step == 0: # Fixed 100 steps to validate + + if config["validation_img_folder"] is not None: + log_validation( + vae, + unet, + image_encoder, + text_encoder, + tokenizer, + config, + accelerator, + weight_dtype, + global_step, + use_ambiguous_prompt = config["mix_ambiguous"], + ) + + ############################################################################################################### + + # Update Steps and Break if needed + global_step += 1 + + if global_step >= max_train_steps: + break + + ############################################################################################################################ + + +if __name__ == "__main__": + args = parse_args() + + config = OmegaConf.load(args.config_path) + main(config) diff --git a/utils/img_utils.py b/utils/img_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0d6e186dc3499e44a67c6981c7d6f68b470e27 --- /dev/null +++ b/utils/img_utils.py @@ -0,0 +1,140 @@ +import os, sys +import cv2 +import numpy as np +import torch +from torch.nn import functional as F +import random +import math + + +def np2tensor(np_frame): + return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).cuda().float()/255 + +def tensor2np(tensor): + # tensor should be batch size1 and cannot be grayscale input + return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (1, 2, 0))) * 255 + + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out + + +def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + ''' Resize with antialiasing (from StableVideoDiffusion Pipeline) + Args: + input (numpy): The input image + size (tuple): (height, width) in int format + ''' + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + + +def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if images.ndim == 3: + images = images[None, ...] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images diff --git a/utils/optical_flow_utils.py b/utils/optical_flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0bb8b1d14123f257fcaddce9a39e6fd62d81968 --- /dev/null +++ b/utils/optical_flow_utils.py @@ -0,0 +1,219 @@ +import numpy as np + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] in range [0, 255] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) + + + +def filter_uv(flow, threshold_factor = 0.1, sample_prob = 1.0): + ''' + Args: + flow (numpy): A 2-dim array that stores x and y change in optical flow + threshold_factor (float): Prob of discarding outliers vector + sample_prob (float): The selection rate of how much proportion of points we need to store + ''' + u = flow[:,:,0] + v = flow[:,:,1] + + # Filter out those less than the threshold + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + threshold = threshold_factor * rad_max + flow[:,:,0][rad < threshold] = 0 + flow[:,:,1][rad < threshold] = 0 + + + # Randomly sample based on sample_prob + zero_prob = 1 - sample_prob + random_array = np.random.randn(*flow.shape) + random_array[random_array < zero_prob] = 0 + random_array[random_array >= zero_prob] = 1 + flow = flow * random_array + + + return flow + + + +############################################# The following is for dilation method in optical flow ###################################### +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + Args: + kernel_size (int): + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel