from model.video_diffusion.models.controlnet3d import ControlNet3DModel from model.video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel from model.video_diffusion.pipelines.pipeline_stable_diffusion_controlnet3d import Controlnet3DStableDiffusionPipeline from transformers import DPTForDepthEstimation from model.annotator.hed import HEDNetwork import torch from einops import rearrange,repeat import imageio import numpy as np import cv2 import torch.nn.functional as F from PIL import Image import argparse import tempfile import os import gradio as gr control_mode = 'depth' control_net_path = f"wf-genius/controlavideo-{control_mode}" unet = UNetPseudo3DConditionModel.from_pretrained(control_net_path, torch_dtype = torch.float16, subfolder='unet', ).to("cuda") controlnet = ControlNet3DModel.from_pretrained(control_net_path, torch_dtype = torch.float16, subfolder='controlnet', ).to("cuda") if control_mode == 'depth': annotator_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") elif control_mode == 'canny': annotator_model = None elif control_mode == 'hed': # firstly download from https://huggingface.co/wf-genius/controlavideo-hed/resolve/main/hed-network.pth annotator_model = HEDNetwork('hed-network.pth').to("cuda") video_controlnet_pipe = Controlnet3DStableDiffusionPipeline.from_pretrained(control_net_path, unet=unet, controlnet=controlnet, annotator_model=annotator_model, torch_dtype = torch.float16, ).to("cuda") def to_video(frames, fps: int) -> str: out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps) for frame in frames: writer.append_data(np.array(frame)) writer.close() return out_file.name def inference(input_video, prompt, seed, num_inference_steps, guidance_scale, sampling_rate, video_scale, init_noise_thres, each_sample_frame, iter_times, h, w, ): num_sample_frames = iter_times * each_sample_frame testing_prompt = [prompt] np_frames, fps_vid = Controlnet3DStableDiffusionPipeline.get_frames_preprocess(input_video, num_frames=num_sample_frames, sampling_rate=sampling_rate, return_np=True) if control_mode == 'depth': frames = torch.from_numpy(np_frames).div(255) * 2 - 1 frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0) frames = rearrange(frames, 'b c f h w -> (b f) c h w') control_maps = video_controlnet_pipe.get_depth_map(frames, h, w, return_standard_norm=False) # (b f) 1 h w elif control_mode == 'canny': control_maps = np.stack([cv2.Canny(inp, 100, 200) for inp in np_frames]) control_maps = repeat(control_maps, 'f h w -> f c h w',c=1) control_maps = torch.from_numpy(control_maps).div(255) # 0~1 elif control_mode == 'hed': control_maps = np.stack([video_controlnet_pipe.get_hed_map(inp) for inp in np_frames]) control_maps = repeat(control_maps, 'f h w -> f c h w',c=1) control_maps = torch.from_numpy(control_maps).div(255) # 0~1 control_maps = control_maps.to(dtype=controlnet.dtype, device=controlnet.device) control_maps = F.interpolate(control_maps, size=(h,w), mode='bilinear', align_corners=False) control_maps = rearrange(control_maps, "(b f) c h w -> b c f h w", f=num_sample_frames) if control_maps.shape[1] == 1: control_maps = repeat(control_maps, 'b c f h w -> b (n c) f h w', n=3) frames = torch.from_numpy(np_frames).div(255) frames = rearrange(frames, 'f h w c -> f c h w') v2v_input_frames = torch.nn.functional.interpolate( frames, size=(h, w), mode="bicubic", antialias=True, ) v2v_input_frames = rearrange(v2v_input_frames, '(b f) c h w -> b c f h w ', f=num_sample_frames) out = [] for i in range(num_sample_frames//each_sample_frame): out1 = video_controlnet_pipe( # controlnet_hint= control_maps[:,:,:each_sample_frame,:,:], # images= v2v_input_frames[:,:,:each_sample_frame,:,:], controlnet_hint=control_maps[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else control_maps[:,:,:each_sample_frame,:,:], images=v2v_input_frames[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else v2v_input_frames[:,:,:each_sample_frame,:,:], first_frame_output=out[-1] if i>0 else None, prompt=testing_prompt, num_inference_steps=num_inference_steps, width=w, height=h, guidance_scale=guidance_scale, generator=[torch.Generator(device="cuda").manual_seed(seed)], video_scale = video_scale, init_noise_by_residual_thres = init_noise_thres, # residual-based init. larger thres ==> more smooth. controlnet_conditioning_scale=1.0, fix_first_frame=True, in_domain=True, ) out1 = out1.images[0] if len(out1) > 1: out1 = out1[1:] # drop the first frame out.extend(out1) return to_video(out, 8) examples = [ ["bear.mp4", "a bear walking through stars, artstation"], ["car-shadow.mp4", "a car, sunset, cartoon style, artstation."], ["libby.mp4", "a dog running, chinese ink painting."], ] def preview_inference( input_video, prompt, seed, num_inference_steps, guidance_scale, sampling_rate, video_scale, init_noise_thres, each_sample_frame,iter_times, h, w, ): return inference(input_video, prompt, seed, num_inference_steps, guidance_scale, sampling_rate, 0.0, 0.0, 1, 1, h, w,) if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_video = gr.Video( label="Input Video", source='upload', format="mp4", visible=True) with gr.Column(): init_noise_thres = gr.Slider(0, 1, value=0.1, step=0.1, label="init_noise_thress") each_sample_frame = gr.Slider(6, 16, value=8, step=1, label="each_sample_frame") iter_times = gr.Slider(1, 4, value=1, step=1, label="iter_times") sampling_rate = gr.Slider(1, 8, value=3, step=1, label="sampling_rate") h = gr.Slider(256, 768, value=512, step=64, label="height") w = gr.Slider(256, 768, value=512, step=64, label="width") with gr.Column(): seed = gr.Slider(0, 6666, value=1, step=1, label="seed") num_inference_steps = gr.Slider(5, 50, value=20, step=1, label="num_inference_steps") guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="guidance_scale") video_scale = gr.Slider(0, 2.5, value=1.5, step=0.1, label="video_scale") prompt = gr.Textbox(label='Prompt') # preview_button = gr.Button('Preview') run_button = gr.Button('Generate Video') with gr.Column(): result = gr.Video(label="Generated Video") inputs = [ input_video, prompt, seed, num_inference_steps, guidance_scale, sampling_rate, video_scale, init_noise_thres, each_sample_frame, iter_times, h, w, ] gr.Examples(examples=examples, inputs=inputs, outputs=result, fn=inference, cache_examples=False, run_on_click=False, ) run_button.click(fn=inference, inputs=inputs, outputs=result,) # preview_button.click(fn=preview_inference, # inputs=inputs, # outputs=result,) demo.launch(server_name="0.0.0.0", server_port=7860) # TODO # 1. preview # 2. params