Spanicin commited on
Commit
bfbd052
·
verified ·
1 Parent(s): d46075f

Update app_hallo.py

Browse files
Files changed (1) hide show
  1. app_hallo.py +379 -380
app_hallo.py CHANGED
@@ -1,381 +1,380 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- from diffusers import AutoencoderKL, DDIMScheduler
6
- from omegaconf import OmegaConf
7
- from torch import nn
8
-
9
- from hallo.animate.face_animate import FaceAnimatePipeline
10
- from hallo.datasets.audio_processor import AudioProcessor
11
- from hallo.datasets.image_processor import ImageProcessor
12
- from hallo.models.audio_proj import AudioProjModel
13
- from hallo.models.face_locator import FaceLocator
14
- from hallo.models.image_proj import ImageProjModel
15
- from hallo.models.unet_2d_condition import UNet2DConditionModel
16
- from hallo.models.unet_3d import UNet3DConditionModel
17
- from hallo.utils.config import filter_non_none
18
- from hallo.utils.util import tensor_to_video
19
-
20
- from flask import Flask, request, jsonify
21
- import tempfile
22
- import uuid
23
-
24
- app = Flask(__name__)
25
- TEMP_DIR = None
26
-
27
- class Net(nn.Module):
28
- """
29
- The Net class combines all the necessary modules for the inference process.
30
-
31
- Args:
32
- reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
33
- denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
34
- face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
35
- imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
36
- audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
37
- """
38
- def __init__(
39
- self,
40
- reference_unet: UNet2DConditionModel,
41
- denoising_unet: UNet3DConditionModel,
42
- face_locator: FaceLocator,
43
- imageproj,
44
- audioproj,
45
- ):
46
- super().__init__()
47
- self.reference_unet = reference_unet
48
- self.denoising_unet = denoising_unet
49
- self.face_locator = face_locator
50
- self.imageproj = imageproj
51
- self.audioproj = audioproj
52
-
53
- def forward(self,):
54
- """
55
- empty function to override abstract function of nn Module
56
- """
57
-
58
- def get_modules(self):
59
- """
60
- Simple method to avoid too-few-public-methods pylint error
61
- """
62
- return {
63
- "reference_unet": self.reference_unet,
64
- "denoising_unet": self.denoising_unet,
65
- "face_locator": self.face_locator,
66
- "imageproj": self.imageproj,
67
- "audioproj": self.audioproj,
68
- }
69
-
70
- class AnimationConfig:
71
- def __init__(self, driven_audio_path, source_image_path, result_folder):
72
- self.driven_audio = driven_audio_path
73
- self.source_image = source_image_path
74
- self.checkpoint_dir = './checkpoints'
75
- self.result_dir = result_folder
76
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
77
-
78
-
79
- def process_audio_emb(audio_emb):
80
- """
81
- Process the audio embedding to concatenate with other tensors.
82
-
83
- Parameters:
84
- audio_emb (torch.Tensor): The audio embedding tensor to process.
85
-
86
- Returns:
87
- concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
88
- """
89
- concatenated_tensors = []
90
-
91
- for i in range(audio_emb.shape[0]):
92
- vectors_to_concat = [
93
- audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
94
- concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
95
-
96
- audio_emb = torch.stack(concatenated_tensors, dim=0)
97
-
98
- return audio_emb
99
-
100
-
101
-
102
- def inference_process(args: argparse.Namespace):
103
- """
104
- Perform inference processing.
105
-
106
- Args:
107
- args (argparse.Namespace): Command-line arguments.
108
-
109
- This function initializes the configuration for the inference process. It sets up the necessary
110
- modules and variables to prepare for the upcoming inference steps.
111
- """
112
- # 1. init config
113
- cli_args = filter_non_none(vars(args))
114
- config = OmegaConf.load(args.config)
115
- config = OmegaConf.merge(config, cli_args)
116
- source_image_path = config.source_image
117
- driving_audio_path = config.driving_audio
118
- save_path = config.save_path
119
- if not os.path.exists(save_path):
120
- os.makedirs(save_path)
121
- motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
122
-
123
- # 2. runtime variables
124
- device = torch.device(
125
- "cuda") if torch.cuda.is_available() else torch.device("cpu")
126
- if config.weight_dtype == "fp16":
127
- weight_dtype = torch.float16
128
- elif config.weight_dtype == "bf16":
129
- weight_dtype = torch.bfloat16
130
- elif config.weight_dtype == "fp32":
131
- weight_dtype = torch.float32
132
- else:
133
- weight_dtype = torch.float32
134
-
135
- # 3. prepare inference data
136
- # 3.1 prepare source image, face mask, face embeddings
137
- img_size = (config.data.source_image.width,
138
- config.data.source_image.height)
139
- clip_length = config.data.n_sample_frames
140
- face_analysis_model_path = config.face_analysis.model_path
141
- with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
142
- source_image_pixels, \
143
- source_image_face_region, \
144
- source_image_face_emb, \
145
- source_image_full_mask, \
146
- source_image_face_mask, \
147
- source_image_lip_mask = image_processor.preprocess(
148
- source_image_path, save_path, config.face_expand_ratio)
149
-
150
- # 3.2 prepare audio embeddings
151
- sample_rate = config.data.driving_audio.sample_rate
152
- assert sample_rate == 16000, "audio sample rate must be 16000"
153
- fps = config.data.export_video.fps
154
- wav2vec_model_path = config.wav2vec.model_path
155
- wav2vec_only_last_features = config.wav2vec.features == "last"
156
- audio_separator_model_file = config.audio_separator.model_path
157
- with AudioProcessor(
158
- sample_rate,
159
- fps,
160
- wav2vec_model_path,
161
- wav2vec_only_last_features,
162
- os.path.dirname(audio_separator_model_file),
163
- os.path.basename(audio_separator_model_file),
164
- os.path.join(save_path, "audio_preprocess")
165
- ) as audio_processor:
166
- audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
167
-
168
- # 4. build modules
169
- sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
170
- if config.enable_zero_snr:
171
- sched_kwargs.update(
172
- rescale_betas_zero_snr=True,
173
- timestep_spacing="trailing",
174
- prediction_type="v_prediction",
175
- )
176
- val_noise_scheduler = DDIMScheduler(**sched_kwargs)
177
- sched_kwargs.update({"beta_schedule": "scaled_linear"})
178
-
179
- vae = AutoencoderKL.from_pretrained(config.vae.model_path)
180
- reference_unet = UNet2DConditionModel.from_pretrained(
181
- config.base_model_path, subfolder="unet")
182
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
183
- config.base_model_path,
184
- config.motion_module_path,
185
- subfolder="unet",
186
- unet_additional_kwargs=OmegaConf.to_container(
187
- config.unet_additional_kwargs),
188
- use_landmark=False,
189
- )
190
- face_locator = FaceLocator(conditioning_embedding_channels=320)
191
- image_proj = ImageProjModel(
192
- cross_attention_dim=denoising_unet.config.cross_attention_dim,
193
- clip_embeddings_dim=512,
194
- clip_extra_context_tokens=4,
195
- )
196
-
197
- audio_proj = AudioProjModel(
198
- seq_len=5,
199
- blocks=12, # use 12 layers' hidden states of wav2vec
200
- channels=768, # audio embedding channel
201
- intermediate_dim=512,
202
- output_dim=768,
203
- context_tokens=32,
204
- ).to(device=device, dtype=weight_dtype)
205
-
206
- audio_ckpt_dir = config.audio_ckpt_dir
207
-
208
-
209
- # Freeze
210
- vae.requires_grad_(False)
211
- image_proj.requires_grad_(False)
212
- reference_unet.requires_grad_(False)
213
- denoising_unet.requires_grad_(False)
214
- face_locator.requires_grad_(False)
215
- audio_proj.requires_grad_(False)
216
-
217
- reference_unet.enable_gradient_checkpointing()
218
- denoising_unet.enable_gradient_checkpointing()
219
-
220
- net = Net(
221
- reference_unet,
222
- denoising_unet,
223
- face_locator,
224
- image_proj,
225
- audio_proj,
226
- )
227
-
228
- m,u = net.load_state_dict(
229
- torch.load(
230
- os.path.join(audio_ckpt_dir, "net.pth"),
231
- map_location="cpu",
232
- ),
233
- )
234
- assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
235
- print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
236
-
237
- # 5. inference
238
- pipeline = FaceAnimatePipeline(
239
- vae=vae,
240
- reference_unet=net.reference_unet,
241
- denoising_unet=net.denoising_unet,
242
- face_locator=net.face_locator,
243
- scheduler=val_noise_scheduler,
244
- image_proj=net.imageproj,
245
- )
246
- pipeline.to(device=device, dtype=weight_dtype)
247
-
248
- audio_emb = process_audio_emb(audio_emb)
249
-
250
- source_image_pixels = source_image_pixels.unsqueeze(0)
251
- source_image_face_region = source_image_face_region.unsqueeze(0)
252
- source_image_face_emb = source_image_face_emb.reshape(1, -1)
253
- source_image_face_emb = torch.tensor(source_image_face_emb)
254
-
255
- source_image_full_mask = [
256
- (mask.repeat(clip_length, 1))
257
- for mask in source_image_full_mask
258
- ]
259
- source_image_face_mask = [
260
- (mask.repeat(clip_length, 1))
261
- for mask in source_image_face_mask
262
- ]
263
- source_image_lip_mask = [
264
- (mask.repeat(clip_length, 1))
265
- for mask in source_image_lip_mask
266
- ]
267
-
268
-
269
- times = audio_emb.shape[0] // clip_length
270
-
271
- tensor_result = []
272
-
273
- generator = torch.manual_seed(42)
274
-
275
- for t in range(times):
276
- print(f"[{t+1}/{times}]")
277
-
278
- if len(tensor_result) == 0:
279
- # The first iteration
280
- motion_zeros = source_image_pixels.repeat(
281
- config.data.n_motion_frames, 1, 1, 1)
282
- motion_zeros = motion_zeros.to(
283
- dtype=source_image_pixels.dtype, device=source_image_pixels.device)
284
- pixel_values_ref_img = torch.cat(
285
- [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
286
- else:
287
- motion_frames = tensor_result[-1][0]
288
- motion_frames = motion_frames.permute(1, 0, 2, 3)
289
- motion_frames = motion_frames[0-config.data.n_motion_frames:]
290
- motion_frames = motion_frames * 2.0 - 1.0
291
- motion_frames = motion_frames.to(
292
- dtype=source_image_pixels.dtype, device=source_image_pixels.device)
293
- pixel_values_ref_img = torch.cat(
294
- [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
295
-
296
- pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
297
-
298
- audio_tensor = audio_emb[
299
- t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
300
- ]
301
- audio_tensor = audio_tensor.unsqueeze(0)
302
- audio_tensor = audio_tensor.to(
303
- device=net.audioproj.device, dtype=net.audioproj.dtype)
304
- audio_tensor = net.audioproj(audio_tensor)
305
-
306
- pipeline_output = pipeline(
307
- ref_image=pixel_values_ref_img,
308
- audio_tensor=audio_tensor,
309
- face_emb=source_image_face_emb,
310
- face_mask=source_image_face_region,
311
- pixel_values_full_mask=source_image_full_mask,
312
- pixel_values_face_mask=source_image_face_mask,
313
- pixel_values_lip_mask=source_image_lip_mask,
314
- width=img_size[0],
315
- height=img_size[1],
316
- video_length=clip_length,
317
- num_inference_steps=config.inference_steps,
318
- guidance_scale=config.cfg_scale,
319
- generator=generator,
320
- motion_scale=motion_scale,
321
- )
322
-
323
- tensor_result.append(pipeline_output.videos)
324
-
325
- tensor_result = torch.cat(tensor_result, dim=2)
326
- tensor_result = tensor_result.squeeze(0)
327
- tensor_result = tensor_result[:, :audio_length]
328
-
329
- output_file = config.output
330
- # save the result after all iteration
331
- tensor_to_video(tensor_result, output_file, driving_audio_path)
332
- return output_file
333
-
334
- def create_temp_dir():
335
- return tempfile.TemporaryDirectory()
336
-
337
- def save_uploaded_file(file, filename,TEMP_DIR):
338
- unique_filename = str(uuid.uuid4()) + "_" + filename
339
- file_path = os.path.join(TEMP_DIR.name, unique_filename)
340
- file.save(file_path)
341
- return file_path
342
-
343
- @app.route('/run', methods=['POST'])
344
- def generate_video():
345
- global TEMP_DIR
346
- TEMP_DIR = create_temp_dir()
347
- if request.method == 'POST':
348
- source_image = request.files['source_image']
349
- # text_prompt = request.form['text_prompt']
350
- # print('Input text prompt: ', text_prompt)
351
- # text_prompt = text_prompt.strip()
352
- # if not text_prompt:
353
- # return jsonify({'error': 'Input text prompt cannot be blank'}), 400
354
- driving_audio = request.files['driving_audio']
355
- source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR)
356
- print(source_image_path)
357
- driving_audio_path = save_uploaded_file(driving_audio, 'driving_audio.wav', TEMP_DIR)
358
- print(driving_audio_path)
359
- output_path = TEMP_DIR.name
360
-
361
- args = AnimationConfig(
362
- driven_audio_path=driving_audio_path,
363
- source_image_path=source_image_path,
364
- result_folder=output_path)
365
-
366
- try:
367
- # Run the inference process
368
- output_file = inference_process(args)
369
- return jsonify({"message": "Inference completed successfully", "output_file": os.path.abspath(output_file)})
370
- except Exception as e:
371
- return jsonify({"error": "Inference failed", "details": str(e)}), 500
372
-
373
-
374
-
375
- @app.route("/health", methods=["GET"])
376
- def health_status():
377
- response = {"online": "true"}
378
- return jsonify(response)
379
-
380
- if __name__ == '__main__':
381
  app.run(debug=True)
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDIMScheduler
6
+ from omegaconf import OmegaConf
7
+ from torch import nn
8
+
9
+ from hallo.animate.face_animate import FaceAnimatePipeline
10
+ from hallo.datasets.audio_processor import AudioProcessor
11
+ from hallo.datasets.image_processor import ImageProcessor
12
+ from hallo.models.audio_proj import AudioProjModel
13
+ from hallo.models.face_locator import FaceLocator
14
+ from hallo.models.image_proj import ImageProjModel
15
+ from hallo.models.unet_2d_condition import UNet2DConditionModel
16
+ from hallo.models.unet_3d import UNet3DConditionModel
17
+ from hallo.utils.config import filter_non_none
18
+ from hallo.utils.util import tensor_to_video
19
+
20
+ from flask import Flask, request, jsonify
21
+ import tempfile
22
+ import uuid
23
+
24
+ app = Flask(__name__)
25
+ TEMP_DIR = None
26
+
27
+ class Net(nn.Module):
28
+ """
29
+ The Net class combines all the necessary modules for the inference process.
30
+
31
+ Args:
32
+ reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
33
+ denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
34
+ face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
35
+ imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
36
+ audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
37
+ """
38
+ def __init__(
39
+ self,
40
+ reference_unet: UNet2DConditionModel,
41
+ denoising_unet: UNet3DConditionModel,
42
+ face_locator: FaceLocator,
43
+ imageproj,
44
+ audioproj,
45
+ ):
46
+ super().__init__()
47
+ self.reference_unet = reference_unet
48
+ self.denoising_unet = denoising_unet
49
+ self.face_locator = face_locator
50
+ self.imageproj = imageproj
51
+ self.audioproj = audioproj
52
+
53
+ def forward(self,):
54
+ """
55
+ empty function to override abstract function of nn Module
56
+ """
57
+
58
+ def get_modules(self):
59
+ """
60
+ Simple method to avoid too-few-public-methods pylint error
61
+ """
62
+ return {
63
+ "reference_unet": self.reference_unet,
64
+ "denoising_unet": self.denoising_unet,
65
+ "face_locator": self.face_locator,
66
+ "imageproj": self.imageproj,
67
+ "audioproj": self.audioproj,
68
+ }
69
+
70
+ class AnimationConfig:
71
+ def __init__(self, driven_audio_path, source_image_path, result_folder):
72
+ self.driven_audio = driven_audio_path
73
+ self.source_image = source_image_path
74
+ self.output = result_folder
75
+ self.config: 'configs/inference/default.yaml'
76
+
77
+
78
+ def process_audio_emb(audio_emb):
79
+ """
80
+ Process the audio embedding to concatenate with other tensors.
81
+
82
+ Parameters:
83
+ audio_emb (torch.Tensor): The audio embedding tensor to process.
84
+
85
+ Returns:
86
+ concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
87
+ """
88
+ concatenated_tensors = []
89
+
90
+ for i in range(audio_emb.shape[0]):
91
+ vectors_to_concat = [
92
+ audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
93
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
94
+
95
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
96
+
97
+ return audio_emb
98
+
99
+
100
+
101
+ def inference_process(args: argparse.Namespace):
102
+ """
103
+ Perform inference processing.
104
+
105
+ Args:
106
+ args (argparse.Namespace): Command-line arguments.
107
+
108
+ This function initializes the configuration for the inference process. It sets up the necessary
109
+ modules and variables to prepare for the upcoming inference steps.
110
+ """
111
+ # 1. init config
112
+ cli_args = filter_non_none(vars(args))
113
+ config = OmegaConf.load(args.config)
114
+ config = OmegaConf.merge(config, cli_args)
115
+ source_image_path = config.source_image
116
+ driving_audio_path = config.driving_audio
117
+ save_path = config.save_path
118
+ if not os.path.exists(save_path):
119
+ os.makedirs(save_path)
120
+ motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
121
+
122
+ # 2. runtime variables
123
+ device = torch.device(
124
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
125
+ if config.weight_dtype == "fp16":
126
+ weight_dtype = torch.float16
127
+ elif config.weight_dtype == "bf16":
128
+ weight_dtype = torch.bfloat16
129
+ elif config.weight_dtype == "fp32":
130
+ weight_dtype = torch.float32
131
+ else:
132
+ weight_dtype = torch.float32
133
+
134
+ # 3. prepare inference data
135
+ # 3.1 prepare source image, face mask, face embeddings
136
+ img_size = (config.data.source_image.width,
137
+ config.data.source_image.height)
138
+ clip_length = config.data.n_sample_frames
139
+ face_analysis_model_path = config.face_analysis.model_path
140
+ with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
141
+ source_image_pixels, \
142
+ source_image_face_region, \
143
+ source_image_face_emb, \
144
+ source_image_full_mask, \
145
+ source_image_face_mask, \
146
+ source_image_lip_mask = image_processor.preprocess(
147
+ source_image_path, save_path, config.face_expand_ratio)
148
+
149
+ # 3.2 prepare audio embeddings
150
+ sample_rate = config.data.driving_audio.sample_rate
151
+ assert sample_rate == 16000, "audio sample rate must be 16000"
152
+ fps = config.data.export_video.fps
153
+ wav2vec_model_path = config.wav2vec.model_path
154
+ wav2vec_only_last_features = config.wav2vec.features == "last"
155
+ audio_separator_model_file = config.audio_separator.model_path
156
+ with AudioProcessor(
157
+ sample_rate,
158
+ fps,
159
+ wav2vec_model_path,
160
+ wav2vec_only_last_features,
161
+ os.path.dirname(audio_separator_model_file),
162
+ os.path.basename(audio_separator_model_file),
163
+ os.path.join(save_path, "audio_preprocess")
164
+ ) as audio_processor:
165
+ audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
166
+
167
+ # 4. build modules
168
+ sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
169
+ if config.enable_zero_snr:
170
+ sched_kwargs.update(
171
+ rescale_betas_zero_snr=True,
172
+ timestep_spacing="trailing",
173
+ prediction_type="v_prediction",
174
+ )
175
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
176
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
177
+
178
+ vae = AutoencoderKL.from_pretrained(config.vae.model_path)
179
+ reference_unet = UNet2DConditionModel.from_pretrained(
180
+ config.base_model_path, subfolder="unet")
181
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
182
+ config.base_model_path,
183
+ config.motion_module_path,
184
+ subfolder="unet",
185
+ unet_additional_kwargs=OmegaConf.to_container(
186
+ config.unet_additional_kwargs),
187
+ use_landmark=False,
188
+ )
189
+ face_locator = FaceLocator(conditioning_embedding_channels=320)
190
+ image_proj = ImageProjModel(
191
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
192
+ clip_embeddings_dim=512,
193
+ clip_extra_context_tokens=4,
194
+ )
195
+
196
+ audio_proj = AudioProjModel(
197
+ seq_len=5,
198
+ blocks=12, # use 12 layers' hidden states of wav2vec
199
+ channels=768, # audio embedding channel
200
+ intermediate_dim=512,
201
+ output_dim=768,
202
+ context_tokens=32,
203
+ ).to(device=device, dtype=weight_dtype)
204
+
205
+ audio_ckpt_dir = config.audio_ckpt_dir
206
+
207
+
208
+ # Freeze
209
+ vae.requires_grad_(False)
210
+ image_proj.requires_grad_(False)
211
+ reference_unet.requires_grad_(False)
212
+ denoising_unet.requires_grad_(False)
213
+ face_locator.requires_grad_(False)
214
+ audio_proj.requires_grad_(False)
215
+
216
+ reference_unet.enable_gradient_checkpointing()
217
+ denoising_unet.enable_gradient_checkpointing()
218
+
219
+ net = Net(
220
+ reference_unet,
221
+ denoising_unet,
222
+ face_locator,
223
+ image_proj,
224
+ audio_proj,
225
+ )
226
+
227
+ m,u = net.load_state_dict(
228
+ torch.load(
229
+ os.path.join(audio_ckpt_dir, "net.pth"),
230
+ map_location="cpu",
231
+ ),
232
+ )
233
+ assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
234
+ print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
235
+
236
+ # 5. inference
237
+ pipeline = FaceAnimatePipeline(
238
+ vae=vae,
239
+ reference_unet=net.reference_unet,
240
+ denoising_unet=net.denoising_unet,
241
+ face_locator=net.face_locator,
242
+ scheduler=val_noise_scheduler,
243
+ image_proj=net.imageproj,
244
+ )
245
+ pipeline.to(device=device, dtype=weight_dtype)
246
+
247
+ audio_emb = process_audio_emb(audio_emb)
248
+
249
+ source_image_pixels = source_image_pixels.unsqueeze(0)
250
+ source_image_face_region = source_image_face_region.unsqueeze(0)
251
+ source_image_face_emb = source_image_face_emb.reshape(1, -1)
252
+ source_image_face_emb = torch.tensor(source_image_face_emb)
253
+
254
+ source_image_full_mask = [
255
+ (mask.repeat(clip_length, 1))
256
+ for mask in source_image_full_mask
257
+ ]
258
+ source_image_face_mask = [
259
+ (mask.repeat(clip_length, 1))
260
+ for mask in source_image_face_mask
261
+ ]
262
+ source_image_lip_mask = [
263
+ (mask.repeat(clip_length, 1))
264
+ for mask in source_image_lip_mask
265
+ ]
266
+
267
+
268
+ times = audio_emb.shape[0] // clip_length
269
+
270
+ tensor_result = []
271
+
272
+ generator = torch.manual_seed(42)
273
+
274
+ for t in range(times):
275
+ print(f"[{t+1}/{times}]")
276
+
277
+ if len(tensor_result) == 0:
278
+ # The first iteration
279
+ motion_zeros = source_image_pixels.repeat(
280
+ config.data.n_motion_frames, 1, 1, 1)
281
+ motion_zeros = motion_zeros.to(
282
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
283
+ pixel_values_ref_img = torch.cat(
284
+ [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
285
+ else:
286
+ motion_frames = tensor_result[-1][0]
287
+ motion_frames = motion_frames.permute(1, 0, 2, 3)
288
+ motion_frames = motion_frames[0-config.data.n_motion_frames:]
289
+ motion_frames = motion_frames * 2.0 - 1.0
290
+ motion_frames = motion_frames.to(
291
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
292
+ pixel_values_ref_img = torch.cat(
293
+ [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
294
+
295
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
296
+
297
+ audio_tensor = audio_emb[
298
+ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
299
+ ]
300
+ audio_tensor = audio_tensor.unsqueeze(0)
301
+ audio_tensor = audio_tensor.to(
302
+ device=net.audioproj.device, dtype=net.audioproj.dtype)
303
+ audio_tensor = net.audioproj(audio_tensor)
304
+
305
+ pipeline_output = pipeline(
306
+ ref_image=pixel_values_ref_img,
307
+ audio_tensor=audio_tensor,
308
+ face_emb=source_image_face_emb,
309
+ face_mask=source_image_face_region,
310
+ pixel_values_full_mask=source_image_full_mask,
311
+ pixel_values_face_mask=source_image_face_mask,
312
+ pixel_values_lip_mask=source_image_lip_mask,
313
+ width=img_size[0],
314
+ height=img_size[1],
315
+ video_length=clip_length,
316
+ num_inference_steps=config.inference_steps,
317
+ guidance_scale=config.cfg_scale,
318
+ generator=generator,
319
+ motion_scale=motion_scale,
320
+ )
321
+
322
+ tensor_result.append(pipeline_output.videos)
323
+
324
+ tensor_result = torch.cat(tensor_result, dim=2)
325
+ tensor_result = tensor_result.squeeze(0)
326
+ tensor_result = tensor_result[:, :audio_length]
327
+
328
+ output_file = config.output
329
+ # save the result after all iteration
330
+ tensor_to_video(tensor_result, output_file, driving_audio_path)
331
+ return output_file
332
+
333
+ def create_temp_dir():
334
+ return tempfile.TemporaryDirectory()
335
+
336
+ def save_uploaded_file(file, filename,TEMP_DIR):
337
+ unique_filename = str(uuid.uuid4()) + "_" + filename
338
+ file_path = os.path.join(TEMP_DIR.name, unique_filename)
339
+ file.save(file_path)
340
+ return file_path
341
+
342
+ @app.route('/run', methods=['POST'])
343
+ def generate_video():
344
+ global TEMP_DIR
345
+ TEMP_DIR = create_temp_dir()
346
+ if request.method == 'POST':
347
+ source_image = request.files['source_image']
348
+ # text_prompt = request.form['text_prompt']
349
+ # print('Input text prompt: ', text_prompt)
350
+ # text_prompt = text_prompt.strip()
351
+ # if not text_prompt:
352
+ # return jsonify({'error': 'Input text prompt cannot be blank'}), 400
353
+ driving_audio = request.files['driving_audio']
354
+ source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR)
355
+ print(source_image_path)
356
+ driving_audio_path = save_uploaded_file(driving_audio, 'driving_audio.wav', TEMP_DIR)
357
+ print(driving_audio_path)
358
+ result_folder = TEMP_DIR.name
359
+
360
+ args = AnimationConfig(
361
+ driving_audio_path=driving_audio_path,
362
+ source_image_path=source_image_path,
363
+ result_folder=result_folder)
364
+
365
+ try:
366
+ # Run the inference process
367
+ output_file = inference_process(args)
368
+ return jsonify({"message": "Inference completed successfully", "output_file": os.path.abspath(output_file)})
369
+ except Exception as e:
370
+ return jsonify({"error": "Inference failed", "details": str(e)}), 500
371
+
372
+
373
+
374
+ @app.route("/health", methods=["GET"])
375
+ def health_status():
376
+ response = {"online": "true"}
377
+ return jsonify(response)
378
+
379
+ if __name__ == '__main__':
 
380
  app.run(debug=True)