svjack commited on
Commit
50edba0
·
verified ·
1 Parent(s): fdb5dd2

Create consisid_preview_script_offload_multi.py

Browse files
consisid_preview_script_offload_multi.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import numpy as np
5
+ import random
6
+ import threading
7
+ from PIL import Image, ImageOps
8
+ from moviepy.editor import VideoFileClip
9
+ from datetime import datetime, timedelta
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+
12
+ import insightface
13
+ from insightface.app import FaceAnalysis
14
+ from facexlib.parsing import init_parsing_model
15
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16
+
17
+ import torch
18
+ from diffusers import CogVideoXDPMScheduler
19
+ from diffusers.utils import load_image
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.training_utils import free_memory
22
+
23
+ from util.utils import *
24
+ from util.rife_model import load_rife_model, rife_inference_with_latents
25
+ from models.utils import process_face_embeddings
26
+ from models.transformer_consisid import ConsisIDTransformer3DModel
27
+ from models.pipeline_consisid import ConsisIDPipeline
28
+ from models.eva_clip import create_model_and_transforms
29
+ from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
30
+ from models.eva_clip.utils_qformer import resize_numpy_image_long
31
+
32
+ import argparse
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ def main():
37
+ parser = argparse.ArgumentParser(description="ConsisID Command Line Interface")
38
+ parser.add_argument("image_path", type=str, help="Path to the input image")
39
+ parser.add_argument("prompt", type=str, help="Prompt text for the generation")
40
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
41
+ parser.add_argument("--guidance_scale", type=float, default=7.0, help="Guidance scale")
42
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for generation")
43
+ parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the output video")
44
+ parser.add_argument("--num_videos", type=int, default=1, help="Number of videos to generate")
45
+ args = parser.parse_args()
46
+
47
+ # Download models
48
+ hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
49
+ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
50
+ snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
51
+
52
+ model_path = "BestWishYsh/ConsisID-preview"
53
+ lora_path = None
54
+ lora_rank = 128
55
+ dtype = torch.bfloat16
56
+
57
+ if os.path.exists(os.path.join(model_path, "transformer_ema")):
58
+ subfolder = "transformer_ema"
59
+ else:
60
+ subfolder = "transformer"
61
+
62
+ transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
63
+ scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
64
+
65
+ try:
66
+ is_kps = transformer.config.is_kps
67
+ except:
68
+ is_kps = False
69
+
70
+ # 1. load face helper models
71
+ face_helper = FaceRestoreHelper(
72
+ upscale_factor=1,
73
+ face_size=512,
74
+ crop_ratio=(1, 1),
75
+ det_model='retinaface_resnet50',
76
+ save_ext='png',
77
+ device=device,
78
+ model_rootpath=os.path.join(model_path, "face_encoder")
79
+ )
80
+ face_helper.face_parse = None
81
+ face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
82
+ face_helper.face_det.eval()
83
+ face_helper.face_parse.eval()
84
+
85
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
86
+ face_clip_model = model.visual
87
+ face_clip_model.eval()
88
+
89
+ eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
90
+ eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
91
+ if not isinstance(eva_transform_mean, (list, tuple)):
92
+ eva_transform_mean = (eva_transform_mean,) * 3
93
+ if not isinstance(eva_transform_std, (list, tuple)):
94
+ eva_transform_std = (eva_transform_std,) * 3
95
+ eva_transform_mean = eva_transform_mean
96
+ eva_transform_std = eva_transform_std
97
+
98
+ face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
99
+ handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
100
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
101
+ handler_ante.prepare(ctx_id=0)
102
+
103
+ face_clip_model.to(device, dtype=dtype)
104
+ face_helper.face_det.to(device)
105
+ face_helper.face_parse.to(device)
106
+ transformer.to(device, dtype=dtype)
107
+ free_memory()
108
+
109
+ pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype)
110
+ # If you're using with lora, add this code
111
+ if lora_path:
112
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
113
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
114
+
115
+ scheduler_args = {}
116
+ if "variance_type" in pipe.scheduler.config:
117
+ variance_type = pipe.scheduler.config.variance_type
118
+ if variance_type in ["learned", "learned_range"]:
119
+ variance_type = "fixed_small"
120
+ scheduler_args["variance_type"] = variance_type
121
+
122
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
123
+ #pipe.to(device)
124
+
125
+ pipe.enable_model_cpu_offload()
126
+ pipe.enable_sequential_cpu_offload()
127
+ pipe.vae.enable_slicing()
128
+ pipe.vae.enable_tiling()
129
+
130
+ os.makedirs(args.output_dir, exist_ok=True)
131
+
132
+ upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
133
+ frame_interpolation_model = load_rife_model("model_rife")
134
+
135
+ def infer(
136
+ prompt: str,
137
+ image_input: str,
138
+ num_inference_steps: int,
139
+ guidance_scale: float,
140
+ seed: int = 42,
141
+ ):
142
+ if seed == -1:
143
+ seed = random.randint(0, 2**8 - 1)
144
+
145
+ id_image = np.array(ImageOps.exif_transpose(Image.open(image_input)).convert("RGB"))
146
+ id_image = resize_numpy_image_long(id_image, 1024)
147
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
148
+ eva_transform_mean, eva_transform_std,
149
+ face_main_model, device, dtype, id_image,
150
+ original_id_image=id_image, is_align_face=True,
151
+ cal_uncond=False)
152
+
153
+ if is_kps:
154
+ kps_cond = face_kps
155
+ else:
156
+ kps_cond = None
157
+
158
+ tensor = align_crop_face_image.cpu().detach()
159
+ tensor = tensor.squeeze()
160
+ tensor = tensor.permute(1, 2, 0)
161
+ tensor = tensor.numpy() * 255
162
+ tensor = tensor.astype(np.uint8)
163
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
164
+
165
+ prompt = prompt.strip('"')
166
+
167
+ generator = torch.Generator(device).manual_seed(seed) if seed else None
168
+
169
+ video_pt = pipe(
170
+ prompt=prompt,
171
+ image=image,
172
+ num_videos_per_prompt=1,
173
+ num_inference_steps=num_inference_steps,
174
+ num_frames=49,
175
+ use_dynamic_cfg=False,
176
+ guidance_scale=guidance_scale,
177
+ generator=generator,
178
+ id_vit_hidden=id_vit_hidden,
179
+ id_cond=id_cond,
180
+ kps_cond=kps_cond,
181
+ output_type="pt",
182
+ ).frames
183
+
184
+ free_memory()
185
+ return (video_pt, seed)
186
+
187
+ def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8, output_dir = "output"):
188
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
189
+ video_path = f"./{output_dir}/{timestamp}.mp4"
190
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
191
+ export_to_video(tensor, video_path, fps=fps)
192
+ return video_path
193
+
194
+ def convert_to_gif(video_path):
195
+ clip = VideoFileClip(video_path)
196
+ gif_path = video_path.replace(".mp4", ".gif")
197
+ clip.write_gif(gif_path, fps=8)
198
+ return gif_path
199
+
200
+ for i in range(args.num_videos):
201
+ seed = random.randint(0, 2**8 - 1) if args.seed == -1 else args.seed + i
202
+ latents, seed = infer(
203
+ args.prompt,
204
+ args.image_path,
205
+ num_inference_steps=args.num_inference_steps,
206
+ guidance_scale=args.guidance_scale,
207
+ seed=seed,
208
+ )
209
+
210
+ batch_size = latents.shape[0]
211
+ batch_video_frames = []
212
+ for batch_idx in range(batch_size):
213
+ pt_image = latents[batch_idx]
214
+ pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
215
+
216
+ image_np = VaeImageProcessor.pt_to_numpy(pt_image)
217
+ image_pil = VaeImageProcessor.numpy_to_pil(image_np)
218
+ batch_video_frames.append(image_pil)
219
+
220
+ video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6), output_dir=args.output_dir)
221
+ gif_path = convert_to_gif(video_path)
222
+
223
+ print(f"Video {i+1} saved to: {video_path}")
224
+ print(f"GIF {i+1} saved to: {gif_path}")
225
+
226
+ if __name__ == "__main__":
227
+ main()