Haoxin Chen commited on
Commit
32619a4
·
1 Parent(s): 8cdb359

add variable resolution and frame

Browse files
Files changed (3) hide show
  1. app.py +7 -5
  2. videocontrol_test.py +34 -14
  3. videocrafter_test.py +4 -0
app.py CHANGED
@@ -15,7 +15,7 @@ t2v_examples = [
15
  ]
16
 
17
  control_examples = [
18
- ['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1]
19
  ]
20
 
21
  def videocrafter_demo(result_dir='./tmp/'):
@@ -23,7 +23,7 @@ def videocrafter_demo(result_dir='./tmp/'):
23
  videocontrol = VideoControl(result_dir)
24
  with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
25
  gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
26
- <a style='font-size:18px;color: #efefef' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
27
  #######t2v#######
28
  with gr.Tab(label="Text2Video"):
29
  with gr.Column():
@@ -70,7 +70,9 @@ def videocrafter_demo(result_dir='./tmp/'):
70
  with gr.Row():
71
  vc_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="vc_steps", label="Sampling steps", value=50)
72
  frame_stride = gr.Slider(minimum=0 , maximum=100, step=1, label='Frame Stride', value=0, elem_id="vc_frame_stride")
73
-
 
 
74
  vc_end_btn = gr.Button("Send")
75
  with gr.Tab(label='Result'):
76
  vc_output_info = gr.Text(label='Info')
@@ -79,12 +81,12 @@ def videocrafter_demo(result_dir='./tmp/'):
79
  vc_output_video = gr.Video(label="Generated Video").style(width=256)
80
 
81
  gr.Examples(examples=control_examples,
82
- inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
83
  outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
84
  fn = videocontrol.get_video,
85
  cache_examples=os.getenv('SYSTEM') == 'spaces',
86
  )
87
- vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
88
  outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
89
  fn = videocontrol.get_video
90
  )
 
15
  ]
16
 
17
  control_examples = [
18
+ ['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1, 16, 256]
19
  ]
20
 
21
  def videocrafter_demo(result_dir='./tmp/'):
 
23
  videocontrol = VideoControl(result_dir)
24
  with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
25
  gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
26
+ <a style='font-size:18px;color: #000000' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
27
  #######t2v#######
28
  with gr.Tab(label="Text2Video"):
29
  with gr.Column():
 
70
  with gr.Row():
71
  vc_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="vc_steps", label="Sampling steps", value=50)
72
  frame_stride = gr.Slider(minimum=0 , maximum=100, step=1, label='Frame Stride', value=0, elem_id="vc_frame_stride")
73
+ with gr.Row():
74
+ resolution = gr.Slider(minimum=128 , maximum=512, step=8, label='Long Side Resolution', value=256, elem_id="vc_resolution")
75
+ video_frames = gr.Slider(minimum=8 , maximum=64, step=1, label='Video Frame Num', value=16, elem_id="vc_video_frames")
76
  vc_end_btn = gr.Button("Send")
77
  with gr.Tab(label='Result'):
78
  vc_output_info = gr.Text(label='Info')
 
81
  vc_output_video = gr.Video(label="Generated Video").style(width=256)
82
 
83
  gr.Examples(examples=control_examples,
84
+ inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta, video_frames, resolution],
85
  outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
86
  fn = videocontrol.get_video,
87
  cache_examples=os.getenv('SYSTEM') == 'spaces',
88
  )
89
+ vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta, video_frames, resolution],
90
  outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
91
  fn = videocontrol.get_video
92
  )
videocontrol_test.py CHANGED
@@ -50,7 +50,8 @@ class VideoControl:
50
  config_path = "models/adapter_t2v_depth/model_config.yaml"
51
  ckpt_path = "models/base_t2v/model.ckpt"
52
  adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
53
-
 
54
  config = OmegaConf.load(config_path)
55
  model_config = config.pop("model", OmegaConf.create())
56
  model = instantiate_from_config(model_config)
@@ -59,10 +60,18 @@ class VideoControl:
59
  model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
60
  model.eval()
61
  self.model = model
62
- self.resolution=256
63
- self.spatial_transform = transforms_video.CenterCropVideo(self.resolution)
64
 
65
- def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0):
 
 
 
 
 
 
 
 
 
 
66
  if vc_steps > 60:
67
  vc_steps = 60
68
  ## load video
@@ -74,32 +83,43 @@ class VideoControl:
74
  os.remove(input_video)
75
  return 'please input video', None, None, None
76
 
77
- if h < w:
78
- scale = h / self.resolution
79
  else:
80
- scale = w / self.resolution
81
  h = math.ceil(h / scale)
82
  w = math.ceil(w / scale)
83
  try:
84
- video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
85
  except:
86
  os.remove(input_video)
87
  return 'load video error', None, None, None
88
- video = self.spatial_transform(video)
 
 
 
 
 
89
  print('video shape', video.shape)
90
 
91
- h, w = 32, 32
92
  bs = 1
93
  channels = self.model.channels
94
- frames = self.model.temporal_length
95
- noise_shape = [bs, channels, frames, h, w]
 
96
 
97
  ## inference
98
  start = time.time()
99
  prompt = input_prompt
100
  video = video.unsqueeze(0).to("cuda")
101
- with torch.no_grad():
102
- batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
 
 
 
 
 
103
  batch_samples = batch_samples[0]
104
  os.makedirs(self.savedir, exist_ok=True)
105
  filename = prompt
 
50
  config_path = "models/adapter_t2v_depth/model_config.yaml"
51
  ckpt_path = "models/base_t2v/model.ckpt"
52
  adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
53
+ if os.path.exists('/dev/shm/model.ckpt'):
54
+ ckpt_path='/dev/shm/model.ckpt'
55
  config = OmegaConf.load(config_path)
56
  model_config = config.pop("model", OmegaConf.create())
57
  model = instantiate_from_config(model_config)
 
60
  model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
61
  model.eval()
62
  self.model = model
 
 
63
 
64
+ def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
65
+ torch.cuda.empty_cache()
66
+ if resolution > 512:
67
+ resolution = 512
68
+ if resolution < 64:
69
+ resolution = 64
70
+ if video_frames > 64:
71
+ video_frames = 64
72
+
73
+ resolution = int(resolution//64)*64
74
+
75
  if vc_steps > 60:
76
  vc_steps = 60
77
  ## load video
 
83
  os.remove(input_video)
84
  return 'please input video', None, None, None
85
 
86
+ if h > w:
87
+ scale = h / resolution
88
  else:
89
+ scale = w / resolution
90
  h = math.ceil(h / scale)
91
  w = math.ceil(w / scale)
92
  try:
93
+ video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames)
94
  except:
95
  os.remove(input_video)
96
  return 'load video error', None, None, None
97
+ if h > w:
98
+ w = int(w//64)*64
99
+ else:
100
+ h = int(h//64)*64
101
+ spatial_transform = transforms_video.CenterCropVideo((h,w))
102
+ video = spatial_transform(video)
103
  print('video shape', video.shape)
104
 
105
+ rh, rw = h//8, w//8
106
  bs = 1
107
  channels = self.model.channels
108
+ # frames = self.model.temporal_length
109
+ frames = video_frames
110
+ noise_shape = [bs, channels, frames, rh, rw]
111
 
112
  ## inference
113
  start = time.time()
114
  prompt = input_prompt
115
  video = video.unsqueeze(0).to("cuda")
116
+ try:
117
+ with torch.no_grad():
118
+ batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
119
+ except:
120
+ torch.cuda.empty_cache()
121
+ info_str="OOM, please enter a smaller resolution or smaller frame num"
122
+ return info_str, None, None, None
123
  batch_samples = batch_samples[0]
124
  os.makedirs(self.savedir, exist_ok=True)
125
  filename = prompt
videocrafter_test.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from omegaconf import OmegaConf
3
 
4
  from lvdm.samplers.ddim import DDIMSampler
@@ -29,6 +30,8 @@ class Text2Video():
29
  self.download_model()
30
  config_file = 'models/base_t2v/model_config.yaml'
31
  ckpt_path = 'models/base_t2v/model.ckpt'
 
 
32
  config = OmegaConf.load(config_file)
33
  self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt',
34
  'models/videolora/lora_002_frozenmovie_style.ckpt',
@@ -45,6 +48,7 @@ class Text2Video():
45
  self.origin_weight = None
46
 
47
  def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
 
48
  if steps > 60:
49
  steps = 60
50
  if model_index > 0:
 
1
  import os
2
+ import torch
3
  from omegaconf import OmegaConf
4
 
5
  from lvdm.samplers.ddim import DDIMSampler
 
30
  self.download_model()
31
  config_file = 'models/base_t2v/model_config.yaml'
32
  ckpt_path = 'models/base_t2v/model.ckpt'
33
+ if os.path.exists('/dev/shm/model.ckpt'):
34
+ ckpt_path='/dev/shm/model.ckpt'
35
  config = OmegaConf.load(config_file)
36
  self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt',
37
  'models/videolora/lora_002_frozenmovie_style.ckpt',
 
48
  self.origin_weight = None
49
 
50
  def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
51
+ torch.cuda.empty_cache()
52
  if steps > 60:
53
  steps = 60
54
  if model_index > 0: