Spanicin commited on
Commit
db66a7a
·
verified ·
1 Parent(s): 3a34467

Upload app_hallo.py

Browse files
Files changed (1) hide show
  1. app_hallo.py +381 -0
app_hallo.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDIMScheduler
6
+ from omegaconf import OmegaConf
7
+ from torch import nn
8
+
9
+ from hallo.animate.face_animate import FaceAnimatePipeline
10
+ from hallo.datasets.audio_processor import AudioProcessor
11
+ from hallo.datasets.image_processor import ImageProcessor
12
+ from hallo.models.audio_proj import AudioProjModel
13
+ from hallo.models.face_locator import FaceLocator
14
+ from hallo.models.image_proj import ImageProjModel
15
+ from hallo.models.unet_2d_condition import UNet2DConditionModel
16
+ from hallo.models.unet_3d import UNet3DConditionModel
17
+ from hallo.utils.config import filter_non_none
18
+ from hallo.utils.util import tensor_to_video
19
+
20
+ from flask import Flask, request, jsonify
21
+ import tempfile
22
+ import uuid
23
+
24
+ app = Flask(__name__)
25
+ TEMP_DIR = None
26
+
27
+ class Net(nn.Module):
28
+ """
29
+ The Net class combines all the necessary modules for the inference process.
30
+
31
+ Args:
32
+ reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
33
+ denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
34
+ face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
35
+ imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
36
+ audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
37
+ """
38
+ def __init__(
39
+ self,
40
+ reference_unet: UNet2DConditionModel,
41
+ denoising_unet: UNet3DConditionModel,
42
+ face_locator: FaceLocator,
43
+ imageproj,
44
+ audioproj,
45
+ ):
46
+ super().__init__()
47
+ self.reference_unet = reference_unet
48
+ self.denoising_unet = denoising_unet
49
+ self.face_locator = face_locator
50
+ self.imageproj = imageproj
51
+ self.audioproj = audioproj
52
+
53
+ def forward(self,):
54
+ """
55
+ empty function to override abstract function of nn Module
56
+ """
57
+
58
+ def get_modules(self):
59
+ """
60
+ Simple method to avoid too-few-public-methods pylint error
61
+ """
62
+ return {
63
+ "reference_unet": self.reference_unet,
64
+ "denoising_unet": self.denoising_unet,
65
+ "face_locator": self.face_locator,
66
+ "imageproj": self.imageproj,
67
+ "audioproj": self.audioproj,
68
+ }
69
+
70
+ class AnimationConfig:
71
+ def __init__(self, driven_audio_path, source_image_path, result_folder):
72
+ self.driven_audio = driven_audio_path
73
+ self.source_image = source_image_path
74
+ self.checkpoint_dir = './checkpoints'
75
+ self.result_dir = result_folder
76
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
77
+
78
+
79
+ def process_audio_emb(audio_emb):
80
+ """
81
+ Process the audio embedding to concatenate with other tensors.
82
+
83
+ Parameters:
84
+ audio_emb (torch.Tensor): The audio embedding tensor to process.
85
+
86
+ Returns:
87
+ concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
88
+ """
89
+ concatenated_tensors = []
90
+
91
+ for i in range(audio_emb.shape[0]):
92
+ vectors_to_concat = [
93
+ audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
94
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
95
+
96
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
97
+
98
+ return audio_emb
99
+
100
+
101
+
102
+ def inference_process(args: argparse.Namespace):
103
+ """
104
+ Perform inference processing.
105
+
106
+ Args:
107
+ args (argparse.Namespace): Command-line arguments.
108
+
109
+ This function initializes the configuration for the inference process. It sets up the necessary
110
+ modules and variables to prepare for the upcoming inference steps.
111
+ """
112
+ # 1. init config
113
+ cli_args = filter_non_none(vars(args))
114
+ config = OmegaConf.load(args.config)
115
+ config = OmegaConf.merge(config, cli_args)
116
+ source_image_path = config.source_image
117
+ driving_audio_path = config.driving_audio
118
+ save_path = config.save_path
119
+ if not os.path.exists(save_path):
120
+ os.makedirs(save_path)
121
+ motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
122
+
123
+ # 2. runtime variables
124
+ device = torch.device(
125
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
126
+ if config.weight_dtype == "fp16":
127
+ weight_dtype = torch.float16
128
+ elif config.weight_dtype == "bf16":
129
+ weight_dtype = torch.bfloat16
130
+ elif config.weight_dtype == "fp32":
131
+ weight_dtype = torch.float32
132
+ else:
133
+ weight_dtype = torch.float32
134
+
135
+ # 3. prepare inference data
136
+ # 3.1 prepare source image, face mask, face embeddings
137
+ img_size = (config.data.source_image.width,
138
+ config.data.source_image.height)
139
+ clip_length = config.data.n_sample_frames
140
+ face_analysis_model_path = config.face_analysis.model_path
141
+ with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
142
+ source_image_pixels, \
143
+ source_image_face_region, \
144
+ source_image_face_emb, \
145
+ source_image_full_mask, \
146
+ source_image_face_mask, \
147
+ source_image_lip_mask = image_processor.preprocess(
148
+ source_image_path, save_path, config.face_expand_ratio)
149
+
150
+ # 3.2 prepare audio embeddings
151
+ sample_rate = config.data.driving_audio.sample_rate
152
+ assert sample_rate == 16000, "audio sample rate must be 16000"
153
+ fps = config.data.export_video.fps
154
+ wav2vec_model_path = config.wav2vec.model_path
155
+ wav2vec_only_last_features = config.wav2vec.features == "last"
156
+ audio_separator_model_file = config.audio_separator.model_path
157
+ with AudioProcessor(
158
+ sample_rate,
159
+ fps,
160
+ wav2vec_model_path,
161
+ wav2vec_only_last_features,
162
+ os.path.dirname(audio_separator_model_file),
163
+ os.path.basename(audio_separator_model_file),
164
+ os.path.join(save_path, "audio_preprocess")
165
+ ) as audio_processor:
166
+ audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
167
+
168
+ # 4. build modules
169
+ sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
170
+ if config.enable_zero_snr:
171
+ sched_kwargs.update(
172
+ rescale_betas_zero_snr=True,
173
+ timestep_spacing="trailing",
174
+ prediction_type="v_prediction",
175
+ )
176
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
177
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
178
+
179
+ vae = AutoencoderKL.from_pretrained(config.vae.model_path)
180
+ reference_unet = UNet2DConditionModel.from_pretrained(
181
+ config.base_model_path, subfolder="unet")
182
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
183
+ config.base_model_path,
184
+ config.motion_module_path,
185
+ subfolder="unet",
186
+ unet_additional_kwargs=OmegaConf.to_container(
187
+ config.unet_additional_kwargs),
188
+ use_landmark=False,
189
+ )
190
+ face_locator = FaceLocator(conditioning_embedding_channels=320)
191
+ image_proj = ImageProjModel(
192
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
193
+ clip_embeddings_dim=512,
194
+ clip_extra_context_tokens=4,
195
+ )
196
+
197
+ audio_proj = AudioProjModel(
198
+ seq_len=5,
199
+ blocks=12, # use 12 layers' hidden states of wav2vec
200
+ channels=768, # audio embedding channel
201
+ intermediate_dim=512,
202
+ output_dim=768,
203
+ context_tokens=32,
204
+ ).to(device=device, dtype=weight_dtype)
205
+
206
+ audio_ckpt_dir = config.audio_ckpt_dir
207
+
208
+
209
+ # Freeze
210
+ vae.requires_grad_(False)
211
+ image_proj.requires_grad_(False)
212
+ reference_unet.requires_grad_(False)
213
+ denoising_unet.requires_grad_(False)
214
+ face_locator.requires_grad_(False)
215
+ audio_proj.requires_grad_(False)
216
+
217
+ reference_unet.enable_gradient_checkpointing()
218
+ denoising_unet.enable_gradient_checkpointing()
219
+
220
+ net = Net(
221
+ reference_unet,
222
+ denoising_unet,
223
+ face_locator,
224
+ image_proj,
225
+ audio_proj,
226
+ )
227
+
228
+ m,u = net.load_state_dict(
229
+ torch.load(
230
+ os.path.join(audio_ckpt_dir, "net.pth"),
231
+ map_location="cpu",
232
+ ),
233
+ )
234
+ assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
235
+ print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
236
+
237
+ # 5. inference
238
+ pipeline = FaceAnimatePipeline(
239
+ vae=vae,
240
+ reference_unet=net.reference_unet,
241
+ denoising_unet=net.denoising_unet,
242
+ face_locator=net.face_locator,
243
+ scheduler=val_noise_scheduler,
244
+ image_proj=net.imageproj,
245
+ )
246
+ pipeline.to(device=device, dtype=weight_dtype)
247
+
248
+ audio_emb = process_audio_emb(audio_emb)
249
+
250
+ source_image_pixels = source_image_pixels.unsqueeze(0)
251
+ source_image_face_region = source_image_face_region.unsqueeze(0)
252
+ source_image_face_emb = source_image_face_emb.reshape(1, -1)
253
+ source_image_face_emb = torch.tensor(source_image_face_emb)
254
+
255
+ source_image_full_mask = [
256
+ (mask.repeat(clip_length, 1))
257
+ for mask in source_image_full_mask
258
+ ]
259
+ source_image_face_mask = [
260
+ (mask.repeat(clip_length, 1))
261
+ for mask in source_image_face_mask
262
+ ]
263
+ source_image_lip_mask = [
264
+ (mask.repeat(clip_length, 1))
265
+ for mask in source_image_lip_mask
266
+ ]
267
+
268
+
269
+ times = audio_emb.shape[0] // clip_length
270
+
271
+ tensor_result = []
272
+
273
+ generator = torch.manual_seed(42)
274
+
275
+ for t in range(times):
276
+ print(f"[{t+1}/{times}]")
277
+
278
+ if len(tensor_result) == 0:
279
+ # The first iteration
280
+ motion_zeros = source_image_pixels.repeat(
281
+ config.data.n_motion_frames, 1, 1, 1)
282
+ motion_zeros = motion_zeros.to(
283
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
284
+ pixel_values_ref_img = torch.cat(
285
+ [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
286
+ else:
287
+ motion_frames = tensor_result[-1][0]
288
+ motion_frames = motion_frames.permute(1, 0, 2, 3)
289
+ motion_frames = motion_frames[0-config.data.n_motion_frames:]
290
+ motion_frames = motion_frames * 2.0 - 1.0
291
+ motion_frames = motion_frames.to(
292
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
293
+ pixel_values_ref_img = torch.cat(
294
+ [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
295
+
296
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
297
+
298
+ audio_tensor = audio_emb[
299
+ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
300
+ ]
301
+ audio_tensor = audio_tensor.unsqueeze(0)
302
+ audio_tensor = audio_tensor.to(
303
+ device=net.audioproj.device, dtype=net.audioproj.dtype)
304
+ audio_tensor = net.audioproj(audio_tensor)
305
+
306
+ pipeline_output = pipeline(
307
+ ref_image=pixel_values_ref_img,
308
+ audio_tensor=audio_tensor,
309
+ face_emb=source_image_face_emb,
310
+ face_mask=source_image_face_region,
311
+ pixel_values_full_mask=source_image_full_mask,
312
+ pixel_values_face_mask=source_image_face_mask,
313
+ pixel_values_lip_mask=source_image_lip_mask,
314
+ width=img_size[0],
315
+ height=img_size[1],
316
+ video_length=clip_length,
317
+ num_inference_steps=config.inference_steps,
318
+ guidance_scale=config.cfg_scale,
319
+ generator=generator,
320
+ motion_scale=motion_scale,
321
+ )
322
+
323
+ tensor_result.append(pipeline_output.videos)
324
+
325
+ tensor_result = torch.cat(tensor_result, dim=2)
326
+ tensor_result = tensor_result.squeeze(0)
327
+ tensor_result = tensor_result[:, :audio_length]
328
+
329
+ output_file = config.output
330
+ # save the result after all iteration
331
+ tensor_to_video(tensor_result, output_file, driving_audio_path)
332
+ return output_file
333
+
334
+ def create_temp_dir():
335
+ return tempfile.TemporaryDirectory()
336
+
337
+ def save_uploaded_file(file, filename,TEMP_DIR):
338
+ unique_filename = str(uuid.uuid4()) + "_" + filename
339
+ file_path = os.path.join(TEMP_DIR.name, unique_filename)
340
+ file.save(file_path)
341
+ return file_path
342
+
343
+ @app.route('/run', methods=['POST'])
344
+ def generate_video():
345
+ global TEMP_DIR
346
+ TEMP_DIR = create_temp_dir()
347
+ if request.method == 'POST':
348
+ source_image = request.files['source_image']
349
+ # text_prompt = request.form['text_prompt']
350
+ # print('Input text prompt: ', text_prompt)
351
+ # text_prompt = text_prompt.strip()
352
+ # if not text_prompt:
353
+ # return jsonify({'error': 'Input text prompt cannot be blank'}), 400
354
+ driving_audio = request.files['driving_audio']
355
+ source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR)
356
+ print(source_image_path)
357
+ driving_audio_path = save_uploaded_file(driving_audio, 'driving_audio.wav', TEMP_DIR)
358
+ print(driving_audio_path)
359
+ output_path = TEMP_DIR.name
360
+
361
+ args = AnimationConfig(
362
+ driven_audio_path=driving_audio_path,
363
+ source_image_path=source_image_path,
364
+ result_folder=output_path)
365
+
366
+ try:
367
+ # Run the inference process
368
+ output_file = inference_process(args)
369
+ return jsonify({"message": "Inference completed successfully", "output_file": os.path.abspath(output_file)})
370
+ except Exception as e:
371
+ return jsonify({"error": "Inference failed", "details": str(e)}), 500
372
+
373
+
374
+
375
+ @app.route("/health", methods=["GET"])
376
+ def health_status():
377
+ response = {"online": "true"}
378
+ return jsonify(response)
379
+
380
+ if __name__ == '__main__':
381
+ app.run(debug=True)