svjack commited on
Commit
fdb5dd2
·
verified ·
1 Parent(s): faf87a1

Upload consisid_preview_script_offload.py

Browse files
Files changed (1) hide show
  1. consisid_preview_script_offload.py +241 -0
consisid_preview_script_offload.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import numpy as np
5
+ import random
6
+ import threading
7
+ from PIL import Image, ImageOps
8
+ from moviepy.editor import VideoFileClip
9
+ from datetime import datetime, timedelta
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+
12
+ import insightface
13
+ from insightface.app import FaceAnalysis
14
+ from facexlib.parsing import init_parsing_model
15
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16
+
17
+ import torch
18
+ from diffusers import CogVideoXDPMScheduler
19
+ from diffusers.utils import load_image
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.training_utils import free_memory
22
+
23
+ from util.utils import *
24
+ from util.rife_model import load_rife_model, rife_inference_with_latents
25
+ from models.utils import process_face_embeddings
26
+ from models.transformer_consisid import ConsisIDTransformer3DModel
27
+ from models.pipeline_consisid import ConsisIDPipeline
28
+ from models.eva_clip import create_model_and_transforms
29
+ from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
30
+ from models.eva_clip.utils_qformer import resize_numpy_image_long
31
+
32
+ import argparse
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ def main():
37
+ parser = argparse.ArgumentParser(description="ConsisID Command Line Interface")
38
+ parser.add_argument("image_path", type=str, help="Path to the input image")
39
+ parser.add_argument("prompt", type=str, help="Prompt text for the generation")
40
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
41
+ parser.add_argument("--guidance_scale", type=float, default=7.0, help="Guidance scale")
42
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for generation")
43
+ parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the output video")
44
+ args = parser.parse_args()
45
+
46
+ # Download models
47
+ hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
48
+ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
49
+ snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
50
+
51
+ model_path = "BestWishYsh/ConsisID-preview"
52
+ lora_path = None
53
+ lora_rank = 128
54
+ dtype = torch.bfloat16
55
+
56
+ if os.path.exists(os.path.join(model_path, "transformer_ema")):
57
+ subfolder = "transformer_ema"
58
+ else:
59
+ subfolder = "transformer"
60
+
61
+ transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
62
+ scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
63
+
64
+ try:
65
+ is_kps = transformer.config.is_kps
66
+ except:
67
+ is_kps = False
68
+
69
+ # 1. load face helper models
70
+ face_helper = FaceRestoreHelper(
71
+ upscale_factor=1,
72
+ face_size=512,
73
+ crop_ratio=(1, 1),
74
+ det_model='retinaface_resnet50',
75
+ save_ext='png',
76
+ device=device,
77
+ model_rootpath=os.path.join(model_path, "face_encoder")
78
+ )
79
+ face_helper.face_parse = None
80
+ face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
81
+ face_helper.face_det.eval()
82
+ face_helper.face_parse.eval()
83
+
84
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
85
+ face_clip_model = model.visual
86
+ face_clip_model.eval()
87
+
88
+ eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
89
+ eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
90
+ if not isinstance(eva_transform_mean, (list, tuple)):
91
+ eva_transform_mean = (eva_transform_mean,) * 3
92
+ if not isinstance(eva_transform_std, (list, tuple)):
93
+ eva_transform_std = (eva_transform_std,) * 3
94
+ eva_transform_mean = eva_transform_mean
95
+ eva_transform_std = eva_transform_std
96
+
97
+ face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
98
+ handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
99
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
100
+ handler_ante.prepare(ctx_id=0)
101
+
102
+ face_clip_model.to(device, dtype=dtype)
103
+ face_helper.face_det.to(device)
104
+ face_helper.face_parse.to(device)
105
+ transformer.to(device, dtype=dtype)
106
+ free_memory()
107
+
108
+ pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype)
109
+ # If you're using with lora, add this code
110
+ if lora_path:
111
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
112
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
113
+
114
+ scheduler_args = {}
115
+ if "variance_type" in pipe.scheduler.config:
116
+ variance_type = pipe.scheduler.config.variance_type
117
+ if variance_type in ["learned", "learned_range"]:
118
+ variance_type = "fixed_small"
119
+ scheduler_args["variance_type"] = variance_type
120
+
121
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
122
+ #pipe.to(device)
123
+
124
+ pipe.enable_model_cpu_offload()
125
+ pipe.enable_sequential_cpu_offload()
126
+ pipe.vae.enable_slicing()
127
+ pipe.vae.enable_tiling()
128
+
129
+ os.makedirs(args.output_dir, exist_ok=True)
130
+
131
+ upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
132
+ frame_interpolation_model = load_rife_model("model_rife")
133
+
134
+ def infer(
135
+ prompt: str,
136
+ image_input: str,
137
+ num_inference_steps: int,
138
+ guidance_scale: float,
139
+ seed: int = 42,
140
+ ):
141
+ if seed == -1:
142
+ seed = random.randint(0, 2**8 - 1)
143
+
144
+ id_image = np.array(ImageOps.exif_transpose(Image.open(image_input)).convert("RGB"))
145
+ id_image = resize_numpy_image_long(id_image, 1024)
146
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
147
+ eva_transform_mean, eva_transform_std,
148
+ face_main_model, device, dtype, id_image,
149
+ original_id_image=id_image, is_align_face=True,
150
+ cal_uncond=False)
151
+
152
+ if is_kps:
153
+ kps_cond = face_kps
154
+ else:
155
+ kps_cond = None
156
+
157
+ tensor = align_crop_face_image.cpu().detach()
158
+ tensor = tensor.squeeze()
159
+ tensor = tensor.permute(1, 2, 0)
160
+ tensor = tensor.numpy() * 255
161
+ tensor = tensor.astype(np.uint8)
162
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
163
+
164
+ prompt = prompt.strip('"')
165
+
166
+ generator = torch.Generator(device).manual_seed(seed) if seed else None
167
+
168
+ video_pt = pipe(
169
+ prompt=prompt,
170
+ image=image,
171
+ num_videos_per_prompt=1,
172
+ num_inference_steps=num_inference_steps,
173
+ num_frames=49,
174
+ use_dynamic_cfg=False,
175
+ guidance_scale=guidance_scale,
176
+ generator=generator,
177
+ id_vit_hidden=id_vit_hidden,
178
+ id_cond=id_cond,
179
+ kps_cond=kps_cond,
180
+ output_type="pt",
181
+ ).frames
182
+
183
+ free_memory()
184
+ return (video_pt, seed)
185
+
186
+ def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8, output_dir = "output"):
187
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
188
+ video_path = f"./{output_dir}/{timestamp}.mp4"
189
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
190
+ export_to_video(tensor, video_path, fps=fps)
191
+ return video_path
192
+
193
+ def convert_to_gif(video_path):
194
+ clip = VideoFileClip(video_path)
195
+ gif_path = video_path.replace(".mp4", ".gif")
196
+ clip.write_gif(gif_path, fps=8)
197
+ return gif_path
198
+
199
+ def delete_old_files():
200
+ while True:
201
+ now = datetime.now()
202
+ cutoff = now - timedelta(minutes=10)
203
+ directories = [args.output_dir]
204
+
205
+ for directory in directories:
206
+ for filename in os.listdir(directory):
207
+ file_path = os.path.join(directory, filename)
208
+ if os.path.isfile(file_path):
209
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
210
+ if file_mtime < cutoff:
211
+ os.remove(file_path)
212
+ time.sleep(600)
213
+
214
+ threading.Thread(target=delete_old_files, daemon=True).start()
215
+
216
+ latents, seed = infer(
217
+ args.prompt,
218
+ args.image_path,
219
+ num_inference_steps=args.num_inference_steps,
220
+ guidance_scale=args.guidance_scale,
221
+ seed=args.seed,
222
+ )
223
+
224
+ batch_size = latents.shape[0]
225
+ batch_video_frames = []
226
+ for batch_idx in range(batch_size):
227
+ pt_image = latents[batch_idx]
228
+ pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
229
+
230
+ image_np = VaeImageProcessor.pt_to_numpy(pt_image)
231
+ image_pil = VaeImageProcessor.numpy_to_pil(image_np)
232
+ batch_video_frames.append(image_pil)
233
+
234
+ video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6), output_dir=args.output_dir)
235
+ gif_path = convert_to_gif(video_path)
236
+
237
+ print(f"Video saved to: {video_path}")
238
+ print(f"GIF saved to: {gif_path}")
239
+
240
+ if __name__ == "__main__":
241
+ main()