Spaces:
Build error
Build error
Haoxin Chen
commited on
Commit
·
32619a4
1
Parent(s):
8cdb359
add variable resolution and frame
Browse files- app.py +7 -5
- videocontrol_test.py +34 -14
- videocrafter_test.py +4 -0
app.py
CHANGED
@@ -15,7 +15,7 @@ t2v_examples = [
|
|
15 |
]
|
16 |
|
17 |
control_examples = [
|
18 |
-
['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1]
|
19 |
]
|
20 |
|
21 |
def videocrafter_demo(result_dir='./tmp/'):
|
@@ -23,7 +23,7 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
23 |
videocontrol = VideoControl(result_dir)
|
24 |
with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
|
25 |
gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
|
26 |
-
<a style='font-size:18px;color: #
|
27 |
#######t2v#######
|
28 |
with gr.Tab(label="Text2Video"):
|
29 |
with gr.Column():
|
@@ -70,7 +70,9 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
70 |
with gr.Row():
|
71 |
vc_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="vc_steps", label="Sampling steps", value=50)
|
72 |
frame_stride = gr.Slider(minimum=0 , maximum=100, step=1, label='Frame Stride', value=0, elem_id="vc_frame_stride")
|
73 |
-
|
|
|
|
|
74 |
vc_end_btn = gr.Button("Send")
|
75 |
with gr.Tab(label='Result'):
|
76 |
vc_output_info = gr.Text(label='Info')
|
@@ -79,12 +81,12 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
79 |
vc_output_video = gr.Video(label="Generated Video").style(width=256)
|
80 |
|
81 |
gr.Examples(examples=control_examples,
|
82 |
-
inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
|
83 |
outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
|
84 |
fn = videocontrol.get_video,
|
85 |
cache_examples=os.getenv('SYSTEM') == 'spaces',
|
86 |
)
|
87 |
-
vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
|
88 |
outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
|
89 |
fn = videocontrol.get_video
|
90 |
)
|
|
|
15 |
]
|
16 |
|
17 |
control_examples = [
|
18 |
+
['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1, 16, 256]
|
19 |
]
|
20 |
|
21 |
def videocrafter_demo(result_dir='./tmp/'):
|
|
|
23 |
videocontrol = VideoControl(result_dir)
|
24 |
with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
|
25 |
gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
|
26 |
+
<a style='font-size:18px;color: #000000' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
|
27 |
#######t2v#######
|
28 |
with gr.Tab(label="Text2Video"):
|
29 |
with gr.Column():
|
|
|
70 |
with gr.Row():
|
71 |
vc_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="vc_steps", label="Sampling steps", value=50)
|
72 |
frame_stride = gr.Slider(minimum=0 , maximum=100, step=1, label='Frame Stride', value=0, elem_id="vc_frame_stride")
|
73 |
+
with gr.Row():
|
74 |
+
resolution = gr.Slider(minimum=128 , maximum=512, step=8, label='Long Side Resolution', value=256, elem_id="vc_resolution")
|
75 |
+
video_frames = gr.Slider(minimum=8 , maximum=64, step=1, label='Video Frame Num', value=16, elem_id="vc_video_frames")
|
76 |
vc_end_btn = gr.Button("Send")
|
77 |
with gr.Tab(label='Result'):
|
78 |
vc_output_info = gr.Text(label='Info')
|
|
|
81 |
vc_output_video = gr.Video(label="Generated Video").style(width=256)
|
82 |
|
83 |
gr.Examples(examples=control_examples,
|
84 |
+
inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta, video_frames, resolution],
|
85 |
outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
|
86 |
fn = videocontrol.get_video,
|
87 |
cache_examples=os.getenv('SYSTEM') == 'spaces',
|
88 |
)
|
89 |
+
vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta, video_frames, resolution],
|
90 |
outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
|
91 |
fn = videocontrol.get_video
|
92 |
)
|
videocontrol_test.py
CHANGED
@@ -50,7 +50,8 @@ class VideoControl:
|
|
50 |
config_path = "models/adapter_t2v_depth/model_config.yaml"
|
51 |
ckpt_path = "models/base_t2v/model.ckpt"
|
52 |
adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
|
53 |
-
|
|
|
54 |
config = OmegaConf.load(config_path)
|
55 |
model_config = config.pop("model", OmegaConf.create())
|
56 |
model = instantiate_from_config(model_config)
|
@@ -59,10 +60,18 @@ class VideoControl:
|
|
59 |
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
|
60 |
model.eval()
|
61 |
self.model = model
|
62 |
-
self.resolution=256
|
63 |
-
self.spatial_transform = transforms_video.CenterCropVideo(self.resolution)
|
64 |
|
65 |
-
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if vc_steps > 60:
|
67 |
vc_steps = 60
|
68 |
## load video
|
@@ -74,32 +83,43 @@ class VideoControl:
|
|
74 |
os.remove(input_video)
|
75 |
return 'please input video', None, None, None
|
76 |
|
77 |
-
if h
|
78 |
-
scale = h /
|
79 |
else:
|
80 |
-
scale = w /
|
81 |
h = math.ceil(h / scale)
|
82 |
w = math.ceil(w / scale)
|
83 |
try:
|
84 |
-
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=
|
85 |
except:
|
86 |
os.remove(input_video)
|
87 |
return 'load video error', None, None, None
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
print('video shape', video.shape)
|
90 |
|
91 |
-
|
92 |
bs = 1
|
93 |
channels = self.model.channels
|
94 |
-
frames = self.model.temporal_length
|
95 |
-
|
|
|
96 |
|
97 |
## inference
|
98 |
start = time.time()
|
99 |
prompt = input_prompt
|
100 |
video = video.unsqueeze(0).to("cuda")
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
103 |
batch_samples = batch_samples[0]
|
104 |
os.makedirs(self.savedir, exist_ok=True)
|
105 |
filename = prompt
|
|
|
50 |
config_path = "models/adapter_t2v_depth/model_config.yaml"
|
51 |
ckpt_path = "models/base_t2v/model.ckpt"
|
52 |
adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
|
53 |
+
if os.path.exists('/dev/shm/model.ckpt'):
|
54 |
+
ckpt_path='/dev/shm/model.ckpt'
|
55 |
config = OmegaConf.load(config_path)
|
56 |
model_config = config.pop("model", OmegaConf.create())
|
57 |
model = instantiate_from_config(model_config)
|
|
|
60 |
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
|
61 |
model.eval()
|
62 |
self.model = model
|
|
|
|
|
63 |
|
64 |
+
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
if resolution > 512:
|
67 |
+
resolution = 512
|
68 |
+
if resolution < 64:
|
69 |
+
resolution = 64
|
70 |
+
if video_frames > 64:
|
71 |
+
video_frames = 64
|
72 |
+
|
73 |
+
resolution = int(resolution//64)*64
|
74 |
+
|
75 |
if vc_steps > 60:
|
76 |
vc_steps = 60
|
77 |
## load video
|
|
|
83 |
os.remove(input_video)
|
84 |
return 'please input video', None, None, None
|
85 |
|
86 |
+
if h > w:
|
87 |
+
scale = h / resolution
|
88 |
else:
|
89 |
+
scale = w / resolution
|
90 |
h = math.ceil(h / scale)
|
91 |
w = math.ceil(w / scale)
|
92 |
try:
|
93 |
+
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames)
|
94 |
except:
|
95 |
os.remove(input_video)
|
96 |
return 'load video error', None, None, None
|
97 |
+
if h > w:
|
98 |
+
w = int(w//64)*64
|
99 |
+
else:
|
100 |
+
h = int(h//64)*64
|
101 |
+
spatial_transform = transforms_video.CenterCropVideo((h,w))
|
102 |
+
video = spatial_transform(video)
|
103 |
print('video shape', video.shape)
|
104 |
|
105 |
+
rh, rw = h//8, w//8
|
106 |
bs = 1
|
107 |
channels = self.model.channels
|
108 |
+
# frames = self.model.temporal_length
|
109 |
+
frames = video_frames
|
110 |
+
noise_shape = [bs, channels, frames, rh, rw]
|
111 |
|
112 |
## inference
|
113 |
start = time.time()
|
114 |
prompt = input_prompt
|
115 |
video = video.unsqueeze(0).to("cuda")
|
116 |
+
try:
|
117 |
+
with torch.no_grad():
|
118 |
+
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
|
119 |
+
except:
|
120 |
+
torch.cuda.empty_cache()
|
121 |
+
info_str="OOM, please enter a smaller resolution or smaller frame num"
|
122 |
+
return info_str, None, None, None
|
123 |
batch_samples = batch_samples[0]
|
124 |
os.makedirs(self.savedir, exist_ok=True)
|
125 |
filename = prompt
|
videocrafter_test.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from omegaconf import OmegaConf
|
3 |
|
4 |
from lvdm.samplers.ddim import DDIMSampler
|
@@ -29,6 +30,8 @@ class Text2Video():
|
|
29 |
self.download_model()
|
30 |
config_file = 'models/base_t2v/model_config.yaml'
|
31 |
ckpt_path = 'models/base_t2v/model.ckpt'
|
|
|
|
|
32 |
config = OmegaConf.load(config_file)
|
33 |
self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt',
|
34 |
'models/videolora/lora_002_frozenmovie_style.ckpt',
|
@@ -45,6 +48,7 @@ class Text2Video():
|
|
45 |
self.origin_weight = None
|
46 |
|
47 |
def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
|
|
|
48 |
if steps > 60:
|
49 |
steps = 60
|
50 |
if model_index > 0:
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
from omegaconf import OmegaConf
|
4 |
|
5 |
from lvdm.samplers.ddim import DDIMSampler
|
|
|
30 |
self.download_model()
|
31 |
config_file = 'models/base_t2v/model_config.yaml'
|
32 |
ckpt_path = 'models/base_t2v/model.ckpt'
|
33 |
+
if os.path.exists('/dev/shm/model.ckpt'):
|
34 |
+
ckpt_path='/dev/shm/model.ckpt'
|
35 |
config = OmegaConf.load(config_file)
|
36 |
self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt',
|
37 |
'models/videolora/lora_002_frozenmovie_style.ckpt',
|
|
|
48 |
self.origin_weight = None
|
49 |
|
50 |
def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
|
51 |
+
torch.cuda.empty_cache()
|
52 |
if steps > 60:
|
53 |
steps = 60
|
54 |
if model_index > 0:
|