bubbliiiing
commited on
Commit
·
14d2973
1
Parent(s):
933a5a0
update to v1.1
Browse files- app.py +5 -2
- cogvideox/api/api.py +25 -1
- cogvideox/api/post_infer.py +1 -1
- cogvideox/data/dataset_image_video.py +222 -0
- cogvideox/models/transformer3d.py +43 -1
- cogvideox/pipeline/pipeline_cogvideox_control.py +843 -0
- cogvideox/pipeline/pipeline_cogvideox_inpaint.py +18 -1
- cogvideox/ui/ui.py +387 -176
- cogvideox/utils/utils.py +25 -6
app.py
CHANGED
@@ -19,11 +19,14 @@ if __name__ == "__main__":
|
|
19 |
server_port = 7860
|
20 |
|
21 |
# Params below is used when ui_mode = "modelscope"
|
22 |
-
model_name = "models/Diffusion_Transformer/CogVideoX-Fun-5b-InP"
|
|
|
|
|
|
|
23 |
savedir_sample = "samples"
|
24 |
|
25 |
if ui_mode == "modelscope":
|
26 |
-
demo, controller = ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
27 |
elif ui_mode == "eas":
|
28 |
demo, controller = ui_eas(model_name, savedir_sample)
|
29 |
else:
|
|
|
19 |
server_port = 7860
|
20 |
|
21 |
# Params below is used when ui_mode = "modelscope"
|
22 |
+
model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-5b-InP"
|
23 |
+
# "Inpaint" or "Control"
|
24 |
+
model_type = "Inpaint"
|
25 |
+
# Save dir of this model
|
26 |
savedir_sample = "samples"
|
27 |
|
28 |
if ui_mode == "modelscope":
|
29 |
+
demo, controller = ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
30 |
elif ui_mode == "eas":
|
31 |
demo, controller = ui_eas(model_name, savedir_sample)
|
32 |
else:
|
cogvideox/api/api.py
CHANGED
@@ -68,6 +68,20 @@ def save_base64_video(base64_string):
|
|
68 |
|
69 |
return file_path
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
72 |
@app.post("/cogvideox_fun/infer_forward")
|
73 |
def _infer_forward_api(
|
@@ -77,7 +91,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
77 |
lora_model_path = datas.get('lora_model_path', 'none')
|
78 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
79 |
prompt_textbox = datas.get('prompt_textbox', None)
|
80 |
-
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange
|
81 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
82 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
83 |
resize_method = datas.get('resize_method', "Generate by")
|
@@ -93,6 +107,8 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
93 |
start_image = datas.get('start_image', None)
|
94 |
end_image = datas.get('end_image', None)
|
95 |
validation_video = datas.get('validation_video', None)
|
|
|
|
|
96 |
denoise_strength = datas.get('denoise_strength', 0.70)
|
97 |
seed_textbox = datas.get("seed_textbox", 43)
|
98 |
|
@@ -109,6 +125,12 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
109 |
if validation_video is not None:
|
110 |
validation_video = save_base64_video(validation_video)
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
try:
|
113 |
save_sample_path, comment = controller.generate(
|
114 |
"",
|
@@ -131,6 +153,8 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
131 |
start_image,
|
132 |
end_image,
|
133 |
validation_video,
|
|
|
|
|
134 |
denoise_strength,
|
135 |
seed_textbox,
|
136 |
is_api = True,
|
|
|
68 |
|
69 |
return file_path
|
70 |
|
71 |
+
def save_base64_image(base64_string):
|
72 |
+
video_data = base64.b64decode(base64_string)
|
73 |
+
|
74 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
75 |
+
filename = f"{md5_hash}.jpg"
|
76 |
+
|
77 |
+
temp_dir = tempfile.gettempdir()
|
78 |
+
file_path = os.path.join(temp_dir, filename)
|
79 |
+
|
80 |
+
with open(file_path, 'wb') as video_file:
|
81 |
+
video_file.write(video_data)
|
82 |
+
|
83 |
+
return file_path
|
84 |
+
|
85 |
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
86 |
@app.post("/cogvideox_fun/infer_forward")
|
87 |
def _infer_forward_api(
|
|
|
91 |
lora_model_path = datas.get('lora_model_path', 'none')
|
92 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
93 |
prompt_textbox = datas.get('prompt_textbox', None)
|
94 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
95 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
96 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
97 |
resize_method = datas.get('resize_method', "Generate by")
|
|
|
107 |
start_image = datas.get('start_image', None)
|
108 |
end_image = datas.get('end_image', None)
|
109 |
validation_video = datas.get('validation_video', None)
|
110 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
111 |
+
control_video = datas.get('control_video', None)
|
112 |
denoise_strength = datas.get('denoise_strength', 0.70)
|
113 |
seed_textbox = datas.get("seed_textbox", 43)
|
114 |
|
|
|
125 |
if validation_video is not None:
|
126 |
validation_video = save_base64_video(validation_video)
|
127 |
|
128 |
+
if validation_video_mask is not None:
|
129 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
130 |
+
|
131 |
+
if control_video is not None:
|
132 |
+
control_video = save_base64_video(control_video)
|
133 |
+
|
134 |
try:
|
135 |
save_sample_path, comment = controller.generate(
|
136 |
"",
|
|
|
153 |
start_image,
|
154 |
end_image,
|
155 |
validation_video,
|
156 |
+
validation_video_mask,
|
157 |
+
control_video,
|
158 |
denoise_strength,
|
159 |
seed_textbox,
|
160 |
is_api = True,
|
cogvideox/api/post_infer.py
CHANGED
@@ -33,7 +33,7 @@ def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
|
|
33 |
"lora_model_path": "none",
|
34 |
"lora_alpha_slider": 0.55,
|
35 |
"prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
36 |
-
"negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange
|
37 |
"sampler_dropdown": "Euler",
|
38 |
"sample_step_slider": 50,
|
39 |
"width_slider": 672,
|
|
|
33 |
"lora_model_path": "none",
|
34 |
"lora_alpha_slider": 0.55,
|
35 |
"prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
36 |
+
"negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ",
|
37 |
"sampler_dropdown": "Euler",
|
38 |
"sample_step_slider": 50,
|
39 |
"width_slider": 672,
|
cogvideox/data/dataset_image_video.py
CHANGED
@@ -322,3 +322,225 @@ class ImageVideoDataset(Dataset):
|
|
322 |
|
323 |
return sample
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
return sample
|
324 |
|
325 |
+
|
326 |
+
class ImageVideoControlDataset(Dataset):
|
327 |
+
def __init__(
|
328 |
+
self,
|
329 |
+
ann_path, data_root=None,
|
330 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
331 |
+
image_sample_size=512,
|
332 |
+
video_repeat=0,
|
333 |
+
text_drop_ratio=-1,
|
334 |
+
enable_bucket=False,
|
335 |
+
video_length_drop_start=0.1,
|
336 |
+
video_length_drop_end=0.9,
|
337 |
+
enable_inpaint=False,
|
338 |
+
):
|
339 |
+
# Loading annotations from files
|
340 |
+
print(f"loading annotations from {ann_path} ...")
|
341 |
+
if ann_path.endswith('.csv'):
|
342 |
+
with open(ann_path, 'r') as csvfile:
|
343 |
+
dataset = list(csv.DictReader(csvfile))
|
344 |
+
elif ann_path.endswith('.json'):
|
345 |
+
dataset = json.load(open(ann_path))
|
346 |
+
|
347 |
+
self.data_root = data_root
|
348 |
+
|
349 |
+
# It's used to balance num of images and videos.
|
350 |
+
self.dataset = []
|
351 |
+
for data in dataset:
|
352 |
+
if data.get('type', 'image') != 'video':
|
353 |
+
self.dataset.append(data)
|
354 |
+
if video_repeat > 0:
|
355 |
+
for _ in range(video_repeat):
|
356 |
+
for data in dataset:
|
357 |
+
if data.get('type', 'image') == 'video':
|
358 |
+
self.dataset.append(data)
|
359 |
+
del dataset
|
360 |
+
|
361 |
+
self.length = len(self.dataset)
|
362 |
+
print(f"data scale: {self.length}")
|
363 |
+
# TODO: enable bucket training
|
364 |
+
self.enable_bucket = enable_bucket
|
365 |
+
self.text_drop_ratio = text_drop_ratio
|
366 |
+
self.enable_inpaint = enable_inpaint
|
367 |
+
|
368 |
+
self.video_length_drop_start = video_length_drop_start
|
369 |
+
self.video_length_drop_end = video_length_drop_end
|
370 |
+
|
371 |
+
# Video params
|
372 |
+
self.video_sample_stride = video_sample_stride
|
373 |
+
self.video_sample_n_frames = video_sample_n_frames
|
374 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
375 |
+
self.video_transforms = transforms.Compose(
|
376 |
+
[
|
377 |
+
transforms.Resize(min(self.video_sample_size)),
|
378 |
+
transforms.CenterCrop(self.video_sample_size),
|
379 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
380 |
+
]
|
381 |
+
)
|
382 |
+
|
383 |
+
# Image params
|
384 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
385 |
+
self.image_transforms = transforms.Compose([
|
386 |
+
transforms.Resize(min(self.image_sample_size)),
|
387 |
+
transforms.CenterCrop(self.image_sample_size),
|
388 |
+
transforms.ToTensor(),
|
389 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
390 |
+
])
|
391 |
+
|
392 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
393 |
+
|
394 |
+
def get_batch(self, idx):
|
395 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
396 |
+
video_id, control_video_id, text = data_info['file_path'], data_info['control_file_path'], data_info['text']
|
397 |
+
|
398 |
+
if data_info.get('type', 'image')=='video':
|
399 |
+
if self.data_root is None:
|
400 |
+
video_dir = video_id
|
401 |
+
else:
|
402 |
+
video_dir = os.path.join(self.data_root, video_id)
|
403 |
+
|
404 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
405 |
+
min_sample_n_frames = min(
|
406 |
+
self.video_sample_n_frames,
|
407 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
408 |
+
)
|
409 |
+
if min_sample_n_frames == 0:
|
410 |
+
raise ValueError(f"No Frames in video.")
|
411 |
+
|
412 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
413 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
414 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
415 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
416 |
+
|
417 |
+
try:
|
418 |
+
sample_args = (video_reader, batch_index)
|
419 |
+
pixel_values = func_timeout(
|
420 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
421 |
+
)
|
422 |
+
resized_frames = []
|
423 |
+
for i in range(len(pixel_values)):
|
424 |
+
frame = pixel_values[i]
|
425 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
426 |
+
resized_frames.append(resized_frame)
|
427 |
+
pixel_values = np.array(resized_frames)
|
428 |
+
except FunctionTimedOut:
|
429 |
+
raise ValueError(f"Read {idx} timeout.")
|
430 |
+
except Exception as e:
|
431 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
432 |
+
|
433 |
+
if not self.enable_bucket:
|
434 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
435 |
+
pixel_values = pixel_values / 255.
|
436 |
+
del video_reader
|
437 |
+
else:
|
438 |
+
pixel_values = pixel_values
|
439 |
+
|
440 |
+
if not self.enable_bucket:
|
441 |
+
pixel_values = self.video_transforms(pixel_values)
|
442 |
+
|
443 |
+
# Random use no text generation
|
444 |
+
if random.random() < self.text_drop_ratio:
|
445 |
+
text = ''
|
446 |
+
|
447 |
+
if self.data_root is None:
|
448 |
+
control_video_id = control_video_id
|
449 |
+
else:
|
450 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
451 |
+
|
452 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
453 |
+
try:
|
454 |
+
sample_args = (control_video_reader, batch_index)
|
455 |
+
control_pixel_values = func_timeout(
|
456 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
457 |
+
)
|
458 |
+
resized_frames = []
|
459 |
+
for i in range(len(control_pixel_values)):
|
460 |
+
frame = control_pixel_values[i]
|
461 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
462 |
+
resized_frames.append(resized_frame)
|
463 |
+
control_pixel_values = np.array(resized_frames)
|
464 |
+
except FunctionTimedOut:
|
465 |
+
raise ValueError(f"Read {idx} timeout.")
|
466 |
+
except Exception as e:
|
467 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
468 |
+
|
469 |
+
if not self.enable_bucket:
|
470 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
471 |
+
control_pixel_values = control_pixel_values / 255.
|
472 |
+
del control_video_reader
|
473 |
+
else:
|
474 |
+
control_pixel_values = control_pixel_values
|
475 |
+
|
476 |
+
if not self.enable_bucket:
|
477 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
478 |
+
return pixel_values, control_pixel_values, text, "video"
|
479 |
+
else:
|
480 |
+
image_path, text = data_info['file_path'], data_info['text']
|
481 |
+
if self.data_root is not None:
|
482 |
+
image_path = os.path.join(self.data_root, image_path)
|
483 |
+
image = Image.open(image_path).convert('RGB')
|
484 |
+
if not self.enable_bucket:
|
485 |
+
image = self.image_transforms(image).unsqueeze(0)
|
486 |
+
else:
|
487 |
+
image = np.expand_dims(np.array(image), 0)
|
488 |
+
|
489 |
+
if random.random() < self.text_drop_ratio:
|
490 |
+
text = ''
|
491 |
+
|
492 |
+
if self.data_root is None:
|
493 |
+
control_image_id = control_image_id
|
494 |
+
else:
|
495 |
+
control_image_id = os.path.join(self.data_root, control_image_id)
|
496 |
+
|
497 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
498 |
+
if not self.enable_bucket:
|
499 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
500 |
+
else:
|
501 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
502 |
+
return image, control_image, text, 'image'
|
503 |
+
|
504 |
+
def __len__(self):
|
505 |
+
return self.length
|
506 |
+
|
507 |
+
def __getitem__(self, idx):
|
508 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
509 |
+
data_type = data_info.get('type', 'image')
|
510 |
+
while True:
|
511 |
+
sample = {}
|
512 |
+
try:
|
513 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
514 |
+
data_type_local = data_info_local.get('type', 'image')
|
515 |
+
if data_type_local != data_type:
|
516 |
+
raise ValueError("data_type_local != data_type")
|
517 |
+
|
518 |
+
pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
|
519 |
+
sample["pixel_values"] = pixel_values
|
520 |
+
sample["control_pixel_values"] = control_pixel_values
|
521 |
+
sample["text"] = name
|
522 |
+
sample["data_type"] = data_type
|
523 |
+
sample["idx"] = idx
|
524 |
+
|
525 |
+
if len(sample) > 0:
|
526 |
+
break
|
527 |
+
except Exception as e:
|
528 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
529 |
+
idx = random.randint(0, self.length-1)
|
530 |
+
|
531 |
+
if self.enable_inpaint and not self.enable_bucket:
|
532 |
+
mask = get_random_mask(pixel_values.size())
|
533 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
534 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
535 |
+
sample["mask"] = mask
|
536 |
+
|
537 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
538 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
539 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
540 |
+
|
541 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
542 |
+
if (mask == 1).all():
|
543 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
544 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
545 |
+
|
546 |
+
return sample
|
cogvideox/models/transformer3d.py
CHANGED
@@ -27,7 +27,7 @@ from diffusers.utils import is_torch_version, logging
|
|
27 |
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
28 |
from diffusers.models.attention import Attention, FeedForward
|
29 |
from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
30 |
-
from diffusers.models.embeddings import
|
31 |
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
32 |
from diffusers.models.modeling_utils import ModelMixin
|
33 |
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
@@ -35,6 +35,44 @@ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|
35 |
|
36 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
@maybe_allow_in_graph
|
40 |
class CogVideoXBlock(nn.Module):
|
@@ -239,6 +277,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
239 |
spatial_interpolation_scale: float = 1.875,
|
240 |
temporal_interpolation_scale: float = 1.0,
|
241 |
use_rotary_positional_embeddings: bool = False,
|
|
|
242 |
):
|
243 |
super().__init__()
|
244 |
inner_dim = num_attention_heads * attention_head_dim
|
@@ -414,6 +453,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
414 |
timestep: Union[int, float, torch.LongTensor],
|
415 |
timestep_cond: Optional[torch.Tensor] = None,
|
416 |
inpaint_latents: Optional[torch.Tensor] = None,
|
|
|
417 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
418 |
return_dict: bool = True,
|
419 |
):
|
@@ -432,6 +472,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
432 |
# 2. Patch embedding
|
433 |
if inpaint_latents is not None:
|
434 |
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
|
|
|
|
435 |
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
436 |
|
437 |
# 3. Position embedding
|
|
|
27 |
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
28 |
from diffusers.models.attention import Attention, FeedForward
|
29 |
from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
30 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
31 |
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
32 |
from diffusers.models.modeling_utils import ModelMixin
|
33 |
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|
|
35 |
|
36 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
|
38 |
+
class CogVideoXPatchEmbed(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
patch_size: int = 2,
|
42 |
+
in_channels: int = 16,
|
43 |
+
embed_dim: int = 1920,
|
44 |
+
text_embed_dim: int = 4096,
|
45 |
+
bias: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
self.patch_size = patch_size
|
49 |
+
|
50 |
+
self.proj = nn.Conv2d(
|
51 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
52 |
+
)
|
53 |
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
54 |
+
|
55 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
56 |
+
r"""
|
57 |
+
Args:
|
58 |
+
text_embeds (`torch.Tensor`):
|
59 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
60 |
+
image_embeds (`torch.Tensor`):
|
61 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
62 |
+
"""
|
63 |
+
text_embeds = self.text_proj(text_embeds)
|
64 |
+
|
65 |
+
batch, num_frames, channels, height, width = image_embeds.shape
|
66 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
67 |
+
image_embeds = self.proj(image_embeds)
|
68 |
+
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
69 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
70 |
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
71 |
+
|
72 |
+
embeds = torch.cat(
|
73 |
+
[text_embeds, image_embeds], dim=1
|
74 |
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
75 |
+
return embeds
|
76 |
|
77 |
@maybe_allow_in_graph
|
78 |
class CogVideoXBlock(nn.Module):
|
|
|
277 |
spatial_interpolation_scale: float = 1.875,
|
278 |
temporal_interpolation_scale: float = 1.0,
|
279 |
use_rotary_positional_embeddings: bool = False,
|
280 |
+
add_noise_in_inpaint_model: bool = False,
|
281 |
):
|
282 |
super().__init__()
|
283 |
inner_dim = num_attention_heads * attention_head_dim
|
|
|
453 |
timestep: Union[int, float, torch.LongTensor],
|
454 |
timestep_cond: Optional[torch.Tensor] = None,
|
455 |
inpaint_latents: Optional[torch.Tensor] = None,
|
456 |
+
control_latents: Optional[torch.Tensor] = None,
|
457 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
458 |
return_dict: bool = True,
|
459 |
):
|
|
|
472 |
# 2. Patch embedding
|
473 |
if inpaint_latents is not None:
|
474 |
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
475 |
+
if control_latents is not None:
|
476 |
+
hidden_states = torch.concat([hidden_states, control_latents], 2)
|
477 |
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
478 |
|
479 |
# 3. Position embedding
|
cogvideox/pipeline/pipeline_cogvideox_control.py
ADDED
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import inspect
|
17 |
+
import math
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from einops import rearrange
|
24 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
25 |
+
|
26 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
+
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
28 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
29 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
30 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
31 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from diffusers.video_processor import VideoProcessor
|
34 |
+
from diffusers.image_processor import VaeImageProcessor
|
35 |
+
from einops import rearrange
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
EXAMPLE_DOC_STRING = """
|
42 |
+
Examples:
|
43 |
+
```python
|
44 |
+
>>> import torch
|
45 |
+
>>> from diffusers import CogVideoX_Fun_Pipeline
|
46 |
+
>>> from diffusers.utils import export_to_video
|
47 |
+
|
48 |
+
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
49 |
+
>>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
50 |
+
>>> prompt = (
|
51 |
+
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
52 |
+
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
53 |
+
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
54 |
+
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
55 |
+
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
56 |
+
... "atmosphere of this unique musical performance."
|
57 |
+
... )
|
58 |
+
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
59 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
60 |
+
```
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
65 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
66 |
+
tw = tgt_width
|
67 |
+
th = tgt_height
|
68 |
+
h, w = src
|
69 |
+
r = h / w
|
70 |
+
if r > (th / tw):
|
71 |
+
resize_height = th
|
72 |
+
resize_width = int(round(th / h * w))
|
73 |
+
else:
|
74 |
+
resize_width = tw
|
75 |
+
resize_height = int(round(tw / w * h))
|
76 |
+
|
77 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
78 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
79 |
+
|
80 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
81 |
+
|
82 |
+
|
83 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
84 |
+
def retrieve_timesteps(
|
85 |
+
scheduler,
|
86 |
+
num_inference_steps: Optional[int] = None,
|
87 |
+
device: Optional[Union[str, torch.device]] = None,
|
88 |
+
timesteps: Optional[List[int]] = None,
|
89 |
+
sigmas: Optional[List[float]] = None,
|
90 |
+
**kwargs,
|
91 |
+
):
|
92 |
+
"""
|
93 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
94 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
scheduler (`SchedulerMixin`):
|
98 |
+
The scheduler to get timesteps from.
|
99 |
+
num_inference_steps (`int`):
|
100 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
101 |
+
must be `None`.
|
102 |
+
device (`str` or `torch.device`, *optional*):
|
103 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
104 |
+
timesteps (`List[int]`, *optional*):
|
105 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
106 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
107 |
+
sigmas (`List[float]`, *optional*):
|
108 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
109 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
113 |
+
second element is the number of inference steps.
|
114 |
+
"""
|
115 |
+
if timesteps is not None and sigmas is not None:
|
116 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
117 |
+
if timesteps is not None:
|
118 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
119 |
+
if not accepts_timesteps:
|
120 |
+
raise ValueError(
|
121 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
122 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
123 |
+
)
|
124 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
125 |
+
timesteps = scheduler.timesteps
|
126 |
+
num_inference_steps = len(timesteps)
|
127 |
+
elif sigmas is not None:
|
128 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
129 |
+
if not accept_sigmas:
|
130 |
+
raise ValueError(
|
131 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
132 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
133 |
+
)
|
134 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
135 |
+
timesteps = scheduler.timesteps
|
136 |
+
num_inference_steps = len(timesteps)
|
137 |
+
else:
|
138 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
139 |
+
timesteps = scheduler.timesteps
|
140 |
+
return timesteps, num_inference_steps
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass
|
144 |
+
class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
145 |
+
r"""
|
146 |
+
Output class for CogVideo pipelines.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
150 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
151 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
152 |
+
`(batch_size, num_frames, channels, height, width)`.
|
153 |
+
"""
|
154 |
+
|
155 |
+
videos: torch.Tensor
|
156 |
+
|
157 |
+
|
158 |
+
class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
|
159 |
+
r"""
|
160 |
+
Pipeline for text-to-video generation using CogVideoX.
|
161 |
+
|
162 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
163 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
164 |
+
|
165 |
+
Args:
|
166 |
+
vae ([`AutoencoderKL`]):
|
167 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
168 |
+
text_encoder ([`T5EncoderModel`]):
|
169 |
+
Frozen text-encoder. CogVideoX_Fun uses
|
170 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
171 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
172 |
+
tokenizer (`T5Tokenizer`):
|
173 |
+
Tokenizer of class
|
174 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
175 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
176 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
177 |
+
scheduler ([`SchedulerMixin`]):
|
178 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
179 |
+
"""
|
180 |
+
|
181 |
+
_optional_components = []
|
182 |
+
model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
|
183 |
+
|
184 |
+
_callback_tensor_inputs = [
|
185 |
+
"latents",
|
186 |
+
"prompt_embeds",
|
187 |
+
"negative_prompt_embeds",
|
188 |
+
]
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
tokenizer: T5Tokenizer,
|
193 |
+
text_encoder: T5EncoderModel,
|
194 |
+
vae: AutoencoderKLCogVideoX,
|
195 |
+
transformer: CogVideoXTransformer3DModel,
|
196 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
197 |
+
):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.register_modules(
|
201 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
202 |
+
)
|
203 |
+
self.vae_scale_factor_spatial = (
|
204 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
205 |
+
)
|
206 |
+
self.vae_scale_factor_temporal = (
|
207 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
208 |
+
)
|
209 |
+
|
210 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
211 |
+
|
212 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
213 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
214 |
+
self.mask_processor = VaeImageProcessor(
|
215 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
216 |
+
)
|
217 |
+
|
218 |
+
def _get_t5_prompt_embeds(
|
219 |
+
self,
|
220 |
+
prompt: Union[str, List[str]] = None,
|
221 |
+
num_videos_per_prompt: int = 1,
|
222 |
+
max_sequence_length: int = 226,
|
223 |
+
device: Optional[torch.device] = None,
|
224 |
+
dtype: Optional[torch.dtype] = None,
|
225 |
+
):
|
226 |
+
device = device or self._execution_device
|
227 |
+
dtype = dtype or self.text_encoder.dtype
|
228 |
+
|
229 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
230 |
+
batch_size = len(prompt)
|
231 |
+
|
232 |
+
text_inputs = self.tokenizer(
|
233 |
+
prompt,
|
234 |
+
padding="max_length",
|
235 |
+
max_length=max_sequence_length,
|
236 |
+
truncation=True,
|
237 |
+
add_special_tokens=True,
|
238 |
+
return_tensors="pt",
|
239 |
+
)
|
240 |
+
text_input_ids = text_inputs.input_ids
|
241 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
242 |
+
|
243 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
244 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
245 |
+
logger.warning(
|
246 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
247 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
248 |
+
)
|
249 |
+
|
250 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
251 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
252 |
+
|
253 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
254 |
+
_, seq_len, _ = prompt_embeds.shape
|
255 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
256 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
257 |
+
|
258 |
+
return prompt_embeds
|
259 |
+
|
260 |
+
def encode_prompt(
|
261 |
+
self,
|
262 |
+
prompt: Union[str, List[str]],
|
263 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
264 |
+
do_classifier_free_guidance: bool = True,
|
265 |
+
num_videos_per_prompt: int = 1,
|
266 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
267 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
268 |
+
max_sequence_length: int = 226,
|
269 |
+
device: Optional[torch.device] = None,
|
270 |
+
dtype: Optional[torch.dtype] = None,
|
271 |
+
):
|
272 |
+
r"""
|
273 |
+
Encodes the prompt into text encoder hidden states.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
prompt (`str` or `List[str]`, *optional*):
|
277 |
+
prompt to be encoded
|
278 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
279 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
280 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
281 |
+
less than `1`).
|
282 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
283 |
+
Whether to use classifier free guidance or not.
|
284 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
285 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
286 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
287 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
288 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
289 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
290 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
291 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
292 |
+
argument.
|
293 |
+
device: (`torch.device`, *optional*):
|
294 |
+
torch device
|
295 |
+
dtype: (`torch.dtype`, *optional*):
|
296 |
+
torch dtype
|
297 |
+
"""
|
298 |
+
device = device or self._execution_device
|
299 |
+
|
300 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
301 |
+
if prompt is not None:
|
302 |
+
batch_size = len(prompt)
|
303 |
+
else:
|
304 |
+
batch_size = prompt_embeds.shape[0]
|
305 |
+
|
306 |
+
if prompt_embeds is None:
|
307 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
308 |
+
prompt=prompt,
|
309 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
310 |
+
max_sequence_length=max_sequence_length,
|
311 |
+
device=device,
|
312 |
+
dtype=dtype,
|
313 |
+
)
|
314 |
+
|
315 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
316 |
+
negative_prompt = negative_prompt or ""
|
317 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
318 |
+
|
319 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
320 |
+
raise TypeError(
|
321 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
322 |
+
f" {type(prompt)}."
|
323 |
+
)
|
324 |
+
elif batch_size != len(negative_prompt):
|
325 |
+
raise ValueError(
|
326 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
327 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
328 |
+
" the batch size of `prompt`."
|
329 |
+
)
|
330 |
+
|
331 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
332 |
+
prompt=negative_prompt,
|
333 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
334 |
+
max_sequence_length=max_sequence_length,
|
335 |
+
device=device,
|
336 |
+
dtype=dtype,
|
337 |
+
)
|
338 |
+
|
339 |
+
return prompt_embeds, negative_prompt_embeds
|
340 |
+
|
341 |
+
def prepare_latents(
|
342 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
343 |
+
):
|
344 |
+
shape = (
|
345 |
+
batch_size,
|
346 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
347 |
+
num_channels_latents,
|
348 |
+
height // self.vae_scale_factor_spatial,
|
349 |
+
width // self.vae_scale_factor_spatial,
|
350 |
+
)
|
351 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
352 |
+
raise ValueError(
|
353 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
354 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
355 |
+
)
|
356 |
+
|
357 |
+
if latents is None:
|
358 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
359 |
+
else:
|
360 |
+
latents = latents.to(device)
|
361 |
+
|
362 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
363 |
+
latents = latents * self.scheduler.init_noise_sigma
|
364 |
+
return latents
|
365 |
+
|
366 |
+
def prepare_control_latents(
|
367 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
368 |
+
):
|
369 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
370 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
371 |
+
# and half precision
|
372 |
+
|
373 |
+
if mask is not None:
|
374 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
375 |
+
bs = 1
|
376 |
+
new_mask = []
|
377 |
+
for i in range(0, mask.shape[0], bs):
|
378 |
+
mask_bs = mask[i : i + bs]
|
379 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
380 |
+
mask_bs = mask_bs.mode()
|
381 |
+
new_mask.append(mask_bs)
|
382 |
+
mask = torch.cat(new_mask, dim = 0)
|
383 |
+
mask = mask * self.vae.config.scaling_factor
|
384 |
+
|
385 |
+
if masked_image is not None:
|
386 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
387 |
+
bs = 1
|
388 |
+
new_mask_pixel_values = []
|
389 |
+
for i in range(0, masked_image.shape[0], bs):
|
390 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
391 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
392 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
393 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
394 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
395 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
396 |
+
else:
|
397 |
+
masked_image_latents = None
|
398 |
+
|
399 |
+
return mask, masked_image_latents
|
400 |
+
|
401 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
402 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
403 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
404 |
+
|
405 |
+
frames = self.vae.decode(latents).sample
|
406 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
407 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
408 |
+
frames = frames.cpu().float().numpy()
|
409 |
+
return frames
|
410 |
+
|
411 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
412 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
413 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
414 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
415 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
416 |
+
# and should be between [0, 1]
|
417 |
+
|
418 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
419 |
+
extra_step_kwargs = {}
|
420 |
+
if accepts_eta:
|
421 |
+
extra_step_kwargs["eta"] = eta
|
422 |
+
|
423 |
+
# check if the scheduler accepts generator
|
424 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
425 |
+
if accepts_generator:
|
426 |
+
extra_step_kwargs["generator"] = generator
|
427 |
+
return extra_step_kwargs
|
428 |
+
|
429 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
430 |
+
def check_inputs(
|
431 |
+
self,
|
432 |
+
prompt,
|
433 |
+
height,
|
434 |
+
width,
|
435 |
+
negative_prompt,
|
436 |
+
callback_on_step_end_tensor_inputs,
|
437 |
+
prompt_embeds=None,
|
438 |
+
negative_prompt_embeds=None,
|
439 |
+
):
|
440 |
+
if height % 8 != 0 or width % 8 != 0:
|
441 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
442 |
+
|
443 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
444 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
445 |
+
):
|
446 |
+
raise ValueError(
|
447 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
448 |
+
)
|
449 |
+
if prompt is not None and prompt_embeds is not None:
|
450 |
+
raise ValueError(
|
451 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
452 |
+
" only forward one of the two."
|
453 |
+
)
|
454 |
+
elif prompt is None and prompt_embeds is None:
|
455 |
+
raise ValueError(
|
456 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
457 |
+
)
|
458 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
459 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
460 |
+
|
461 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
462 |
+
raise ValueError(
|
463 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
464 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
465 |
+
)
|
466 |
+
|
467 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
468 |
+
raise ValueError(
|
469 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
470 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
471 |
+
)
|
472 |
+
|
473 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
474 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
475 |
+
raise ValueError(
|
476 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
477 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
478 |
+
f" {negative_prompt_embeds.shape}."
|
479 |
+
)
|
480 |
+
|
481 |
+
def fuse_qkv_projections(self) -> None:
|
482 |
+
r"""Enables fused QKV projections."""
|
483 |
+
self.fusing_transformer = True
|
484 |
+
self.transformer.fuse_qkv_projections()
|
485 |
+
|
486 |
+
def unfuse_qkv_projections(self) -> None:
|
487 |
+
r"""Disable QKV projection fusion if enabled."""
|
488 |
+
if not self.fusing_transformer:
|
489 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
490 |
+
else:
|
491 |
+
self.transformer.unfuse_qkv_projections()
|
492 |
+
self.fusing_transformer = False
|
493 |
+
|
494 |
+
def _prepare_rotary_positional_embeddings(
|
495 |
+
self,
|
496 |
+
height: int,
|
497 |
+
width: int,
|
498 |
+
num_frames: int,
|
499 |
+
device: torch.device,
|
500 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
501 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
502 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
503 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
504 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
505 |
+
|
506 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
507 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
508 |
+
)
|
509 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
510 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
511 |
+
crops_coords=grid_crops_coords,
|
512 |
+
grid_size=(grid_height, grid_width),
|
513 |
+
temporal_size=num_frames,
|
514 |
+
use_real=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
freqs_cos = freqs_cos.to(device=device)
|
518 |
+
freqs_sin = freqs_sin.to(device=device)
|
519 |
+
return freqs_cos, freqs_sin
|
520 |
+
|
521 |
+
@property
|
522 |
+
def guidance_scale(self):
|
523 |
+
return self._guidance_scale
|
524 |
+
|
525 |
+
@property
|
526 |
+
def num_timesteps(self):
|
527 |
+
return self._num_timesteps
|
528 |
+
|
529 |
+
@property
|
530 |
+
def interrupt(self):
|
531 |
+
return self._interrupt
|
532 |
+
|
533 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
534 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
535 |
+
# get the original timestep using init_timestep
|
536 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
537 |
+
|
538 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
539 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
540 |
+
|
541 |
+
return timesteps, num_inference_steps - t_start
|
542 |
+
|
543 |
+
@torch.no_grad()
|
544 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
545 |
+
def __call__(
|
546 |
+
self,
|
547 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
548 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
549 |
+
height: int = 480,
|
550 |
+
width: int = 720,
|
551 |
+
video: Union[torch.FloatTensor] = None,
|
552 |
+
control_video: Union[torch.FloatTensor] = None,
|
553 |
+
num_frames: int = 49,
|
554 |
+
num_inference_steps: int = 50,
|
555 |
+
timesteps: Optional[List[int]] = None,
|
556 |
+
guidance_scale: float = 6,
|
557 |
+
use_dynamic_cfg: bool = False,
|
558 |
+
num_videos_per_prompt: int = 1,
|
559 |
+
eta: float = 0.0,
|
560 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
561 |
+
latents: Optional[torch.FloatTensor] = None,
|
562 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
563 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
564 |
+
output_type: str = "numpy",
|
565 |
+
return_dict: bool = False,
|
566 |
+
callback_on_step_end: Optional[
|
567 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
568 |
+
] = None,
|
569 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
570 |
+
max_sequence_length: int = 226,
|
571 |
+
comfyui_progressbar: bool = False,
|
572 |
+
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
573 |
+
"""
|
574 |
+
Function invoked when calling the pipeline for generation.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
prompt (`str` or `List[str]`, *optional*):
|
578 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
579 |
+
instead.
|
580 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
581 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
582 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
583 |
+
less than `1`).
|
584 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
585 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
586 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
587 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
588 |
+
num_frames (`int`, defaults to `48`):
|
589 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
590 |
+
contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
|
591 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
592 |
+
needs to be satisfied is that of divisibility mentioned above.
|
593 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
594 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
595 |
+
expense of slower inference.
|
596 |
+
timesteps (`List[int]`, *optional*):
|
597 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
598 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
599 |
+
passed will be used. Must be in descending order.
|
600 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
601 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
602 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
603 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
604 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
605 |
+
usually at the expense of lower image quality.
|
606 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
607 |
+
The number of videos to generate per prompt.
|
608 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
609 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
610 |
+
to make generation deterministic.
|
611 |
+
latents (`torch.FloatTensor`, *optional*):
|
612 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
613 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
614 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
615 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
616 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
617 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
618 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
619 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
620 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
621 |
+
argument.
|
622 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
623 |
+
The output format of the generate image. Choose between
|
624 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
625 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
626 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
627 |
+
of a plain tuple.
|
628 |
+
callback_on_step_end (`Callable`, *optional*):
|
629 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
630 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
631 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
632 |
+
`callback_on_step_end_tensor_inputs`.
|
633 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
634 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
635 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
636 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
637 |
+
max_sequence_length (`int`, defaults to `226`):
|
638 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
639 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
640 |
+
|
641 |
+
Examples:
|
642 |
+
|
643 |
+
Returns:
|
644 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
|
645 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
|
646 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
647 |
+
"""
|
648 |
+
|
649 |
+
if num_frames > 49:
|
650 |
+
raise ValueError(
|
651 |
+
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
652 |
+
)
|
653 |
+
|
654 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
655 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
656 |
+
|
657 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
658 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
659 |
+
num_videos_per_prompt = 1
|
660 |
+
|
661 |
+
# 1. Check inputs. Raise error if not correct
|
662 |
+
self.check_inputs(
|
663 |
+
prompt,
|
664 |
+
height,
|
665 |
+
width,
|
666 |
+
negative_prompt,
|
667 |
+
callback_on_step_end_tensor_inputs,
|
668 |
+
prompt_embeds,
|
669 |
+
negative_prompt_embeds,
|
670 |
+
)
|
671 |
+
self._guidance_scale = guidance_scale
|
672 |
+
self._interrupt = False
|
673 |
+
|
674 |
+
# 2. Default call parameters
|
675 |
+
if prompt is not None and isinstance(prompt, str):
|
676 |
+
batch_size = 1
|
677 |
+
elif prompt is not None and isinstance(prompt, list):
|
678 |
+
batch_size = len(prompt)
|
679 |
+
else:
|
680 |
+
batch_size = prompt_embeds.shape[0]
|
681 |
+
|
682 |
+
device = self._execution_device
|
683 |
+
|
684 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
685 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
686 |
+
# corresponds to doing no classifier free guidance.
|
687 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
688 |
+
|
689 |
+
# 3. Encode input prompt
|
690 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
691 |
+
prompt,
|
692 |
+
negative_prompt,
|
693 |
+
do_classifier_free_guidance,
|
694 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
695 |
+
prompt_embeds=prompt_embeds,
|
696 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
697 |
+
max_sequence_length=max_sequence_length,
|
698 |
+
device=device,
|
699 |
+
)
|
700 |
+
if do_classifier_free_guidance:
|
701 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
702 |
+
|
703 |
+
# 4. Prepare timesteps
|
704 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
705 |
+
self._num_timesteps = len(timesteps)
|
706 |
+
if comfyui_progressbar:
|
707 |
+
from comfy.utils import ProgressBar
|
708 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
709 |
+
|
710 |
+
# 5. Prepare latents.
|
711 |
+
latent_channels = self.vae.config.latent_channels
|
712 |
+
latents = self.prepare_latents(
|
713 |
+
batch_size * num_videos_per_prompt,
|
714 |
+
latent_channels,
|
715 |
+
num_frames,
|
716 |
+
height,
|
717 |
+
width,
|
718 |
+
prompt_embeds.dtype,
|
719 |
+
device,
|
720 |
+
generator,
|
721 |
+
latents,
|
722 |
+
)
|
723 |
+
if comfyui_progressbar:
|
724 |
+
pbar.update(1)
|
725 |
+
|
726 |
+
if control_video is not None:
|
727 |
+
video_length = control_video.shape[2]
|
728 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
729 |
+
control_video = control_video.to(dtype=torch.float32)
|
730 |
+
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
731 |
+
else:
|
732 |
+
control_video = None
|
733 |
+
control_video_latents = self.prepare_control_latents(
|
734 |
+
None,
|
735 |
+
control_video,
|
736 |
+
batch_size,
|
737 |
+
height,
|
738 |
+
width,
|
739 |
+
prompt_embeds.dtype,
|
740 |
+
device,
|
741 |
+
generator,
|
742 |
+
do_classifier_free_guidance
|
743 |
+
)[1]
|
744 |
+
control_video_latents_input = (
|
745 |
+
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
746 |
+
)
|
747 |
+
control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
|
748 |
+
|
749 |
+
if comfyui_progressbar:
|
750 |
+
pbar.update(1)
|
751 |
+
|
752 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
753 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
754 |
+
|
755 |
+
# 7. Create rotary embeds if required
|
756 |
+
image_rotary_emb = (
|
757 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
758 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
759 |
+
else None
|
760 |
+
)
|
761 |
+
|
762 |
+
# 8. Denoising loop
|
763 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
764 |
+
|
765 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
766 |
+
# for DPM-solver++
|
767 |
+
old_pred_original_sample = None
|
768 |
+
for i, t in enumerate(timesteps):
|
769 |
+
if self.interrupt:
|
770 |
+
continue
|
771 |
+
|
772 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
773 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
774 |
+
|
775 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
776 |
+
timestep = t.expand(latent_model_input.shape[0])
|
777 |
+
|
778 |
+
# predict noise model_output
|
779 |
+
noise_pred = self.transformer(
|
780 |
+
hidden_states=latent_model_input,
|
781 |
+
encoder_hidden_states=prompt_embeds,
|
782 |
+
timestep=timestep,
|
783 |
+
image_rotary_emb=image_rotary_emb,
|
784 |
+
return_dict=False,
|
785 |
+
control_latents=control_latents,
|
786 |
+
)[0]
|
787 |
+
noise_pred = noise_pred.float()
|
788 |
+
|
789 |
+
# perform guidance
|
790 |
+
if use_dynamic_cfg:
|
791 |
+
self._guidance_scale = 1 + guidance_scale * (
|
792 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
793 |
+
)
|
794 |
+
if do_classifier_free_guidance:
|
795 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
796 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
797 |
+
|
798 |
+
# compute the previous noisy sample x_t -> x_t-1
|
799 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
800 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
801 |
+
else:
|
802 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
803 |
+
noise_pred,
|
804 |
+
old_pred_original_sample,
|
805 |
+
t,
|
806 |
+
timesteps[i - 1] if i > 0 else None,
|
807 |
+
latents,
|
808 |
+
**extra_step_kwargs,
|
809 |
+
return_dict=False,
|
810 |
+
)
|
811 |
+
latents = latents.to(prompt_embeds.dtype)
|
812 |
+
|
813 |
+
# call the callback, if provided
|
814 |
+
if callback_on_step_end is not None:
|
815 |
+
callback_kwargs = {}
|
816 |
+
for k in callback_on_step_end_tensor_inputs:
|
817 |
+
callback_kwargs[k] = locals()[k]
|
818 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
819 |
+
|
820 |
+
latents = callback_outputs.pop("latents", latents)
|
821 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
822 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
823 |
+
|
824 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
825 |
+
progress_bar.update()
|
826 |
+
if comfyui_progressbar:
|
827 |
+
pbar.update(1)
|
828 |
+
|
829 |
+
if output_type == "numpy":
|
830 |
+
video = self.decode_latents(latents)
|
831 |
+
elif not output_type == "latent":
|
832 |
+
video = self.decode_latents(latents)
|
833 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
834 |
+
else:
|
835 |
+
video = latents
|
836 |
+
|
837 |
+
# Offload all models
|
838 |
+
self.maybe_free_model_hooks()
|
839 |
+
|
840 |
+
if not return_dict:
|
841 |
+
video = torch.from_numpy(video)
|
842 |
+
|
843 |
+
return CogVideoX_Fun_PipelineOutput(videos=video)
|
cogvideox/pipeline/pipeline_cogvideox_inpaint.py
CHANGED
@@ -177,6 +177,19 @@ def resize_mask(mask, latent, process_first_frame_only=True):
|
|
177 |
return resized_mask
|
178 |
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
@dataclass
|
181 |
class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
182 |
r"""
|
@@ -444,7 +457,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
|
|
444 |
return outputs
|
445 |
|
446 |
def prepare_mask_latents(
|
447 |
-
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
448 |
):
|
449 |
# resize the mask to latents shape as we concatenate the mask to the latents
|
450 |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
@@ -463,6 +476,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
|
|
463 |
mask = mask * self.vae.config.scaling_factor
|
464 |
|
465 |
if masked_image is not None:
|
|
|
|
|
466 |
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
467 |
bs = 1
|
468 |
new_mask_pixel_values = []
|
@@ -650,6 +665,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
|
|
650 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
651 |
max_sequence_length: int = 226,
|
652 |
strength: float = 1,
|
|
|
653 |
comfyui_progressbar: bool = False,
|
654 |
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
655 |
"""
|
@@ -866,6 +882,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
|
|
866 |
device,
|
867 |
generator,
|
868 |
do_classifier_free_guidance,
|
|
|
869 |
)
|
870 |
mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
|
871 |
mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
|
|
|
177 |
return resized_mask
|
178 |
|
179 |
|
180 |
+
def add_noise_to_reference_video(image, ratio=None):
|
181 |
+
if ratio is None:
|
182 |
+
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
183 |
+
sigma = torch.exp(sigma).to(image.dtype)
|
184 |
+
else:
|
185 |
+
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
186 |
+
|
187 |
+
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
188 |
+
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
189 |
+
image = image + image_noise
|
190 |
+
return image
|
191 |
+
|
192 |
+
|
193 |
@dataclass
|
194 |
class CogVideoX_Fun_PipelineOutput(BaseOutput):
|
195 |
r"""
|
|
|
457 |
return outputs
|
458 |
|
459 |
def prepare_mask_latents(
|
460 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
461 |
):
|
462 |
# resize the mask to latents shape as we concatenate the mask to the latents
|
463 |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
|
|
476 |
mask = mask * self.vae.config.scaling_factor
|
477 |
|
478 |
if masked_image is not None:
|
479 |
+
if self.transformer.config.add_noise_in_inpaint_model:
|
480 |
+
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
481 |
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
482 |
bs = 1
|
483 |
new_mask_pixel_values = []
|
|
|
665 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
666 |
max_sequence_length: int = 226,
|
667 |
strength: float = 1,
|
668 |
+
noise_aug_strength: float = 0.0563,
|
669 |
comfyui_progressbar: bool = False,
|
670 |
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
671 |
"""
|
|
|
882 |
device,
|
883 |
generator,
|
884 |
do_classifier_free_guidance,
|
885 |
+
noise_aug_strength=noise_aug_strength,
|
886 |
)
|
887 |
mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
|
888 |
mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
|
cogvideox/ui/ui.py
CHANGED
@@ -30,6 +30,8 @@ from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
|
30 |
from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
|
31 |
from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
|
32 |
from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
|
|
|
|
|
33 |
from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
|
34 |
CogVideoX_Fun_Pipeline_Inpaint
|
35 |
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
|
@@ -58,7 +60,7 @@ css = """
|
|
58 |
}
|
59 |
"""
|
60 |
|
61 |
-
class
|
62 |
def __init__(self, low_gpu_memory_mode, weight_dtype):
|
63 |
# config dirs
|
64 |
self.basedir = os.getcwd()
|
@@ -68,6 +70,7 @@ class CogVideoX_I2VController:
|
|
68 |
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
69 |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
70 |
self.savedir_sample = os.path.join(self.savedir, "sample")
|
|
|
71 |
os.makedirs(self.savedir, exist_ok=True)
|
72 |
|
73 |
self.diffusion_transformer_list = []
|
@@ -102,6 +105,9 @@ class CogVideoX_I2VController:
|
|
102 |
personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
|
103 |
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
104 |
|
|
|
|
|
|
|
105 |
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
106 |
print("Update diffusion transformer")
|
107 |
if diffusion_transformer_dropdown == "none":
|
@@ -118,16 +124,25 @@ class CogVideoX_I2VController:
|
|
118 |
).to(self.weight_dtype)
|
119 |
|
120 |
# Get pipeline
|
121 |
-
if self.
|
122 |
-
self.
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
else:
|
130 |
-
self.pipeline =
|
131 |
diffusion_transformer_dropdown,
|
132 |
vae=self.vae,
|
133 |
transformer=self.transformer,
|
@@ -191,6 +206,8 @@ class CogVideoX_I2VController:
|
|
191 |
start_image,
|
192 |
end_image,
|
193 |
validation_video,
|
|
|
|
|
194 |
denoise_strength,
|
195 |
seed_textbox,
|
196 |
is_api = False,
|
@@ -208,20 +225,34 @@ class CogVideoX_I2VController:
|
|
208 |
if self.lora_model_path != lora_model_dropdown:
|
209 |
print("Update lora model")
|
210 |
self.update_lora_model(lora_model_dropdown)
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
if resize_method == "Resize according to Reference":
|
213 |
-
if start_image is None and validation_video is None:
|
214 |
if is_api:
|
215 |
return "", f"Please upload an image when using \"Resize according to Reference\"."
|
216 |
else:
|
217 |
raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
|
218 |
|
219 |
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
223 |
else:
|
224 |
-
original_width, original_height =
|
225 |
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
226 |
height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
|
227 |
|
@@ -255,75 +286,91 @@ class CogVideoX_I2VController:
|
|
255 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
256 |
|
257 |
try:
|
258 |
-
if self.
|
259 |
-
if
|
260 |
-
if
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
break
|
271 |
-
else:
|
272 |
-
_partial_video_length = partial_video_length
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
with torch.no_grad():
|
280 |
-
sample = self.pipeline(
|
281 |
-
prompt_textbox,
|
282 |
-
negative_prompt = negative_prompt_textbox,
|
283 |
-
num_inference_steps = sample_step_slider,
|
284 |
-
guidance_scale = cfg_scale_slider,
|
285 |
-
width = width_slider,
|
286 |
-
height = height_slider,
|
287 |
-
num_frames = _partial_video_length,
|
288 |
-
generator = generator,
|
289 |
-
|
290 |
-
video = input_video,
|
291 |
-
mask_video = input_video_mask,
|
292 |
-
strength = 1,
|
293 |
-
).videos
|
294 |
-
|
295 |
-
if init_frames != 0:
|
296 |
-
mix_ratio = torch.from_numpy(
|
297 |
-
np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
|
298 |
-
).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
299 |
-
|
300 |
-
new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
|
301 |
-
sample[:, :, :overlap_video_length] * mix_ratio
|
302 |
-
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
|
303 |
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
305 |
else:
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
319 |
else:
|
320 |
-
if validation_video is not None:
|
321 |
-
input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
322 |
-
strength = denoise_strength
|
323 |
-
else:
|
324 |
-
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
325 |
-
strength = 1
|
326 |
-
|
327 |
sample = self.pipeline(
|
328 |
prompt_textbox,
|
329 |
negative_prompt = negative_prompt_textbox,
|
@@ -332,13 +379,11 @@ class CogVideoX_I2VController:
|
|
332 |
width = width_slider,
|
333 |
height = height_slider,
|
334 |
num_frames = length_slider if not is_image else 1,
|
335 |
-
generator = generator
|
336 |
-
|
337 |
-
video = input_video,
|
338 |
-
mask_video = input_video_mask,
|
339 |
-
strength = strength,
|
340 |
).videos
|
341 |
else:
|
|
|
|
|
342 |
sample = self.pipeline(
|
343 |
prompt_textbox,
|
344 |
negative_prompt = negative_prompt_textbox,
|
@@ -347,7 +392,9 @@ class CogVideoX_I2VController:
|
|
347 |
width = width_slider,
|
348 |
height = height_slider,
|
349 |
num_frames = length_slider if not is_image else 1,
|
350 |
-
generator = generator
|
|
|
|
|
351 |
).videos
|
352 |
except Exception as e:
|
353 |
gc.collect()
|
@@ -422,7 +469,7 @@ class CogVideoX_I2VController:
|
|
422 |
|
423 |
|
424 |
def ui(low_gpu_memory_mode, weight_dtype):
|
425 |
-
controller =
|
426 |
|
427 |
with gr.Blocks(css=css) as demo:
|
428 |
gr.Markdown(
|
@@ -437,7 +484,20 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
437 |
with gr.Column(variant="panel"):
|
438 |
gr.Markdown(
|
439 |
"""
|
440 |
-
### 1. Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
"""
|
442 |
)
|
443 |
with gr.Row():
|
@@ -488,12 +548,12 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
488 |
with gr.Column(variant="panel"):
|
489 |
gr.Markdown(
|
490 |
"""
|
491 |
-
###
|
492 |
"""
|
493 |
)
|
494 |
|
495 |
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
496 |
-
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange
|
497 |
|
498 |
with gr.Row():
|
499 |
with gr.Column():
|
@@ -522,7 +582,7 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
522 |
partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
|
523 |
|
524 |
source_method = gr.Radio(
|
525 |
-
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
|
526 |
value="Text to Video (文本到视频)",
|
527 |
show_label=False,
|
528 |
)
|
@@ -535,7 +595,7 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
535 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
536 |
def select_template(evt: gr.SelectData):
|
537 |
text = {
|
538 |
-
"asset/1.png": "The dog is
|
539 |
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
540 |
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
541 |
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
@@ -557,13 +617,36 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
557 |
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
558 |
|
559 |
with gr.Column(visible = False) as video_to_video_col:
|
560 |
-
|
561 |
-
|
562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
)
|
564 |
-
denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
|
565 |
|
566 |
-
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=
|
567 |
|
568 |
with gr.Row():
|
569 |
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
@@ -585,6 +668,12 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
585 |
interactive=False
|
586 |
)
|
587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
def upload_generation_method(generation_method):
|
589 |
if generation_method == "Video Generation":
|
590 |
return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
|
@@ -598,13 +687,18 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
598 |
|
599 |
def upload_source_method(source_method):
|
600 |
if source_method == "Text to Video (文本到视频)":
|
601 |
-
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
602 |
elif source_method == "Image to Video (图片到视频)":
|
603 |
-
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
|
|
|
|
|
604 |
else:
|
605 |
-
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
|
606 |
source_method.change(
|
607 |
-
upload_source_method, source_method, [
|
|
|
|
|
|
|
608 |
)
|
609 |
|
610 |
def upload_resize_method(resize_method):
|
@@ -639,6 +733,8 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
639 |
start_image,
|
640 |
end_image,
|
641 |
validation_video,
|
|
|
|
|
642 |
denoise_strength,
|
643 |
seed_textbox,
|
644 |
],
|
@@ -647,8 +743,8 @@ def ui(low_gpu_memory_mode, weight_dtype):
|
|
647 |
return demo, controller
|
648 |
|
649 |
|
650 |
-
class
|
651 |
-
def __init__(self, model_name, savedir_sample, low_gpu_memory_mode, weight_dtype):
|
652 |
# Basic dir
|
653 |
self.basedir = os.getcwd()
|
654 |
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
@@ -658,6 +754,7 @@ class CogVideoX_I2VController_Modelscope:
|
|
658 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
659 |
|
660 |
# model path
|
|
|
661 |
self.weight_dtype = weight_dtype
|
662 |
|
663 |
self.vae = AutoencoderKLCogVideoX.from_pretrained(
|
@@ -672,16 +769,25 @@ class CogVideoX_I2VController_Modelscope:
|
|
672 |
).to(self.weight_dtype)
|
673 |
|
674 |
# Get pipeline
|
675 |
-
if
|
676 |
-
self.
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
else:
|
684 |
-
self.pipeline =
|
685 |
model_name,
|
686 |
vae=self.vae,
|
687 |
transformer=self.transformer,
|
@@ -733,6 +839,8 @@ class CogVideoX_I2VController_Modelscope:
|
|
733 |
start_image,
|
734 |
end_image,
|
735 |
validation_video,
|
|
|
|
|
736 |
denoise_strength,
|
737 |
seed_textbox,
|
738 |
is_api = False,
|
@@ -747,25 +855,48 @@ class CogVideoX_I2VController_Modelscope:
|
|
747 |
if self.lora_model_path != lora_model_dropdown:
|
748 |
print("Update lora model")
|
749 |
self.update_lora_model(lora_model_dropdown)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
|
751 |
if resize_method == "Resize according to Reference":
|
752 |
-
if start_image is None and validation_video is None:
|
753 |
-
|
|
|
|
|
|
|
754 |
|
755 |
-
aspect_ratio_sample_size
|
756 |
-
|
757 |
-
|
758 |
-
|
|
|
|
|
759 |
else:
|
760 |
-
original_width, original_height =
|
761 |
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
762 |
height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
|
763 |
|
764 |
if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
|
765 |
-
|
766 |
-
|
|
|
|
|
|
|
767 |
if start_image is None and end_image is not None:
|
768 |
-
|
|
|
|
|
|
|
769 |
|
770 |
is_image = True if generation_method == "Image Generation" else False
|
771 |
|
@@ -779,13 +910,42 @@ class CogVideoX_I2VController_Modelscope:
|
|
779 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
780 |
|
781 |
try:
|
782 |
-
if self.
|
783 |
-
if
|
784 |
-
|
785 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
else:
|
787 |
-
|
788 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
|
790 |
sample = self.pipeline(
|
791 |
prompt_textbox,
|
@@ -797,20 +957,7 @@ class CogVideoX_I2VController_Modelscope:
|
|
797 |
num_frames = length_slider if not is_image else 1,
|
798 |
generator = generator,
|
799 |
|
800 |
-
|
801 |
-
mask_video = input_video_mask,
|
802 |
-
strength = strength,
|
803 |
-
).videos
|
804 |
-
else:
|
805 |
-
sample = self.pipeline(
|
806 |
-
prompt_textbox,
|
807 |
-
negative_prompt = negative_prompt_textbox,
|
808 |
-
num_inference_steps = sample_step_slider,
|
809 |
-
guidance_scale = cfg_scale_slider,
|
810 |
-
width = width_slider,
|
811 |
-
height = height_slider,
|
812 |
-
num_frames = length_slider if not is_image else 1,
|
813 |
-
generator = generator
|
814 |
).videos
|
815 |
except Exception as e:
|
816 |
gc.collect()
|
@@ -866,8 +1013,8 @@ class CogVideoX_I2VController_Modelscope:
|
|
866 |
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
867 |
|
868 |
|
869 |
-
def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype):
|
870 |
-
controller =
|
871 |
|
872 |
with gr.Blocks(css=css) as demo:
|
873 |
gr.Markdown(
|
@@ -882,7 +1029,20 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
882 |
with gr.Column(variant="panel"):
|
883 |
gr.Markdown(
|
884 |
"""
|
885 |
-
### 1. Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
886 |
"""
|
887 |
)
|
888 |
with gr.Row():
|
@@ -919,12 +1079,12 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
919 |
with gr.Column(variant="panel"):
|
920 |
gr.Markdown(
|
921 |
"""
|
922 |
-
###
|
923 |
"""
|
924 |
)
|
925 |
|
926 |
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
927 |
-
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange
|
928 |
|
929 |
with gr.Row():
|
930 |
with gr.Column():
|
@@ -953,7 +1113,7 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
953 |
partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
|
954 |
|
955 |
source_method = gr.Radio(
|
956 |
-
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
|
957 |
value="Text to Video (文本到视频)",
|
958 |
show_label=False,
|
959 |
)
|
@@ -964,7 +1124,7 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
964 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
965 |
def select_template(evt: gr.SelectData):
|
966 |
text = {
|
967 |
-
"asset/1.png": "The dog is
|
968 |
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
969 |
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
970 |
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
@@ -986,13 +1146,36 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
986 |
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
987 |
|
988 |
with gr.Column(visible = False) as video_to_video_col:
|
989 |
-
|
990 |
-
|
991 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
992 |
)
|
993 |
-
denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
|
994 |
|
995 |
-
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=
|
996 |
|
997 |
with gr.Row():
|
998 |
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
@@ -1025,13 +1208,18 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
1025 |
|
1026 |
def upload_source_method(source_method):
|
1027 |
if source_method == "Text to Video (文本到视频)":
|
1028 |
-
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
1029 |
elif source_method == "Image to Video (图片到视频)":
|
1030 |
-
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
|
|
|
|
|
1031 |
else:
|
1032 |
-
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
|
1033 |
source_method.change(
|
1034 |
-
upload_source_method, source_method, [
|
|
|
|
|
|
|
1035 |
)
|
1036 |
|
1037 |
def upload_resize_method(resize_method):
|
@@ -1066,6 +1254,8 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
|
1066 |
start_image,
|
1067 |
end_image,
|
1068 |
validation_video,
|
|
|
|
|
1069 |
denoise_strength,
|
1070 |
seed_textbox,
|
1071 |
],
|
@@ -1080,7 +1270,7 @@ def post_eas(
|
|
1080 |
prompt_textbox, negative_prompt_textbox,
|
1081 |
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
1082 |
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
1083 |
-
start_image, end_image, validation_video, denoise_strength, seed_textbox,
|
1084 |
):
|
1085 |
if start_image is not None:
|
1086 |
with open(start_image, 'rb') as file:
|
@@ -1100,6 +1290,12 @@ def post_eas(
|
|
1100 |
validation_video_encoded_content = base64.b64encode(file_content)
|
1101 |
validation_video = validation_video_encoded_content.decode('utf-8')
|
1102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1103 |
datas = {
|
1104 |
"base_model_path": base_model_dropdown,
|
1105 |
"lora_model_path": lora_model_dropdown,
|
@@ -1118,6 +1314,7 @@ def post_eas(
|
|
1118 |
"start_image": start_image,
|
1119 |
"end_image": end_image,
|
1120 |
"validation_video": validation_video,
|
|
|
1121 |
"denoise_strength": denoise_strength,
|
1122 |
"seed_textbox": seed_textbox,
|
1123 |
}
|
@@ -1131,7 +1328,7 @@ def post_eas(
|
|
1131 |
return outputs
|
1132 |
|
1133 |
|
1134 |
-
class
|
1135 |
def __init__(self, model_name, savedir_sample):
|
1136 |
self.savedir_sample = savedir_sample
|
1137 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
@@ -1156,6 +1353,7 @@ class CogVideoX_I2VController_EAS:
|
|
1156 |
start_image,
|
1157 |
end_image,
|
1158 |
validation_video,
|
|
|
1159 |
denoise_strength,
|
1160 |
seed_textbox
|
1161 |
):
|
@@ -1167,7 +1365,7 @@ class CogVideoX_I2VController_EAS:
|
|
1167 |
prompt_textbox, negative_prompt_textbox,
|
1168 |
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
1169 |
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
1170 |
-
start_image, end_image, validation_video, denoise_strength,
|
1171 |
seed_textbox
|
1172 |
)
|
1173 |
try:
|
@@ -1201,7 +1399,7 @@ class CogVideoX_I2VController_EAS:
|
|
1201 |
|
1202 |
|
1203 |
def ui_eas(model_name, savedir_sample):
|
1204 |
-
controller =
|
1205 |
|
1206 |
with gr.Blocks(css=css) as demo:
|
1207 |
gr.Markdown(
|
@@ -1216,7 +1414,7 @@ def ui_eas(model_name, savedir_sample):
|
|
1216 |
with gr.Column(variant="panel"):
|
1217 |
gr.Markdown(
|
1218 |
"""
|
1219 |
-
### 1. Model checkpoints.
|
1220 |
"""
|
1221 |
)
|
1222 |
with gr.Row():
|
@@ -1258,7 +1456,7 @@ def ui_eas(model_name, savedir_sample):
|
|
1258 |
)
|
1259 |
|
1260 |
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
1261 |
-
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange
|
1262 |
|
1263 |
with gr.Row():
|
1264 |
with gr.Column():
|
@@ -1295,7 +1493,7 @@ def ui_eas(model_name, savedir_sample):
|
|
1295 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1296 |
def select_template(evt: gr.SelectData):
|
1297 |
text = {
|
1298 |
-
"asset/1.png": "The dog is
|
1299 |
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1300 |
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1301 |
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
@@ -1317,13 +1515,25 @@ def ui_eas(model_name, savedir_sample):
|
|
1317 |
end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
|
1318 |
|
1319 |
with gr.Column(visible = False) as video_to_video_col:
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
1325 |
-
|
1326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1327 |
|
1328 |
with gr.Row():
|
1329 |
seed_textbox = gr.Textbox(label="Seed", value=43)
|
@@ -1347,7 +1557,7 @@ def ui_eas(model_name, savedir_sample):
|
|
1347 |
|
1348 |
def upload_generation_method(generation_method):
|
1349 |
if generation_method == "Video Generation":
|
1350 |
-
return gr.update(visible=True, minimum=5, maximum=
|
1351 |
elif generation_method == "Image Generation":
|
1352 |
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
1353 |
generation_method.change(
|
@@ -1356,13 +1566,13 @@ def ui_eas(model_name, savedir_sample):
|
|
1356 |
|
1357 |
def upload_source_method(source_method):
|
1358 |
if source_method == "Text to Video (文本到视频)":
|
1359 |
-
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
1360 |
elif source_method == "Image to Video (图片到视频)":
|
1361 |
-
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
|
1362 |
else:
|
1363 |
-
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
|
1364 |
source_method.change(
|
1365 |
-
upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
|
1366 |
)
|
1367 |
|
1368 |
def upload_resize_method(resize_method):
|
@@ -1395,6 +1605,7 @@ def ui_eas(model_name, savedir_sample):
|
|
1395 |
start_image,
|
1396 |
end_image,
|
1397 |
validation_video,
|
|
|
1398 |
denoise_strength,
|
1399 |
seed_textbox,
|
1400 |
],
|
|
|
30 |
from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
|
31 |
from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
|
32 |
from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
|
33 |
+
from cogvideox.pipeline.pipeline_cogvideox_control import \
|
34 |
+
CogVideoX_Fun_Pipeline_Control
|
35 |
from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
|
36 |
CogVideoX_Fun_Pipeline_Inpaint
|
37 |
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
|
|
|
60 |
}
|
61 |
"""
|
62 |
|
63 |
+
class CogVideoX_Fun_Controller:
|
64 |
def __init__(self, low_gpu_memory_mode, weight_dtype):
|
65 |
# config dirs
|
66 |
self.basedir = os.getcwd()
|
|
|
70 |
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
71 |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
72 |
self.savedir_sample = os.path.join(self.savedir, "sample")
|
73 |
+
self.model_type = "Inpaint"
|
74 |
os.makedirs(self.savedir, exist_ok=True)
|
75 |
|
76 |
self.diffusion_transformer_list = []
|
|
|
105 |
personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
|
106 |
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
107 |
|
108 |
+
def update_model_type(self, model_type):
|
109 |
+
self.model_type = model_type
|
110 |
+
|
111 |
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
112 |
print("Update diffusion transformer")
|
113 |
if diffusion_transformer_dropdown == "none":
|
|
|
124 |
).to(self.weight_dtype)
|
125 |
|
126 |
# Get pipeline
|
127 |
+
if self.model_type == "Inpaint":
|
128 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
129 |
+
self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
|
130 |
+
diffusion_transformer_dropdown,
|
131 |
+
vae=self.vae,
|
132 |
+
transformer=self.transformer,
|
133 |
+
scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
134 |
+
torch_dtype=self.weight_dtype
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
|
138 |
+
diffusion_transformer_dropdown,
|
139 |
+
vae=self.vae,
|
140 |
+
transformer=self.transformer,
|
141 |
+
scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
142 |
+
torch_dtype=self.weight_dtype
|
143 |
+
)
|
144 |
else:
|
145 |
+
self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
|
146 |
diffusion_transformer_dropdown,
|
147 |
vae=self.vae,
|
148 |
transformer=self.transformer,
|
|
|
206 |
start_image,
|
207 |
end_image,
|
208 |
validation_video,
|
209 |
+
validation_video_mask,
|
210 |
+
control_video,
|
211 |
denoise_strength,
|
212 |
seed_textbox,
|
213 |
is_api = False,
|
|
|
225 |
if self.lora_model_path != lora_model_dropdown:
|
226 |
print("Update lora model")
|
227 |
self.update_lora_model(lora_model_dropdown)
|
228 |
+
|
229 |
+
if control_video is not None and self.model_type == "Inpaint":
|
230 |
+
if is_api:
|
231 |
+
return "", f"If specifying the control video, please set the model_type == \"Control\". "
|
232 |
+
else:
|
233 |
+
raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
|
234 |
+
|
235 |
+
if control_video is None and self.model_type == "Control":
|
236 |
+
if is_api:
|
237 |
+
return "", f"If set the model_type == \"Control\", please specifying the control video. "
|
238 |
+
else:
|
239 |
+
raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
|
240 |
+
|
241 |
if resize_method == "Resize according to Reference":
|
242 |
+
if start_image is None and validation_video is None and control_video is None:
|
243 |
if is_api:
|
244 |
return "", f"Please upload an image when using \"Resize according to Reference\"."
|
245 |
else:
|
246 |
raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
|
247 |
|
248 |
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
249 |
+
if self.model_type == "Inpaint":
|
250 |
+
if validation_video is not None:
|
251 |
+
original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
|
252 |
+
else:
|
253 |
+
original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
|
254 |
else:
|
255 |
+
original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
|
256 |
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
257 |
height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
|
258 |
|
|
|
286 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
287 |
|
288 |
try:
|
289 |
+
if self.model_type == "Inpaint":
|
290 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
291 |
+
if generation_method == "Long Video Generation":
|
292 |
+
if validation_video is not None:
|
293 |
+
raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
|
294 |
+
init_frames = 0
|
295 |
+
last_frames = init_frames + partial_video_length
|
296 |
+
while init_frames < length_slider:
|
297 |
+
if last_frames >= length_slider:
|
298 |
+
_partial_video_length = length_slider - init_frames
|
299 |
+
_partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
|
300 |
+
|
301 |
+
if _partial_video_length <= 0:
|
302 |
+
break
|
303 |
+
else:
|
304 |
+
_partial_video_length = partial_video_length
|
305 |
+
|
306 |
+
if last_frames >= length_slider:
|
307 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
308 |
+
else:
|
309 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
310 |
+
|
311 |
+
with torch.no_grad():
|
312 |
+
sample = self.pipeline(
|
313 |
+
prompt_textbox,
|
314 |
+
negative_prompt = negative_prompt_textbox,
|
315 |
+
num_inference_steps = sample_step_slider,
|
316 |
+
guidance_scale = cfg_scale_slider,
|
317 |
+
width = width_slider,
|
318 |
+
height = height_slider,
|
319 |
+
num_frames = _partial_video_length,
|
320 |
+
generator = generator,
|
321 |
+
|
322 |
+
video = input_video,
|
323 |
+
mask_video = input_video_mask,
|
324 |
+
strength = 1,
|
325 |
+
).videos
|
326 |
|
327 |
+
if init_frames != 0:
|
328 |
+
mix_ratio = torch.from_numpy(
|
329 |
+
np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
|
330 |
+
).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
331 |
+
|
332 |
+
new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
|
333 |
+
sample[:, :, :overlap_video_length] * mix_ratio
|
334 |
+
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
|
335 |
+
|
336 |
+
sample = new_sample
|
337 |
+
else:
|
338 |
+
new_sample = sample
|
339 |
+
|
340 |
+
if last_frames >= length_slider:
|
341 |
break
|
|
|
|
|
342 |
|
343 |
+
start_image = [
|
344 |
+
Image.fromarray(
|
345 |
+
(sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
|
346 |
+
) for _index in range(-overlap_video_length, 0)
|
347 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
+
init_frames = init_frames + _partial_video_length - overlap_video_length
|
350 |
+
last_frames = init_frames + _partial_video_length
|
351 |
+
else:
|
352 |
+
if validation_video is not None:
|
353 |
+
input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
|
354 |
+
strength = denoise_strength
|
355 |
else:
|
356 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
357 |
+
strength = 1
|
358 |
+
|
359 |
+
sample = self.pipeline(
|
360 |
+
prompt_textbox,
|
361 |
+
negative_prompt = negative_prompt_textbox,
|
362 |
+
num_inference_steps = sample_step_slider,
|
363 |
+
guidance_scale = cfg_scale_slider,
|
364 |
+
width = width_slider,
|
365 |
+
height = height_slider,
|
366 |
+
num_frames = length_slider if not is_image else 1,
|
367 |
+
generator = generator,
|
368 |
+
|
369 |
+
video = input_video,
|
370 |
+
mask_video = input_video_mask,
|
371 |
+
strength = strength,
|
372 |
+
).videos
|
373 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
sample = self.pipeline(
|
375 |
prompt_textbox,
|
376 |
negative_prompt = negative_prompt_textbox,
|
|
|
379 |
width = width_slider,
|
380 |
height = height_slider,
|
381 |
num_frames = length_slider if not is_image else 1,
|
382 |
+
generator = generator
|
|
|
|
|
|
|
|
|
383 |
).videos
|
384 |
else:
|
385 |
+
input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
|
386 |
+
|
387 |
sample = self.pipeline(
|
388 |
prompt_textbox,
|
389 |
negative_prompt = negative_prompt_textbox,
|
|
|
392 |
width = width_slider,
|
393 |
height = height_slider,
|
394 |
num_frames = length_slider if not is_image else 1,
|
395 |
+
generator = generator,
|
396 |
+
|
397 |
+
control_video = input_video,
|
398 |
).videos
|
399 |
except Exception as e:
|
400 |
gc.collect()
|
|
|
469 |
|
470 |
|
471 |
def ui(low_gpu_memory_mode, weight_dtype):
|
472 |
+
controller = CogVideoX_Fun_Controller(low_gpu_memory_mode, weight_dtype)
|
473 |
|
474 |
with gr.Blocks(css=css) as demo:
|
475 |
gr.Markdown(
|
|
|
484 |
with gr.Column(variant="panel"):
|
485 |
gr.Markdown(
|
486 |
"""
|
487 |
+
### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
|
488 |
+
"""
|
489 |
+
)
|
490 |
+
with gr.Row():
|
491 |
+
model_type = gr.Dropdown(
|
492 |
+
label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
|
493 |
+
choices=["Inpaint", "Control"],
|
494 |
+
value="Inpaint",
|
495 |
+
interactive=True,
|
496 |
+
)
|
497 |
+
|
498 |
+
gr.Markdown(
|
499 |
+
"""
|
500 |
+
### 2. Model checkpoints (模型路径).
|
501 |
"""
|
502 |
)
|
503 |
with gr.Row():
|
|
|
548 |
with gr.Column(variant="panel"):
|
549 |
gr.Markdown(
|
550 |
"""
|
551 |
+
### 3. Configs for Generation (生成参数配置).
|
552 |
"""
|
553 |
)
|
554 |
|
555 |
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
556 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
|
557 |
|
558 |
with gr.Row():
|
559 |
with gr.Column():
|
|
|
582 |
partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
|
583 |
|
584 |
source_method = gr.Radio(
|
585 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
|
586 |
value="Text to Video (文本到视频)",
|
587 |
show_label=False,
|
588 |
)
|
|
|
595 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
596 |
def select_template(evt: gr.SelectData):
|
597 |
text = {
|
598 |
+
"asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
599 |
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
600 |
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
601 |
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
|
|
617 |
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
618 |
|
619 |
with gr.Column(visible = False) as video_to_video_col:
|
620 |
+
with gr.Row():
|
621 |
+
validation_video = gr.Video(
|
622 |
+
label="The video to convert (视频转视频的参考视频)", show_label=True,
|
623 |
+
elem_id="v2v", sources="upload",
|
624 |
+
)
|
625 |
+
with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
|
626 |
+
gr.Markdown(
|
627 |
+
"""
|
628 |
+
- Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
|
629 |
+
- (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
|
630 |
+
"""
|
631 |
+
)
|
632 |
+
validation_video_mask = gr.Image(
|
633 |
+
label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
|
634 |
+
show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
|
635 |
+
)
|
636 |
+
denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
|
637 |
+
|
638 |
+
with gr.Column(visible = False) as control_video_col:
|
639 |
+
gr.Markdown(
|
640 |
+
"""
|
641 |
+
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
642 |
+
"""
|
643 |
+
)
|
644 |
+
control_video = gr.Video(
|
645 |
+
label="The control video (用于提供控制信号的video)", show_label=True,
|
646 |
+
elem_id="v2v_control", sources="upload",
|
647 |
)
|
|
|
648 |
|
649 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
|
650 |
|
651 |
with gr.Row():
|
652 |
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
|
|
668 |
interactive=False
|
669 |
)
|
670 |
|
671 |
+
model_type.change(
|
672 |
+
fn=controller.update_model_type,
|
673 |
+
inputs=[model_type],
|
674 |
+
outputs=[]
|
675 |
+
)
|
676 |
+
|
677 |
def upload_generation_method(generation_method):
|
678 |
if generation_method == "Video Generation":
|
679 |
return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
|
|
|
687 |
|
688 |
def upload_source_method(source_method):
|
689 |
if source_method == "Text to Video (文本到视频)":
|
690 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
691 |
elif source_method == "Image to Video (图片到视频)":
|
692 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
693 |
+
elif source_method == "Video to Video (视频到视频)":
|
694 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
695 |
else:
|
696 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
697 |
source_method.change(
|
698 |
+
upload_source_method, source_method, [
|
699 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
700 |
+
validation_video, validation_video_mask, control_video
|
701 |
+
]
|
702 |
)
|
703 |
|
704 |
def upload_resize_method(resize_method):
|
|
|
733 |
start_image,
|
734 |
end_image,
|
735 |
validation_video,
|
736 |
+
validation_video_mask,
|
737 |
+
control_video,
|
738 |
denoise_strength,
|
739 |
seed_textbox,
|
740 |
],
|
|
|
743 |
return demo, controller
|
744 |
|
745 |
|
746 |
+
class CogVideoX_Fun_Controller_Modelscope:
|
747 |
+
def __init__(self, model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
|
748 |
# Basic dir
|
749 |
self.basedir = os.getcwd()
|
750 |
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
|
|
754 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
755 |
|
756 |
# model path
|
757 |
+
self.model_type = model_type
|
758 |
self.weight_dtype = weight_dtype
|
759 |
|
760 |
self.vae = AutoencoderKLCogVideoX.from_pretrained(
|
|
|
769 |
).to(self.weight_dtype)
|
770 |
|
771 |
# Get pipeline
|
772 |
+
if model_type == "Inpaint":
|
773 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
774 |
+
self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
|
775 |
+
model_name,
|
776 |
+
vae=self.vae,
|
777 |
+
transformer=self.transformer,
|
778 |
+
scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
|
779 |
+
torch_dtype=self.weight_dtype
|
780 |
+
)
|
781 |
+
else:
|
782 |
+
self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
|
783 |
+
model_name,
|
784 |
+
vae=self.vae,
|
785 |
+
transformer=self.transformer,
|
786 |
+
scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
|
787 |
+
torch_dtype=self.weight_dtype
|
788 |
+
)
|
789 |
else:
|
790 |
+
self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
|
791 |
model_name,
|
792 |
vae=self.vae,
|
793 |
transformer=self.transformer,
|
|
|
839 |
start_image,
|
840 |
end_image,
|
841 |
validation_video,
|
842 |
+
validation_video_mask,
|
843 |
+
control_video,
|
844 |
denoise_strength,
|
845 |
seed_textbox,
|
846 |
is_api = False,
|
|
|
855 |
if self.lora_model_path != lora_model_dropdown:
|
856 |
print("Update lora model")
|
857 |
self.update_lora_model(lora_model_dropdown)
|
858 |
+
|
859 |
+
if control_video is not None and self.model_type == "Inpaint":
|
860 |
+
if is_api:
|
861 |
+
return "", f"If specifying the control video, please set the model_type == \"Control\". "
|
862 |
+
else:
|
863 |
+
raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
|
864 |
+
|
865 |
+
if control_video is None and self.model_type == "Control":
|
866 |
+
if is_api:
|
867 |
+
return "", f"If set the model_type == \"Control\", please specifying the control video. "
|
868 |
+
else:
|
869 |
+
raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
|
870 |
|
871 |
if resize_method == "Resize according to Reference":
|
872 |
+
if start_image is None and validation_video is None and control_video is None:
|
873 |
+
if is_api:
|
874 |
+
return "", f"Please upload an image when using \"Resize according to Reference\"."
|
875 |
+
else:
|
876 |
+
raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
|
877 |
|
878 |
+
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
879 |
+
if self.model_type == "Inpaint":
|
880 |
+
if validation_video is not None:
|
881 |
+
original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
|
882 |
+
else:
|
883 |
+
original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
|
884 |
else:
|
885 |
+
original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
|
886 |
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
887 |
height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
|
888 |
|
889 |
if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
|
890 |
+
if is_api:
|
891 |
+
return "", f"Please select an image to video pretrained model while using image to video."
|
892 |
+
else:
|
893 |
+
raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
|
894 |
+
|
895 |
if start_image is None and end_image is not None:
|
896 |
+
if is_api:
|
897 |
+
return "", f"If specifying the ending image of the video, please specify a starting image of the video."
|
898 |
+
else:
|
899 |
+
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
900 |
|
901 |
is_image = True if generation_method == "Image Generation" else False
|
902 |
|
|
|
910 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
911 |
|
912 |
try:
|
913 |
+
if self.model_type == "Inpaint":
|
914 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
915 |
+
if validation_video is not None:
|
916 |
+
input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
|
917 |
+
strength = denoise_strength
|
918 |
+
else:
|
919 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
920 |
+
strength = 1
|
921 |
+
|
922 |
+
sample = self.pipeline(
|
923 |
+
prompt_textbox,
|
924 |
+
negative_prompt = negative_prompt_textbox,
|
925 |
+
num_inference_steps = sample_step_slider,
|
926 |
+
guidance_scale = cfg_scale_slider,
|
927 |
+
width = width_slider,
|
928 |
+
height = height_slider,
|
929 |
+
num_frames = length_slider if not is_image else 1,
|
930 |
+
generator = generator,
|
931 |
+
|
932 |
+
video = input_video,
|
933 |
+
mask_video = input_video_mask,
|
934 |
+
strength = strength,
|
935 |
+
).videos
|
936 |
else:
|
937 |
+
sample = self.pipeline(
|
938 |
+
prompt_textbox,
|
939 |
+
negative_prompt = negative_prompt_textbox,
|
940 |
+
num_inference_steps = sample_step_slider,
|
941 |
+
guidance_scale = cfg_scale_slider,
|
942 |
+
width = width_slider,
|
943 |
+
height = height_slider,
|
944 |
+
num_frames = length_slider if not is_image else 1,
|
945 |
+
generator = generator
|
946 |
+
).videos
|
947 |
+
else:
|
948 |
+
input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
|
949 |
|
950 |
sample = self.pipeline(
|
951 |
prompt_textbox,
|
|
|
957 |
num_frames = length_slider if not is_image else 1,
|
958 |
generator = generator,
|
959 |
|
960 |
+
control_video = input_video,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
).videos
|
962 |
except Exception as e:
|
963 |
gc.collect()
|
|
|
1013 |
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
1014 |
|
1015 |
|
1016 |
+
def ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
|
1017 |
+
controller = CogVideoX_Fun_Controller_Modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
|
1018 |
|
1019 |
with gr.Blocks(css=css) as demo:
|
1020 |
gr.Markdown(
|
|
|
1029 |
with gr.Column(variant="panel"):
|
1030 |
gr.Markdown(
|
1031 |
"""
|
1032 |
+
### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
|
1033 |
+
"""
|
1034 |
+
)
|
1035 |
+
with gr.Row():
|
1036 |
+
model_type = gr.Dropdown(
|
1037 |
+
label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
|
1038 |
+
choices=[model_type],
|
1039 |
+
value=model_type,
|
1040 |
+
interactive=False,
|
1041 |
+
)
|
1042 |
+
|
1043 |
+
gr.Markdown(
|
1044 |
+
"""
|
1045 |
+
### 2. Model checkpoints (模型路径).
|
1046 |
"""
|
1047 |
)
|
1048 |
with gr.Row():
|
|
|
1079 |
with gr.Column(variant="panel"):
|
1080 |
gr.Markdown(
|
1081 |
"""
|
1082 |
+
### 3. Configs for Generation (生成参数配置).
|
1083 |
"""
|
1084 |
)
|
1085 |
|
1086 |
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
1087 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
|
1088 |
|
1089 |
with gr.Row():
|
1090 |
with gr.Column():
|
|
|
1113 |
partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
|
1114 |
|
1115 |
source_method = gr.Radio(
|
1116 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
|
1117 |
value="Text to Video (文本到视频)",
|
1118 |
show_label=False,
|
1119 |
)
|
|
|
1124 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1125 |
def select_template(evt: gr.SelectData):
|
1126 |
text = {
|
1127 |
+
"asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1128 |
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1129 |
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1130 |
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
|
|
1146 |
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
1147 |
|
1148 |
with gr.Column(visible = False) as video_to_video_col:
|
1149 |
+
with gr.Row():
|
1150 |
+
validation_video = gr.Video(
|
1151 |
+
label="The video to convert (视频转视频的参考视频)", show_label=True,
|
1152 |
+
elem_id="v2v", sources="upload",
|
1153 |
+
)
|
1154 |
+
with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
|
1155 |
+
gr.Markdown(
|
1156 |
+
"""
|
1157 |
+
- Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
|
1158 |
+
- (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
|
1159 |
+
"""
|
1160 |
+
)
|
1161 |
+
validation_video_mask = gr.Image(
|
1162 |
+
label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
|
1163 |
+
show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
|
1164 |
+
)
|
1165 |
+
denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
|
1166 |
+
|
1167 |
+
with gr.Column(visible = False) as control_video_col:
|
1168 |
+
gr.Markdown(
|
1169 |
+
"""
|
1170 |
+
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
1171 |
+
"""
|
1172 |
+
)
|
1173 |
+
control_video = gr.Video(
|
1174 |
+
label="The control video (用于提供控制信号的video)", show_label=True,
|
1175 |
+
elem_id="v2v_control", sources="upload",
|
1176 |
)
|
|
|
1177 |
|
1178 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
|
1179 |
|
1180 |
with gr.Row():
|
1181 |
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
|
|
1208 |
|
1209 |
def upload_source_method(source_method):
|
1210 |
if source_method == "Text to Video (文本到视频)":
|
1211 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
1212 |
elif source_method == "Image to Video (图片到视频)":
|
1213 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
1214 |
+
elif source_method == "Video to Video (视频到视频)":
|
1215 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
1216 |
else:
|
1217 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
1218 |
source_method.change(
|
1219 |
+
upload_source_method, source_method, [
|
1220 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
1221 |
+
validation_video, validation_video_mask, control_video
|
1222 |
+
]
|
1223 |
)
|
1224 |
|
1225 |
def upload_resize_method(resize_method):
|
|
|
1254 |
start_image,
|
1255 |
end_image,
|
1256 |
validation_video,
|
1257 |
+
validation_video_mask,
|
1258 |
+
control_video,
|
1259 |
denoise_strength,
|
1260 |
seed_textbox,
|
1261 |
],
|
|
|
1270 |
prompt_textbox, negative_prompt_textbox,
|
1271 |
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
1272 |
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
1273 |
+
start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
|
1274 |
):
|
1275 |
if start_image is not None:
|
1276 |
with open(start_image, 'rb') as file:
|
|
|
1290 |
validation_video_encoded_content = base64.b64encode(file_content)
|
1291 |
validation_video = validation_video_encoded_content.decode('utf-8')
|
1292 |
|
1293 |
+
if validation_video_mask is not None:
|
1294 |
+
with open(validation_video_mask, 'rb') as file:
|
1295 |
+
file_content = file.read()
|
1296 |
+
validation_video_mask_encoded_content = base64.b64encode(file_content)
|
1297 |
+
validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
|
1298 |
+
|
1299 |
datas = {
|
1300 |
"base_model_path": base_model_dropdown,
|
1301 |
"lora_model_path": lora_model_dropdown,
|
|
|
1314 |
"start_image": start_image,
|
1315 |
"end_image": end_image,
|
1316 |
"validation_video": validation_video,
|
1317 |
+
"validation_video_mask": validation_video_mask,
|
1318 |
"denoise_strength": denoise_strength,
|
1319 |
"seed_textbox": seed_textbox,
|
1320 |
}
|
|
|
1328 |
return outputs
|
1329 |
|
1330 |
|
1331 |
+
class CogVideoX_Fun_Controller_EAS:
|
1332 |
def __init__(self, model_name, savedir_sample):
|
1333 |
self.savedir_sample = savedir_sample
|
1334 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
|
|
1353 |
start_image,
|
1354 |
end_image,
|
1355 |
validation_video,
|
1356 |
+
validation_video_mask,
|
1357 |
denoise_strength,
|
1358 |
seed_textbox
|
1359 |
):
|
|
|
1365 |
prompt_textbox, negative_prompt_textbox,
|
1366 |
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
1367 |
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
1368 |
+
start_image, end_image, validation_video, validation_video_mask, denoise_strength,
|
1369 |
seed_textbox
|
1370 |
)
|
1371 |
try:
|
|
|
1399 |
|
1400 |
|
1401 |
def ui_eas(model_name, savedir_sample):
|
1402 |
+
controller = CogVideoX_Fun_Controller_EAS(model_name, savedir_sample)
|
1403 |
|
1404 |
with gr.Blocks(css=css) as demo:
|
1405 |
gr.Markdown(
|
|
|
1414 |
with gr.Column(variant="panel"):
|
1415 |
gr.Markdown(
|
1416 |
"""
|
1417 |
+
### 1. Model checkpoints (模型路径).
|
1418 |
"""
|
1419 |
)
|
1420 |
with gr.Row():
|
|
|
1456 |
)
|
1457 |
|
1458 |
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
1459 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
|
1460 |
|
1461 |
with gr.Row():
|
1462 |
with gr.Column():
|
|
|
1493 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1494 |
def select_template(evt: gr.SelectData):
|
1495 |
text = {
|
1496 |
+
"asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1497 |
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1498 |
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1499 |
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
|
|
1515 |
end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
|
1516 |
|
1517 |
with gr.Column(visible = False) as video_to_video_col:
|
1518 |
+
with gr.Row():
|
1519 |
+
validation_video = gr.Video(
|
1520 |
+
label="The video to convert (视频转视频的参考视频)", show_label=True,
|
1521 |
+
elem_id="v2v", sources="upload",
|
1522 |
+
)
|
1523 |
+
with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
|
1524 |
+
gr.Markdown(
|
1525 |
+
"""
|
1526 |
+
- Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
|
1527 |
+
- (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
|
1528 |
+
"""
|
1529 |
+
)
|
1530 |
+
validation_video_mask = gr.Image(
|
1531 |
+
label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
|
1532 |
+
show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
|
1533 |
+
)
|
1534 |
+
denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
|
1535 |
+
|
1536 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
|
1537 |
|
1538 |
with gr.Row():
|
1539 |
seed_textbox = gr.Textbox(label="Seed", value=43)
|
|
|
1557 |
|
1558 |
def upload_generation_method(generation_method):
|
1559 |
if generation_method == "Video Generation":
|
1560 |
+
return gr.update(visible=True, minimum=5, maximum=49, value=49, interactive=True)
|
1561 |
elif generation_method == "Image Generation":
|
1562 |
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
1563 |
generation_method.change(
|
|
|
1566 |
|
1567 |
def upload_source_method(source_method):
|
1568 |
if source_method == "Text to Video (文本到视频)":
|
1569 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
1570 |
elif source_method == "Image to Video (图片到视频)":
|
1571 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
|
1572 |
else:
|
1573 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
|
1574 |
source_method.change(
|
1575 |
+
upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
|
1576 |
)
|
1577 |
|
1578 |
def upload_resize_method(resize_method):
|
|
|
1605 |
start_image,
|
1606 |
end_image,
|
1607 |
validation_video,
|
1608 |
+
validation_video_mask,
|
1609 |
denoise_strength,
|
1610 |
seed_textbox,
|
1611 |
],
|
cogvideox/utils/utils.py
CHANGED
@@ -166,16 +166,27 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
|
|
166 |
|
167 |
return input_video, input_video_mask, clip_image
|
168 |
|
169 |
-
def get_video_to_video_latent(input_video_path, video_length, sample_size):
|
170 |
-
if
|
171 |
cap = cv2.VideoCapture(input_video_path)
|
172 |
input_video = []
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
while True:
|
174 |
ret, frame = cap.read()
|
175 |
if not ret:
|
176 |
break
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
179 |
cap.release()
|
180 |
else:
|
181 |
input_video = input_video_path
|
@@ -183,7 +194,15 @@ def get_video_to_video_latent(input_video_path, video_length, sample_size):
|
|
183 |
input_video = torch.from_numpy(np.array(input_video))[:video_length]
|
184 |
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
185 |
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
return input_video, input_video_mask, None
|
|
|
166 |
|
167 |
return input_video, input_video_mask, clip_image
|
168 |
|
169 |
+
def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None):
|
170 |
+
if isinstance(input_video_path, str):
|
171 |
cap = cv2.VideoCapture(input_video_path)
|
172 |
input_video = []
|
173 |
+
|
174 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
175 |
+
frame_skip = 1 if fps is None else int(original_fps // fps)
|
176 |
+
|
177 |
+
frame_count = 0
|
178 |
+
|
179 |
while True:
|
180 |
ret, frame = cap.read()
|
181 |
if not ret:
|
182 |
break
|
183 |
+
|
184 |
+
if frame_count % frame_skip == 0:
|
185 |
+
frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
|
186 |
+
input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
187 |
+
|
188 |
+
frame_count += 1
|
189 |
+
|
190 |
cap.release()
|
191 |
else:
|
192 |
input_video = input_video_path
|
|
|
194 |
input_video = torch.from_numpy(np.array(input_video))[:video_length]
|
195 |
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
196 |
|
197 |
+
if validation_video_mask is not None:
|
198 |
+
validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
|
199 |
+
input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
|
200 |
+
|
201 |
+
input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
|
202 |
+
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
|
203 |
+
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
|
204 |
+
else:
|
205 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
206 |
+
input_video_mask[:, :, :] = 255
|
207 |
|
208 |
return input_video, input_video_mask, None
|