svjack commited on
Commit
ef78ab1
·
verified ·
1 Parent(s): 0e3ae41

Upload consisid_preview_script.py

Browse files
Files changed (1) hide show
  1. consisid_preview_script.py +236 -0
consisid_preview_script.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import numpy as np
5
+ import random
6
+ import threading
7
+ from PIL import Image, ImageOps
8
+ from moviepy.editor import VideoFileClip
9
+ from datetime import datetime, timedelta
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+
12
+ import insightface
13
+ from insightface.app import FaceAnalysis
14
+ from facexlib.parsing import init_parsing_model
15
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16
+
17
+ import torch
18
+ from diffusers import CogVideoXDPMScheduler
19
+ from diffusers.utils import load_image
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.training_utils import free_memory
22
+
23
+ from util.utils import *
24
+ from util.rife_model import load_rife_model, rife_inference_with_latents
25
+ from models.utils import process_face_embeddings
26
+ from models.transformer_consisid import ConsisIDTransformer3DModel
27
+ from models.pipeline_consisid import ConsisIDPipeline
28
+ from models.eva_clip import create_model_and_transforms
29
+ from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
30
+ from models.eva_clip.utils_qformer import resize_numpy_image_long
31
+
32
+ import argparse
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ def main():
37
+ parser = argparse.ArgumentParser(description="ConsisID Command Line Interface")
38
+ parser.add_argument("image_path", type=str, help="Path to the input image")
39
+ parser.add_argument("prompt", type=str, help="Prompt text for the generation")
40
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
41
+ parser.add_argument("--guidance_scale", type=float, default=7.0, help="Guidance scale")
42
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for generation")
43
+ parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the output video")
44
+ args = parser.parse_args()
45
+
46
+ # Download models
47
+ hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
48
+ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
49
+ snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
50
+
51
+ model_path = "BestWishYsh/ConsisID-preview"
52
+ lora_path = None
53
+ lora_rank = 128
54
+ dtype = torch.bfloat16
55
+
56
+ if os.path.exists(os.path.join(model_path, "transformer_ema")):
57
+ subfolder = "transformer_ema"
58
+ else:
59
+ subfolder = "transformer"
60
+
61
+ transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
62
+ scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
63
+
64
+ try:
65
+ is_kps = transformer.config.is_kps
66
+ except:
67
+ is_kps = False
68
+
69
+ # 1. load face helper models
70
+ face_helper = FaceRestoreHelper(
71
+ upscale_factor=1,
72
+ face_size=512,
73
+ crop_ratio=(1, 1),
74
+ det_model='retinaface_resnet50',
75
+ save_ext='png',
76
+ device=device,
77
+ model_rootpath=os.path.join(model_path, "face_encoder")
78
+ )
79
+ face_helper.face_parse = None
80
+ face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
81
+ face_helper.face_det.eval()
82
+ face_helper.face_parse.eval()
83
+
84
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
85
+ face_clip_model = model.visual
86
+ face_clip_model.eval()
87
+
88
+ eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
89
+ eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
90
+ if not isinstance(eva_transform_mean, (list, tuple)):
91
+ eva_transform_mean = (eva_transform_mean,) * 3
92
+ if not isinstance(eva_transform_std, (list, tuple)):
93
+ eva_transform_std = (eva_transform_std,) * 3
94
+ eva_transform_mean = eva_transform_mean
95
+ eva_transform_std = eva_transform_std
96
+
97
+ face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
98
+ handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
99
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
100
+ handler_ante.prepare(ctx_id=0)
101
+
102
+ face_clip_model.to(device, dtype=dtype)
103
+ face_helper.face_det.to(device)
104
+ face_helper.face_parse.to(device)
105
+ transformer.to(device, dtype=dtype)
106
+ free_memory()
107
+
108
+ pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype)
109
+ # If you're using with lora, add this code
110
+ if lora_path:
111
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
112
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
113
+
114
+ scheduler_args = {}
115
+ if "variance_type" in pipe.scheduler.config:
116
+ variance_type = pipe.scheduler.config.variance_type
117
+ if variance_type in ["learned", "learned_range"]:
118
+ variance_type = "fixed_small"
119
+ scheduler_args["variance_type"] = variance_type
120
+
121
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
122
+ pipe.to(device)
123
+
124
+ os.makedirs(args.output_dir, exist_ok=True)
125
+
126
+ upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
127
+ frame_interpolation_model = load_rife_model("model_rife")
128
+
129
+ def infer(
130
+ prompt: str,
131
+ image_input: str,
132
+ num_inference_steps: int,
133
+ guidance_scale: float,
134
+ seed: int = 42,
135
+ ):
136
+ if seed == -1:
137
+ seed = random.randint(0, 2**8 - 1)
138
+
139
+ id_image = np.array(ImageOps.exif_transpose(Image.open(image_input)).convert("RGB"))
140
+ id_image = resize_numpy_image_long(id_image, 1024)
141
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
142
+ eva_transform_mean, eva_transform_std,
143
+ face_main_model, device, dtype, id_image,
144
+ original_id_image=id_image, is_align_face=True,
145
+ cal_uncond=False)
146
+
147
+ if is_kps:
148
+ kps_cond = face_kps
149
+ else:
150
+ kps_cond = None
151
+
152
+ tensor = align_crop_face_image.cpu().detach()
153
+ tensor = tensor.squeeze()
154
+ tensor = tensor.permute(1, 2, 0)
155
+ tensor = tensor.numpy() * 255
156
+ tensor = tensor.astype(np.uint8)
157
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
158
+
159
+ prompt = prompt.strip('"')
160
+
161
+ generator = torch.Generator(device).manual_seed(seed) if seed else None
162
+
163
+ video_pt = pipe(
164
+ prompt=prompt,
165
+ image=image,
166
+ num_videos_per_prompt=1,
167
+ num_inference_steps=num_inference_steps,
168
+ num_frames=49,
169
+ use_dynamic_cfg=False,
170
+ guidance_scale=guidance_scale,
171
+ generator=generator,
172
+ id_vit_hidden=id_vit_hidden,
173
+ id_cond=id_cond,
174
+ kps_cond=kps_cond,
175
+ output_type="pt",
176
+ ).frames
177
+
178
+ free_memory()
179
+ return (video_pt, seed)
180
+
181
+ def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8, output_dir = "output"):
182
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
183
+ video_path = f"./{output_dir}/{timestamp}.mp4"
184
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
185
+ export_to_video(tensor, video_path, fps=fps)
186
+ return video_path
187
+
188
+ def convert_to_gif(video_path):
189
+ clip = VideoFileClip(video_path)
190
+ gif_path = video_path.replace(".mp4", ".gif")
191
+ clip.write_gif(gif_path, fps=8)
192
+ return gif_path
193
+
194
+ def delete_old_files():
195
+ while True:
196
+ now = datetime.now()
197
+ cutoff = now - timedelta(minutes=10)
198
+ directories = [args.output_dir]
199
+
200
+ for directory in directories:
201
+ for filename in os.listdir(directory):
202
+ file_path = os.path.join(directory, filename)
203
+ if os.path.isfile(file_path):
204
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
205
+ if file_mtime < cutoff:
206
+ os.remove(file_path)
207
+ time.sleep(600)
208
+
209
+ threading.Thread(target=delete_old_files, daemon=True).start()
210
+
211
+ latents, seed = infer(
212
+ args.prompt,
213
+ args.image_path,
214
+ num_inference_steps=args.num_inference_steps,
215
+ guidance_scale=args.guidance_scale,
216
+ seed=args.seed,
217
+ )
218
+
219
+ batch_size = latents.shape[0]
220
+ batch_video_frames = []
221
+ for batch_idx in range(batch_size):
222
+ pt_image = latents[batch_idx]
223
+ pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
224
+
225
+ image_np = VaeImageProcessor.pt_to_numpy(pt_image)
226
+ image_pil = VaeImageProcessor.numpy_to_pil(image_np)
227
+ batch_video_frames.append(image_pil)
228
+
229
+ video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6), output_dir=args.output_dir)
230
+ gif_path = convert_to_gif(video_path)
231
+
232
+ print(f"Video saved to: {video_path}")
233
+ print(f"GIF saved to: {gif_path}")
234
+
235
+ if __name__ == "__main__":
236
+ main()