nekoniii3 commited on
Commit
1678478
1 Parent(s): cb3ed34

create new

Browse files
Files changed (2) hide show
  1. app.py +61 -0
  2. inference.py +320 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import datetime
4
+ import inference
5
+
6
+ example1 = ["sample_data/ref1.jpg", "sample_data/ano.mp3"]
7
+ example2 = ["sample_data/ref2.jpg", "sample_data/rakugo.mp3"]
8
+
9
+ def fix_face_video(input_image, input_audio):
10
+
11
+ # 調査用
12
+ import subprocess
13
+
14
+ cmd = ["lsb_release", "-a"]
15
+ result = subprocess.run(cmd, capture_output=True)
16
+ print(result.stdout.decode("utf-8"))
17
+
18
+ cmd = ["pip", "list"]
19
+ result = subprocess.run(cmd, capture_output=True)
20
+ print(result.stdout.decode("utf-8"))
21
+
22
+ cmd = ["nvcc", "-V"]
23
+ result = subprocess.run(cmd, capture_output=True)
24
+ print(result.stdout.decode("utf-8"))
25
+
26
+
27
+
28
+
29
+ dt = datetime.datetime.now() + datetime.timedelta(hours=9)
30
+ fol_name = dt.strftime("%Y%m%d")
31
+ file_name = dt.strftime("%H%M%S")
32
+
33
+ out_video = "./output/" + fol_name+ "/fix_face_" + file_name + ".mp4"
34
+
35
+ inference.fix_face(input_image, input_audio, out_video)
36
+
37
+ return out_video
38
+
39
+ image = gr.Image(label="画像(image)", type="filepath")
40
+ audio = gr.File(label="音声(audio)", file_types=[".mp3", ".MP3"])
41
+ out_video = gr.Video(label="Fix Face Video")
42
+ btn = gr.Button("送信", variant="primary")
43
+
44
+ title = "V_Express"
45
+ description = "<div style='text-align: center;'><h3>画像と音声だけで生成できます。(Using only images and audio)"
46
+ description += "<br>This uses the following V-Express \"https://github.com/tencent-ailab/V-Express\"</h3></div>"
47
+
48
+ demo = gr.Interface(
49
+ fn=fix_face_video,
50
+ inputs=[image, audio],
51
+ examples=[example1, example2],
52
+ outputs=[out_video],
53
+ title=title,
54
+ submit_btn=btn,
55
+ clear_btn=None,
56
+ description=description,
57
+ allow_flagging="never"
58
+ )
59
+
60
+ demo.queue()
61
+ demo.launch(share=True, debug=True)
inference.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio.functional
8
+ import torchvision.io
9
+ from PIL import Image
10
+ from diffusers import AutoencoderKL, DDIMScheduler
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from insightface.app import FaceAnalysis
14
+ from omegaconf import OmegaConf
15
+ from transformers import CLIPVisionModelWithProjection, Wav2Vec2Model, Wav2Vec2Processor
16
+
17
+ from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection
18
+ from pipelines import VExpressPipeline
19
+ from pipelines.utils import draw_kps_image, save_video
20
+ from pipelines.utils import retarget_kps
21
+
22
+ import spaces
23
+
24
+ # 引数用ダミークラス
25
+ class args_dum:
26
+
27
+ def __init__(self):
28
+ self.unet_config_path='./model_ckpts/stable-diffusion-v1-5/unet/config.json'
29
+ self.vae_path='./model_ckpts/sd-vae-ft-mse/'
30
+ self.audio_encoder_path='./model_ckpts/wav2vec2-base-960h/'
31
+ self.insightface_model_path='./model_ckpts/insightface_models/'
32
+ self.denoising_unet_path='./model_ckpts/v-express/denoising_unet.pth'
33
+ self.reference_net_path='./model_ckpts/v-express/reference_net.pth'
34
+ self.v_kps_guider_path='./model_ckpts/v-express/v_kps_guider.pth'
35
+ self.audio_projection_path='./model_ckpts/v-express/audio_projection.pth'
36
+ self.motion_module_path='./model_ckpts/v-express/motion_module.pth'
37
+ self.retarget_strategy='fix_face'
38
+ self.device='cuda'
39
+ self.gpu_id=0
40
+ self.dtype='fp16'
41
+ self.num_pad_audio_frames=2
42
+ self.standard_audio_sampling_rate=16000
43
+ self.reference_image_path='./test_samples/short_case/tys/ref.jpg'
44
+ self.audio_path='./test_samples/short_case/tys/aud.mp3'
45
+ self.kps_path='./test_samples/emo/talk_emotion/kps.pth'
46
+ self.output_path='./output/short_case/talk_tys_fix_face.mp4'
47
+ self.image_width=512
48
+ self.image_height=512
49
+ self.fps=30.0
50
+ self.seed=42
51
+ self.num_inference_steps=25
52
+ self.guidance_scale=3.5
53
+ self.context_frames=12
54
+ self.context_stride=1
55
+ self.context_overlap=4
56
+ self.reference_attention_weight=0.95
57
+ self.audio_attention_weight=3.0
58
+
59
+ # def parse_args():
60
+ # parser = argparse.ArgumentParser()
61
+
62
+ # parser.add_argument('--unet_config_path', type=str, default='./model_ckpts/stable-diffusion-v1-5/unet/config.json')
63
+ # parser.add_argument('--vae_path', type=str, default='./model_ckpts/sd-vae-ft-mse/')
64
+ # parser.add_argument('--audio_encoder_path', type=str, default='./model_ckpts/wav2vec2-base-960h/')
65
+ # parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/')
66
+
67
+ # parser.add_argument('--denoising_unet_path', type=str, default='./model_ckpts/v-express/denoising_unet.pth')
68
+ # parser.add_argument('--reference_net_path', type=str, default='./model_ckpts/v-express/reference_net.pth')
69
+ # parser.add_argument('--v_kps_guider_path', type=str, default='./model_ckpts/v-express/v_kps_guider.pth')
70
+ # parser.add_argument('--audio_projection_path', type=str, default='./model_ckpts/v-express/audio_projection.pth')
71
+ # parser.add_argument('--motion_module_path', type=str, default='./model_ckpts/v-express/motion_module.pth')
72
+
73
+ # parser.add_argument('--retarget_strategy', type=str, default='fix_face') # fix_face, no_retarget, offset_retarget, naive_retarget
74
+
75
+ # parser.add_argument('--device', type=str, default='cuda')
76
+ # parser.add_argument('--gpu_id', type=int, default=0)
77
+ # parser.add_argument('--dtype', type=str, default='fp16')
78
+
79
+ # parser.add_argument('--num_pad_audio_frames', type=int, default=2)
80
+ # parser.add_argument('--standard_audio_sampling_rate', type=int, default=16000)
81
+
82
+ # parser.add_argument('--reference_image_path', type=str, default='./test_samples/emo/talk_emotion/ref.jpg')
83
+ # parser.add_argument('--audio_path', type=str, default='./test_samples/emo/talk_emotion/aud.mp3')
84
+ # parser.add_argument('--kps_path', type=str, default='./test_samples/emo/talk_emotion/kps.pth')
85
+ # parser.add_argument('--output_path', type=str, default='./output/emo/talk_emotion.mp4')
86
+
87
+ # parser.add_argument('--image_width', type=int, default=512)
88
+ # parser.add_argument('--image_height', type=int, default=512)
89
+ # parser.add_argument('--fps', type=float, default=30.0)
90
+ # parser.add_argument('--seed', type=int, default=42)
91
+ # parser.add_argument('--num_inference_steps', type=int, default=25)
92
+ # parser.add_argument('--guidance_scale', type=float, default=3.5)
93
+ # parser.add_argument('--context_frames', type=int, default=12)
94
+ # parser.add_argument('--context_stride', type=int, default=1)
95
+ # parser.add_argument('--context_overlap', type=int, default=4)
96
+ # parser.add_argument('--reference_attention_weight', default=0.95, type=float)
97
+ # parser.add_argument('--audio_attention_weight', default=3., type=float)
98
+
99
+ # args = parser.parse_args()
100
+
101
+ # return args
102
+
103
+
104
+ def load_reference_net(unet_config_path, reference_net_path, dtype, device):
105
+ reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device)
106
+ reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False)
107
+ print(f'Loaded weights of Reference Net from {reference_net_path}.')
108
+ return reference_net
109
+
110
+
111
+ def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device):
112
+ inference_config_path = './inference_v2.yaml'
113
+ inference_config = OmegaConf.load(inference_config_path)
114
+ denoising_unet = UNet3DConditionModel.from_config_2d(
115
+ unet_config_path,
116
+ unet_additional_kwargs=inference_config.unet_additional_kwargs,
117
+ ).to(dtype=dtype, device=device)
118
+ denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
119
+ print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.')
120
+
121
+ denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False)
122
+ print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.')
123
+
124
+ return denoising_unet
125
+
126
+
127
+ def load_v_kps_guider(v_kps_guider_path, dtype, device):
128
+ v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
129
+ v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
130
+ print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.')
131
+ return v_kps_guider
132
+
133
+
134
+ def load_audio_projection(
135
+ audio_projection_path,
136
+ dtype,
137
+ device,
138
+ inp_dim: int,
139
+ mid_dim: int,
140
+ out_dim: int,
141
+ inp_seq_len: int,
142
+ out_seq_len: int,
143
+ ):
144
+ audio_projection = AudioProjection(
145
+ dim=mid_dim,
146
+ depth=4,
147
+ dim_head=64,
148
+ heads=12,
149
+ num_queries=out_seq_len,
150
+ embedding_dim=inp_dim,
151
+ output_dim=out_dim,
152
+ ff_mult=4,
153
+ max_seq_len=inp_seq_len,
154
+ ).to(dtype=dtype, device=device)
155
+ audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu'))
156
+ print(f'Loaded weights of Audio Projection from {audio_projection_path}.')
157
+ return audio_projection
158
+
159
+
160
+ def get_scheduler():
161
+ inference_config_path = './inference_v2.yaml'
162
+ inference_config = OmegaConf.load(inference_config_path)
163
+ scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
164
+ scheduler = DDIMScheduler(**scheduler_kwargs)
165
+ return scheduler
166
+
167
+ @spaces.GPU
168
+ def fix_face(image, audio, out_path):
169
+ # args = parse_args()
170
+ args = args_dum()
171
+
172
+ args.reference_image_path = image
173
+ args.audio_path = audio
174
+ args.output_path = out_path
175
+
176
+ # test
177
+ # print(args)
178
+ # return
179
+
180
+ device = torch.device(f'{args.device}:{args.gpu_id}' if args.device == 'cuda' else args.device)
181
+ dtype = torch.float16 if args.dtype == 'fp16' else torch.float32
182
+
183
+ vae_path = args.vae_path
184
+ audio_encoder_path = args.audio_encoder_path
185
+
186
+ vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device)
187
+ audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device)
188
+ audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path)
189
+
190
+ unet_config_path = args.unet_config_path
191
+ reference_net_path = args.reference_net_path
192
+ denoising_unet_path = args.denoising_unet_path
193
+ v_kps_guider_path = args.v_kps_guider_path
194
+ audio_projection_path = args.audio_projection_path
195
+ motion_module_path = args.motion_module_path
196
+
197
+ scheduler = get_scheduler()
198
+ reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device)
199
+ denoising_unet = load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device)
200
+ v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device)
201
+ audio_projection = load_audio_projection(
202
+ audio_projection_path,
203
+ dtype,
204
+ device,
205
+ inp_dim=denoising_unet.config.cross_attention_dim,
206
+ mid_dim=denoising_unet.config.cross_attention_dim,
207
+ out_dim=denoising_unet.config.cross_attention_dim,
208
+ inp_seq_len=2 * (2 * args.num_pad_audio_frames + 1),
209
+ out_seq_len=2 * args.num_pad_audio_frames + 1,
210
+ )
211
+
212
+ if is_xformers_available():
213
+ reference_net.enable_xformers_memory_efficient_attention()
214
+ denoising_unet.enable_xformers_memory_efficient_attention()
215
+ else:
216
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
217
+
218
+ generator = torch.manual_seed(args.seed)
219
+ pipeline = VExpressPipeline(
220
+ vae=vae,
221
+ reference_net=reference_net,
222
+ denoising_unet=denoising_unet,
223
+ v_kps_guider=v_kps_guider,
224
+ audio_processor=audio_processor,
225
+ audio_encoder=audio_encoder,
226
+ audio_projection=audio_projection,
227
+ scheduler=scheduler,
228
+ ).to(dtype=dtype, device=device)
229
+
230
+ app = FaceAnalysis(
231
+ providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'],
232
+ provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [],
233
+ root=args.insightface_model_path,
234
+ )
235
+ app.prepare(ctx_id=0, det_size=(args.image_height, args.image_width))
236
+
237
+ reference_image = Image.open(args.reference_image_path).convert('RGB')
238
+ reference_image = reference_image.resize((args.image_height, args.image_width))
239
+
240
+ reference_image_for_kps = cv2.imread(args.reference_image_path)
241
+ reference_image_for_kps = cv2.resize(reference_image_for_kps, (args.image_height, args.image_width))
242
+ reference_kps = app.get(reference_image_for_kps)[0].kps[:3]
243
+
244
+ _, audio_waveform, meta_info = torchvision.io.read_video(args.audio_path, pts_unit='sec')
245
+ audio_sampling_rate = meta_info['audio_fps']
246
+ print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.')
247
+ if audio_sampling_rate != args.standard_audio_sampling_rate:
248
+ audio_waveform = torchaudio.functional.resample(
249
+ audio_waveform,
250
+ orig_freq=audio_sampling_rate,
251
+ new_freq=args.standard_audio_sampling_rate,
252
+ )
253
+ audio_waveform = audio_waveform.mean(dim=0)
254
+
255
+ duration = audio_waveform.shape[0] / args.standard_audio_sampling_rate
256
+ video_length = int(duration * args.fps)
257
+ print(f'The corresponding video length is {video_length}.')
258
+
259
+ if args.kps_path != "":
260
+ assert os.path.exists(args.kps_path), f'{args.kps_path} does not exist'
261
+ kps_sequence = torch.tensor(torch.load(args.kps_path)) # [len, 3, 2]
262
+ print(f'The original length of kps sequence is {kps_sequence.shape[0]}.')
263
+ kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear')
264
+ kps_sequence = kps_sequence.permute(2, 0, 1)
265
+ print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.')
266
+
267
+ retarget_strategy = args.retarget_strategy
268
+ if retarget_strategy == 'fix_face':
269
+ kps_sequence = torch.tensor([reference_kps] * video_length)
270
+ elif retarget_strategy == 'no_retarget':
271
+ kps_sequence = kps_sequence
272
+ elif retarget_strategy == 'offset_retarget':
273
+ kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True)
274
+ elif retarget_strategy == 'naive_retarget':
275
+ kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False)
276
+ else:
277
+ raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.')
278
+
279
+ kps_images = []
280
+ for i in range(video_length):
281
+ kps_image = np.zeros_like(reference_image_for_kps)
282
+ kps_image = draw_kps_image(kps_image, kps_sequence[i])
283
+ kps_images.append(Image.fromarray(kps_image))
284
+
285
+ vae_scale_factor = 8
286
+ latent_height = args.image_height // vae_scale_factor
287
+ latent_width = args.image_width // vae_scale_factor
288
+
289
+ latent_shape = (1, 4, video_length, latent_height, latent_width)
290
+ vae_latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype)
291
+
292
+ video_latents = pipeline(
293
+ vae_latents=vae_latents,
294
+ reference_image=reference_image,
295
+ kps_images=kps_images,
296
+ audio_waveform=audio_waveform,
297
+ width=args.image_width,
298
+ height=args.image_height,
299
+ video_length=video_length,
300
+ num_inference_steps=args.num_inference_steps,
301
+ guidance_scale=args.guidance_scale,
302
+ context_frames=args.context_frames,
303
+ context_stride=args.context_stride,
304
+ context_overlap=args.context_overlap,
305
+ reference_attention_weight=args.reference_attention_weight,
306
+ audio_attention_weight=args.audio_attention_weight,
307
+ num_pad_audio_frames=args.num_pad_audio_frames,
308
+ generator=generator,
309
+ ).video_latents
310
+
311
+ video_tensor = pipeline.decode_latents(video_latents)
312
+ if isinstance(video_tensor, np.ndarray):
313
+ video_tensor = torch.from_numpy(video_tensor)
314
+
315
+ save_video(video_tensor, args.audio_path, args.output_path, args.fps)
316
+ print(f'The generated video has been saved at {args.output_path}.')
317
+
318
+
319
+ # if __name__ == '__main__':
320
+ # main()