Upload app_hallo.py
Browse files- app_hallo.py +381 -0
app_hallo.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from hallo.animate.face_animate import FaceAnimatePipeline
|
10 |
+
from hallo.datasets.audio_processor import AudioProcessor
|
11 |
+
from hallo.datasets.image_processor import ImageProcessor
|
12 |
+
from hallo.models.audio_proj import AudioProjModel
|
13 |
+
from hallo.models.face_locator import FaceLocator
|
14 |
+
from hallo.models.image_proj import ImageProjModel
|
15 |
+
from hallo.models.unet_2d_condition import UNet2DConditionModel
|
16 |
+
from hallo.models.unet_3d import UNet3DConditionModel
|
17 |
+
from hallo.utils.config import filter_non_none
|
18 |
+
from hallo.utils.util import tensor_to_video
|
19 |
+
|
20 |
+
from flask import Flask, request, jsonify
|
21 |
+
import tempfile
|
22 |
+
import uuid
|
23 |
+
|
24 |
+
app = Flask(__name__)
|
25 |
+
TEMP_DIR = None
|
26 |
+
|
27 |
+
class Net(nn.Module):
|
28 |
+
"""
|
29 |
+
The Net class combines all the necessary modules for the inference process.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
|
33 |
+
denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
|
34 |
+
face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
|
35 |
+
imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
|
36 |
+
audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
|
37 |
+
"""
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
reference_unet: UNet2DConditionModel,
|
41 |
+
denoising_unet: UNet3DConditionModel,
|
42 |
+
face_locator: FaceLocator,
|
43 |
+
imageproj,
|
44 |
+
audioproj,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.reference_unet = reference_unet
|
48 |
+
self.denoising_unet = denoising_unet
|
49 |
+
self.face_locator = face_locator
|
50 |
+
self.imageproj = imageproj
|
51 |
+
self.audioproj = audioproj
|
52 |
+
|
53 |
+
def forward(self,):
|
54 |
+
"""
|
55 |
+
empty function to override abstract function of nn Module
|
56 |
+
"""
|
57 |
+
|
58 |
+
def get_modules(self):
|
59 |
+
"""
|
60 |
+
Simple method to avoid too-few-public-methods pylint error
|
61 |
+
"""
|
62 |
+
return {
|
63 |
+
"reference_unet": self.reference_unet,
|
64 |
+
"denoising_unet": self.denoising_unet,
|
65 |
+
"face_locator": self.face_locator,
|
66 |
+
"imageproj": self.imageproj,
|
67 |
+
"audioproj": self.audioproj,
|
68 |
+
}
|
69 |
+
|
70 |
+
class AnimationConfig:
|
71 |
+
def __init__(self, driven_audio_path, source_image_path, result_folder):
|
72 |
+
self.driven_audio = driven_audio_path
|
73 |
+
self.source_image = source_image_path
|
74 |
+
self.checkpoint_dir = './checkpoints'
|
75 |
+
self.result_dir = result_folder
|
76 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
77 |
+
|
78 |
+
|
79 |
+
def process_audio_emb(audio_emb):
|
80 |
+
"""
|
81 |
+
Process the audio embedding to concatenate with other tensors.
|
82 |
+
|
83 |
+
Parameters:
|
84 |
+
audio_emb (torch.Tensor): The audio embedding tensor to process.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
|
88 |
+
"""
|
89 |
+
concatenated_tensors = []
|
90 |
+
|
91 |
+
for i in range(audio_emb.shape[0]):
|
92 |
+
vectors_to_concat = [
|
93 |
+
audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
|
94 |
+
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
|
95 |
+
|
96 |
+
audio_emb = torch.stack(concatenated_tensors, dim=0)
|
97 |
+
|
98 |
+
return audio_emb
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def inference_process(args: argparse.Namespace):
|
103 |
+
"""
|
104 |
+
Perform inference processing.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
args (argparse.Namespace): Command-line arguments.
|
108 |
+
|
109 |
+
This function initializes the configuration for the inference process. It sets up the necessary
|
110 |
+
modules and variables to prepare for the upcoming inference steps.
|
111 |
+
"""
|
112 |
+
# 1. init config
|
113 |
+
cli_args = filter_non_none(vars(args))
|
114 |
+
config = OmegaConf.load(args.config)
|
115 |
+
config = OmegaConf.merge(config, cli_args)
|
116 |
+
source_image_path = config.source_image
|
117 |
+
driving_audio_path = config.driving_audio
|
118 |
+
save_path = config.save_path
|
119 |
+
if not os.path.exists(save_path):
|
120 |
+
os.makedirs(save_path)
|
121 |
+
motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
|
122 |
+
|
123 |
+
# 2. runtime variables
|
124 |
+
device = torch.device(
|
125 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
126 |
+
if config.weight_dtype == "fp16":
|
127 |
+
weight_dtype = torch.float16
|
128 |
+
elif config.weight_dtype == "bf16":
|
129 |
+
weight_dtype = torch.bfloat16
|
130 |
+
elif config.weight_dtype == "fp32":
|
131 |
+
weight_dtype = torch.float32
|
132 |
+
else:
|
133 |
+
weight_dtype = torch.float32
|
134 |
+
|
135 |
+
# 3. prepare inference data
|
136 |
+
# 3.1 prepare source image, face mask, face embeddings
|
137 |
+
img_size = (config.data.source_image.width,
|
138 |
+
config.data.source_image.height)
|
139 |
+
clip_length = config.data.n_sample_frames
|
140 |
+
face_analysis_model_path = config.face_analysis.model_path
|
141 |
+
with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
|
142 |
+
source_image_pixels, \
|
143 |
+
source_image_face_region, \
|
144 |
+
source_image_face_emb, \
|
145 |
+
source_image_full_mask, \
|
146 |
+
source_image_face_mask, \
|
147 |
+
source_image_lip_mask = image_processor.preprocess(
|
148 |
+
source_image_path, save_path, config.face_expand_ratio)
|
149 |
+
|
150 |
+
# 3.2 prepare audio embeddings
|
151 |
+
sample_rate = config.data.driving_audio.sample_rate
|
152 |
+
assert sample_rate == 16000, "audio sample rate must be 16000"
|
153 |
+
fps = config.data.export_video.fps
|
154 |
+
wav2vec_model_path = config.wav2vec.model_path
|
155 |
+
wav2vec_only_last_features = config.wav2vec.features == "last"
|
156 |
+
audio_separator_model_file = config.audio_separator.model_path
|
157 |
+
with AudioProcessor(
|
158 |
+
sample_rate,
|
159 |
+
fps,
|
160 |
+
wav2vec_model_path,
|
161 |
+
wav2vec_only_last_features,
|
162 |
+
os.path.dirname(audio_separator_model_file),
|
163 |
+
os.path.basename(audio_separator_model_file),
|
164 |
+
os.path.join(save_path, "audio_preprocess")
|
165 |
+
) as audio_processor:
|
166 |
+
audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
|
167 |
+
|
168 |
+
# 4. build modules
|
169 |
+
sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
|
170 |
+
if config.enable_zero_snr:
|
171 |
+
sched_kwargs.update(
|
172 |
+
rescale_betas_zero_snr=True,
|
173 |
+
timestep_spacing="trailing",
|
174 |
+
prediction_type="v_prediction",
|
175 |
+
)
|
176 |
+
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
177 |
+
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
178 |
+
|
179 |
+
vae = AutoencoderKL.from_pretrained(config.vae.model_path)
|
180 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
181 |
+
config.base_model_path, subfolder="unet")
|
182 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
183 |
+
config.base_model_path,
|
184 |
+
config.motion_module_path,
|
185 |
+
subfolder="unet",
|
186 |
+
unet_additional_kwargs=OmegaConf.to_container(
|
187 |
+
config.unet_additional_kwargs),
|
188 |
+
use_landmark=False,
|
189 |
+
)
|
190 |
+
face_locator = FaceLocator(conditioning_embedding_channels=320)
|
191 |
+
image_proj = ImageProjModel(
|
192 |
+
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
193 |
+
clip_embeddings_dim=512,
|
194 |
+
clip_extra_context_tokens=4,
|
195 |
+
)
|
196 |
+
|
197 |
+
audio_proj = AudioProjModel(
|
198 |
+
seq_len=5,
|
199 |
+
blocks=12, # use 12 layers' hidden states of wav2vec
|
200 |
+
channels=768, # audio embedding channel
|
201 |
+
intermediate_dim=512,
|
202 |
+
output_dim=768,
|
203 |
+
context_tokens=32,
|
204 |
+
).to(device=device, dtype=weight_dtype)
|
205 |
+
|
206 |
+
audio_ckpt_dir = config.audio_ckpt_dir
|
207 |
+
|
208 |
+
|
209 |
+
# Freeze
|
210 |
+
vae.requires_grad_(False)
|
211 |
+
image_proj.requires_grad_(False)
|
212 |
+
reference_unet.requires_grad_(False)
|
213 |
+
denoising_unet.requires_grad_(False)
|
214 |
+
face_locator.requires_grad_(False)
|
215 |
+
audio_proj.requires_grad_(False)
|
216 |
+
|
217 |
+
reference_unet.enable_gradient_checkpointing()
|
218 |
+
denoising_unet.enable_gradient_checkpointing()
|
219 |
+
|
220 |
+
net = Net(
|
221 |
+
reference_unet,
|
222 |
+
denoising_unet,
|
223 |
+
face_locator,
|
224 |
+
image_proj,
|
225 |
+
audio_proj,
|
226 |
+
)
|
227 |
+
|
228 |
+
m,u = net.load_state_dict(
|
229 |
+
torch.load(
|
230 |
+
os.path.join(audio_ckpt_dir, "net.pth"),
|
231 |
+
map_location="cpu",
|
232 |
+
),
|
233 |
+
)
|
234 |
+
assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
|
235 |
+
print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
|
236 |
+
|
237 |
+
# 5. inference
|
238 |
+
pipeline = FaceAnimatePipeline(
|
239 |
+
vae=vae,
|
240 |
+
reference_unet=net.reference_unet,
|
241 |
+
denoising_unet=net.denoising_unet,
|
242 |
+
face_locator=net.face_locator,
|
243 |
+
scheduler=val_noise_scheduler,
|
244 |
+
image_proj=net.imageproj,
|
245 |
+
)
|
246 |
+
pipeline.to(device=device, dtype=weight_dtype)
|
247 |
+
|
248 |
+
audio_emb = process_audio_emb(audio_emb)
|
249 |
+
|
250 |
+
source_image_pixels = source_image_pixels.unsqueeze(0)
|
251 |
+
source_image_face_region = source_image_face_region.unsqueeze(0)
|
252 |
+
source_image_face_emb = source_image_face_emb.reshape(1, -1)
|
253 |
+
source_image_face_emb = torch.tensor(source_image_face_emb)
|
254 |
+
|
255 |
+
source_image_full_mask = [
|
256 |
+
(mask.repeat(clip_length, 1))
|
257 |
+
for mask in source_image_full_mask
|
258 |
+
]
|
259 |
+
source_image_face_mask = [
|
260 |
+
(mask.repeat(clip_length, 1))
|
261 |
+
for mask in source_image_face_mask
|
262 |
+
]
|
263 |
+
source_image_lip_mask = [
|
264 |
+
(mask.repeat(clip_length, 1))
|
265 |
+
for mask in source_image_lip_mask
|
266 |
+
]
|
267 |
+
|
268 |
+
|
269 |
+
times = audio_emb.shape[0] // clip_length
|
270 |
+
|
271 |
+
tensor_result = []
|
272 |
+
|
273 |
+
generator = torch.manual_seed(42)
|
274 |
+
|
275 |
+
for t in range(times):
|
276 |
+
print(f"[{t+1}/{times}]")
|
277 |
+
|
278 |
+
if len(tensor_result) == 0:
|
279 |
+
# The first iteration
|
280 |
+
motion_zeros = source_image_pixels.repeat(
|
281 |
+
config.data.n_motion_frames, 1, 1, 1)
|
282 |
+
motion_zeros = motion_zeros.to(
|
283 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
284 |
+
pixel_values_ref_img = torch.cat(
|
285 |
+
[source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
|
286 |
+
else:
|
287 |
+
motion_frames = tensor_result[-1][0]
|
288 |
+
motion_frames = motion_frames.permute(1, 0, 2, 3)
|
289 |
+
motion_frames = motion_frames[0-config.data.n_motion_frames:]
|
290 |
+
motion_frames = motion_frames * 2.0 - 1.0
|
291 |
+
motion_frames = motion_frames.to(
|
292 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
293 |
+
pixel_values_ref_img = torch.cat(
|
294 |
+
[source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
|
295 |
+
|
296 |
+
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
297 |
+
|
298 |
+
audio_tensor = audio_emb[
|
299 |
+
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
|
300 |
+
]
|
301 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
302 |
+
audio_tensor = audio_tensor.to(
|
303 |
+
device=net.audioproj.device, dtype=net.audioproj.dtype)
|
304 |
+
audio_tensor = net.audioproj(audio_tensor)
|
305 |
+
|
306 |
+
pipeline_output = pipeline(
|
307 |
+
ref_image=pixel_values_ref_img,
|
308 |
+
audio_tensor=audio_tensor,
|
309 |
+
face_emb=source_image_face_emb,
|
310 |
+
face_mask=source_image_face_region,
|
311 |
+
pixel_values_full_mask=source_image_full_mask,
|
312 |
+
pixel_values_face_mask=source_image_face_mask,
|
313 |
+
pixel_values_lip_mask=source_image_lip_mask,
|
314 |
+
width=img_size[0],
|
315 |
+
height=img_size[1],
|
316 |
+
video_length=clip_length,
|
317 |
+
num_inference_steps=config.inference_steps,
|
318 |
+
guidance_scale=config.cfg_scale,
|
319 |
+
generator=generator,
|
320 |
+
motion_scale=motion_scale,
|
321 |
+
)
|
322 |
+
|
323 |
+
tensor_result.append(pipeline_output.videos)
|
324 |
+
|
325 |
+
tensor_result = torch.cat(tensor_result, dim=2)
|
326 |
+
tensor_result = tensor_result.squeeze(0)
|
327 |
+
tensor_result = tensor_result[:, :audio_length]
|
328 |
+
|
329 |
+
output_file = config.output
|
330 |
+
# save the result after all iteration
|
331 |
+
tensor_to_video(tensor_result, output_file, driving_audio_path)
|
332 |
+
return output_file
|
333 |
+
|
334 |
+
def create_temp_dir():
|
335 |
+
return tempfile.TemporaryDirectory()
|
336 |
+
|
337 |
+
def save_uploaded_file(file, filename,TEMP_DIR):
|
338 |
+
unique_filename = str(uuid.uuid4()) + "_" + filename
|
339 |
+
file_path = os.path.join(TEMP_DIR.name, unique_filename)
|
340 |
+
file.save(file_path)
|
341 |
+
return file_path
|
342 |
+
|
343 |
+
@app.route('/run', methods=['POST'])
|
344 |
+
def generate_video():
|
345 |
+
global TEMP_DIR
|
346 |
+
TEMP_DIR = create_temp_dir()
|
347 |
+
if request.method == 'POST':
|
348 |
+
source_image = request.files['source_image']
|
349 |
+
# text_prompt = request.form['text_prompt']
|
350 |
+
# print('Input text prompt: ', text_prompt)
|
351 |
+
# text_prompt = text_prompt.strip()
|
352 |
+
# if not text_prompt:
|
353 |
+
# return jsonify({'error': 'Input text prompt cannot be blank'}), 400
|
354 |
+
driving_audio = request.files['driving_audio']
|
355 |
+
source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR)
|
356 |
+
print(source_image_path)
|
357 |
+
driving_audio_path = save_uploaded_file(driving_audio, 'driving_audio.wav', TEMP_DIR)
|
358 |
+
print(driving_audio_path)
|
359 |
+
output_path = TEMP_DIR.name
|
360 |
+
|
361 |
+
args = AnimationConfig(
|
362 |
+
driven_audio_path=driving_audio_path,
|
363 |
+
source_image_path=source_image_path,
|
364 |
+
result_folder=output_path)
|
365 |
+
|
366 |
+
try:
|
367 |
+
# Run the inference process
|
368 |
+
output_file = inference_process(args)
|
369 |
+
return jsonify({"message": "Inference completed successfully", "output_file": os.path.abspath(output_file)})
|
370 |
+
except Exception as e:
|
371 |
+
return jsonify({"error": "Inference failed", "details": str(e)}), 500
|
372 |
+
|
373 |
+
|
374 |
+
|
375 |
+
@app.route("/health", methods=["GET"])
|
376 |
+
def health_status():
|
377 |
+
response = {"online": "true"}
|
378 |
+
return jsonify(response)
|
379 |
+
|
380 |
+
if __name__ == '__main__':
|
381 |
+
app.run(debug=True)
|