Spaces:
Runtime error
Runtime error
from model.video_diffusion.models.controlnet3d import ControlNet3DModel | |
from model.video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel | |
from model.video_diffusion.pipelines.pipeline_stable_diffusion_controlnet3d import Controlnet3DStableDiffusionPipeline | |
from transformers import DPTForDepthEstimation | |
from model.annotator.hed import HEDNetwork | |
import torch | |
from einops import rearrange,repeat | |
import imageio | |
import numpy as np | |
import cv2 | |
import torch.nn.functional as F | |
from PIL import Image | |
import argparse | |
import tempfile | |
import os | |
import gradio as gr | |
control_mode = 'depth' | |
control_net_path = f"wf-genius/controlavideo-{control_mode}" | |
unet = UNetPseudo3DConditionModel.from_pretrained(control_net_path, | |
torch_dtype = torch.float16, | |
subfolder='unet', | |
).to("cuda") | |
controlnet = ControlNet3DModel.from_pretrained(control_net_path, | |
torch_dtype = torch.float16, | |
subfolder='controlnet', | |
).to("cuda") | |
if control_mode == 'depth': | |
annotator_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") | |
elif control_mode == 'canny': | |
annotator_model = None | |
elif control_mode == 'hed': | |
# firstly download from https://huggingface.co/wf-genius/controlavideo-hed/resolve/main/hed-network.pth | |
annotator_model = HEDNetwork('hed-network.pth').to("cuda") | |
video_controlnet_pipe = Controlnet3DStableDiffusionPipeline.from_pretrained(control_net_path, unet=unet, | |
controlnet=controlnet, annotator_model=annotator_model, | |
torch_dtype = torch.float16, | |
).to("cuda") | |
def to_video(frames, fps: int) -> str: | |
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps) | |
for frame in frames: | |
writer.append_data(np.array(frame)) | |
writer.close() | |
return out_file.name | |
def inference(input_video, | |
prompt, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
sampling_rate, | |
video_scale, | |
init_noise_thres, | |
each_sample_frame, | |
iter_times, | |
h, | |
w, | |
): | |
num_sample_frames = iter_times * each_sample_frame | |
testing_prompt = [prompt] | |
np_frames, fps_vid = Controlnet3DStableDiffusionPipeline.get_frames_preprocess(input_video, num_frames=num_sample_frames, sampling_rate=sampling_rate, return_np=True) | |
if control_mode == 'depth': | |
frames = torch.from_numpy(np_frames).div(255) * 2 - 1 | |
frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0) | |
frames = rearrange(frames, 'b c f h w -> (b f) c h w') | |
control_maps = video_controlnet_pipe.get_depth_map(frames, h, w, return_standard_norm=False) # (b f) 1 h w | |
elif control_mode == 'canny': | |
control_maps = np.stack([cv2.Canny(inp, 100, 200) for inp in np_frames]) | |
control_maps = repeat(control_maps, 'f h w -> f c h w',c=1) | |
control_maps = torch.from_numpy(control_maps).div(255) # 0~1 | |
elif control_mode == 'hed': | |
control_maps = np.stack([video_controlnet_pipe.get_hed_map(inp) for inp in np_frames]) | |
control_maps = repeat(control_maps, 'f h w -> f c h w',c=1) | |
control_maps = torch.from_numpy(control_maps).div(255) # 0~1 | |
control_maps = control_maps.to(dtype=controlnet.dtype, device=controlnet.device) | |
control_maps = F.interpolate(control_maps, size=(h,w), mode='bilinear', align_corners=False) | |
control_maps = rearrange(control_maps, "(b f) c h w -> b c f h w", f=num_sample_frames) | |
if control_maps.shape[1] == 1: | |
control_maps = repeat(control_maps, 'b c f h w -> b (n c) f h w', n=3) | |
frames = torch.from_numpy(np_frames).div(255) | |
frames = rearrange(frames, 'f h w c -> f c h w') | |
v2v_input_frames = torch.nn.functional.interpolate( | |
frames, | |
size=(h, w), | |
mode="bicubic", | |
antialias=True, | |
) | |
v2v_input_frames = rearrange(v2v_input_frames, '(b f) c h w -> b c f h w ', f=num_sample_frames) | |
out = [] | |
for i in range(num_sample_frames//each_sample_frame): | |
out1 = video_controlnet_pipe( | |
# controlnet_hint= control_maps[:,:,:each_sample_frame,:,:], | |
# images= v2v_input_frames[:,:,:each_sample_frame,:,:], | |
controlnet_hint=control_maps[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else control_maps[:,:,:each_sample_frame,:,:], | |
images=v2v_input_frames[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else v2v_input_frames[:,:,:each_sample_frame,:,:], | |
first_frame_output=out[-1] if i>0 else None, | |
prompt=testing_prompt, | |
num_inference_steps=num_inference_steps, | |
width=w, | |
height=h, | |
guidance_scale=guidance_scale, | |
generator=[torch.Generator(device="cuda").manual_seed(seed)], | |
video_scale = video_scale, | |
init_noise_by_residual_thres = init_noise_thres, # residual-based init. larger thres ==> more smooth. | |
controlnet_conditioning_scale=1.0, | |
fix_first_frame=True, | |
in_domain=True, | |
) | |
out1 = out1.images[0] | |
if len(out1) > 1: | |
out1 = out1[1:] # drop the first frame | |
out.extend(out1) | |
return to_video(out, 8) | |
examples = [ | |
["bear.mp4", | |
"a bear walking through stars, artstation"], | |
["car-shadow.mp4", | |
"a car, sunset, cartoon style, artstation."], | |
["libby.mp4", | |
"a dog running, chinese ink painting."], | |
] | |
def preview_inference( | |
input_video, | |
prompt, seed, | |
num_inference_steps, guidance_scale, | |
sampling_rate, video_scale, init_noise_thres, | |
each_sample_frame,iter_times, h, w, | |
): | |
return inference(input_video, | |
prompt, seed, | |
num_inference_steps, guidance_scale, | |
sampling_rate, 0.0, 0.0, 1, 1, h, w,) | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="Input Video", source='upload', format="mp4", visible=True) | |
with gr.Column(): | |
init_noise_thres = gr.Slider(0, 1, value=0.1, step=0.1, label="init_noise_thress") | |
each_sample_frame = gr.Slider(6, 16, value=8, step=1, label="each_sample_frame") | |
iter_times = gr.Slider(1, 4, value=1, step=1, label="iter_times") | |
sampling_rate = gr.Slider(1, 8, value=3, step=1, label="sampling_rate") | |
h = gr.Slider(256, 768, value=512, step=64, label="height") | |
w = gr.Slider(256, 768, value=512, step=64, label="width") | |
with gr.Column(): | |
seed = gr.Slider(0, 6666, value=1, step=1, label="seed") | |
num_inference_steps = gr.Slider(5, 50, value=20, step=1, label="num_inference_steps") | |
guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="guidance_scale") | |
video_scale = gr.Slider(0, 2.5, value=1.5, step=0.1, label="video_scale") | |
prompt = gr.Textbox(label='Prompt') | |
# preview_button = gr.Button('Preview') | |
run_button = gr.Button('Generate Video') | |
# with gr.Column(scale=1): | |
result = gr.Video(label="Generated Video") | |
inputs = [ | |
input_video, | |
prompt, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
sampling_rate, | |
video_scale, | |
init_noise_thres, | |
each_sample_frame, | |
iter_times, | |
h, | |
w, | |
] | |
gr.Examples(examples=examples, | |
inputs=inputs, | |
outputs=result, | |
fn=inference, | |
cache_examples=False, | |
run_on_click=False, | |
) | |
run_button.click(fn=inference, | |
inputs=inputs, | |
outputs=result,) | |
# preview_button.click(fn=preview_inference, | |
# inputs=inputs, | |
# outputs=result,) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |