nekoniii3 commited on
Commit
81ae7f5
·
1 Parent(s): 1678478
Files changed (2) hide show
  1. app.py +18 -10
  2. inference.py +0 -320
app.py CHANGED
@@ -1,20 +1,16 @@
1
  import gradio as gr
2
- import os
3
- import datetime
4
- import inference
5
 
6
- example1 = ["sample_data/ref1.jpg", "sample_data/ano.mp3"]
7
- example2 = ["sample_data/ref2.jpg", "sample_data/rakugo.mp3"]
8
-
9
- def fix_face_video(input_image, input_audio):
10
-
11
- # 調査用
12
- import subprocess
13
 
14
  cmd = ["lsb_release", "-a"]
15
  result = subprocess.run(cmd, capture_output=True)
16
  print(result.stdout.decode("utf-8"))
17
 
 
 
 
 
18
  cmd = ["pip", "list"]
19
  result = subprocess.run(cmd, capture_output=True)
20
  print(result.stdout.decode("utf-8"))
@@ -23,6 +19,18 @@ def fix_face_video(input_image, input_audio):
23
  result = subprocess.run(cmd, capture_output=True)
24
  print(result.stdout.decode("utf-8"))
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
 
 
1
  import gradio as gr
2
+ import subprocess
 
 
3
 
4
+ def greet(name):
 
 
 
 
 
 
5
 
6
  cmd = ["lsb_release", "-a"]
7
  result = subprocess.run(cmd, capture_output=True)
8
  print(result.stdout.decode("utf-8"))
9
 
10
+ cmd = ["python", "-V"]
11
+ result = subprocess.run(cmd, capture_output=True)
12
+ print(result.stdout.decode("utf-8"))
13
+
14
  cmd = ["pip", "list"]
15
  result = subprocess.run(cmd, capture_output=True)
16
  print(result.stdout.decode("utf-8"))
 
19
  result = subprocess.run(cmd, capture_output=True)
20
  print(result.stdout.decode("utf-8"))
21
 
22
+ return "Hello " + name + "!!"
23
+
24
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
25
+ demo.launch()
26
+
27
+ def fix_face_video(input_image, input_audio):
28
+
29
+ # 調査用
30
+
31
+
32
+
33
+
34
 
35
 
36
 
inference.py DELETED
@@ -1,320 +0,0 @@
1
- import argparse
2
-
3
- import os
4
- import cv2
5
- import numpy as np
6
- import torch
7
- import torchaudio.functional
8
- import torchvision.io
9
- from PIL import Image
10
- from diffusers import AutoencoderKL, DDIMScheduler
11
- from diffusers.utils.import_utils import is_xformers_available
12
- from diffusers.utils.torch_utils import randn_tensor
13
- from insightface.app import FaceAnalysis
14
- from omegaconf import OmegaConf
15
- from transformers import CLIPVisionModelWithProjection, Wav2Vec2Model, Wav2Vec2Processor
16
-
17
- from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection
18
- from pipelines import VExpressPipeline
19
- from pipelines.utils import draw_kps_image, save_video
20
- from pipelines.utils import retarget_kps
21
-
22
- import spaces
23
-
24
- # 引数用ダミークラス
25
- class args_dum:
26
-
27
- def __init__(self):
28
- self.unet_config_path='./model_ckpts/stable-diffusion-v1-5/unet/config.json'
29
- self.vae_path='./model_ckpts/sd-vae-ft-mse/'
30
- self.audio_encoder_path='./model_ckpts/wav2vec2-base-960h/'
31
- self.insightface_model_path='./model_ckpts/insightface_models/'
32
- self.denoising_unet_path='./model_ckpts/v-express/denoising_unet.pth'
33
- self.reference_net_path='./model_ckpts/v-express/reference_net.pth'
34
- self.v_kps_guider_path='./model_ckpts/v-express/v_kps_guider.pth'
35
- self.audio_projection_path='./model_ckpts/v-express/audio_projection.pth'
36
- self.motion_module_path='./model_ckpts/v-express/motion_module.pth'
37
- self.retarget_strategy='fix_face'
38
- self.device='cuda'
39
- self.gpu_id=0
40
- self.dtype='fp16'
41
- self.num_pad_audio_frames=2
42
- self.standard_audio_sampling_rate=16000
43
- self.reference_image_path='./test_samples/short_case/tys/ref.jpg'
44
- self.audio_path='./test_samples/short_case/tys/aud.mp3'
45
- self.kps_path='./test_samples/emo/talk_emotion/kps.pth'
46
- self.output_path='./output/short_case/talk_tys_fix_face.mp4'
47
- self.image_width=512
48
- self.image_height=512
49
- self.fps=30.0
50
- self.seed=42
51
- self.num_inference_steps=25
52
- self.guidance_scale=3.5
53
- self.context_frames=12
54
- self.context_stride=1
55
- self.context_overlap=4
56
- self.reference_attention_weight=0.95
57
- self.audio_attention_weight=3.0
58
-
59
- # def parse_args():
60
- # parser = argparse.ArgumentParser()
61
-
62
- # parser.add_argument('--unet_config_path', type=str, default='./model_ckpts/stable-diffusion-v1-5/unet/config.json')
63
- # parser.add_argument('--vae_path', type=str, default='./model_ckpts/sd-vae-ft-mse/')
64
- # parser.add_argument('--audio_encoder_path', type=str, default='./model_ckpts/wav2vec2-base-960h/')
65
- # parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/')
66
-
67
- # parser.add_argument('--denoising_unet_path', type=str, default='./model_ckpts/v-express/denoising_unet.pth')
68
- # parser.add_argument('--reference_net_path', type=str, default='./model_ckpts/v-express/reference_net.pth')
69
- # parser.add_argument('--v_kps_guider_path', type=str, default='./model_ckpts/v-express/v_kps_guider.pth')
70
- # parser.add_argument('--audio_projection_path', type=str, default='./model_ckpts/v-express/audio_projection.pth')
71
- # parser.add_argument('--motion_module_path', type=str, default='./model_ckpts/v-express/motion_module.pth')
72
-
73
- # parser.add_argument('--retarget_strategy', type=str, default='fix_face') # fix_face, no_retarget, offset_retarget, naive_retarget
74
-
75
- # parser.add_argument('--device', type=str, default='cuda')
76
- # parser.add_argument('--gpu_id', type=int, default=0)
77
- # parser.add_argument('--dtype', type=str, default='fp16')
78
-
79
- # parser.add_argument('--num_pad_audio_frames', type=int, default=2)
80
- # parser.add_argument('--standard_audio_sampling_rate', type=int, default=16000)
81
-
82
- # parser.add_argument('--reference_image_path', type=str, default='./test_samples/emo/talk_emotion/ref.jpg')
83
- # parser.add_argument('--audio_path', type=str, default='./test_samples/emo/talk_emotion/aud.mp3')
84
- # parser.add_argument('--kps_path', type=str, default='./test_samples/emo/talk_emotion/kps.pth')
85
- # parser.add_argument('--output_path', type=str, default='./output/emo/talk_emotion.mp4')
86
-
87
- # parser.add_argument('--image_width', type=int, default=512)
88
- # parser.add_argument('--image_height', type=int, default=512)
89
- # parser.add_argument('--fps', type=float, default=30.0)
90
- # parser.add_argument('--seed', type=int, default=42)
91
- # parser.add_argument('--num_inference_steps', type=int, default=25)
92
- # parser.add_argument('--guidance_scale', type=float, default=3.5)
93
- # parser.add_argument('--context_frames', type=int, default=12)
94
- # parser.add_argument('--context_stride', type=int, default=1)
95
- # parser.add_argument('--context_overlap', type=int, default=4)
96
- # parser.add_argument('--reference_attention_weight', default=0.95, type=float)
97
- # parser.add_argument('--audio_attention_weight', default=3., type=float)
98
-
99
- # args = parser.parse_args()
100
-
101
- # return args
102
-
103
-
104
- def load_reference_net(unet_config_path, reference_net_path, dtype, device):
105
- reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device)
106
- reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False)
107
- print(f'Loaded weights of Reference Net from {reference_net_path}.')
108
- return reference_net
109
-
110
-
111
- def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device):
112
- inference_config_path = './inference_v2.yaml'
113
- inference_config = OmegaConf.load(inference_config_path)
114
- denoising_unet = UNet3DConditionModel.from_config_2d(
115
- unet_config_path,
116
- unet_additional_kwargs=inference_config.unet_additional_kwargs,
117
- ).to(dtype=dtype, device=device)
118
- denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
119
- print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.')
120
-
121
- denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False)
122
- print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.')
123
-
124
- return denoising_unet
125
-
126
-
127
- def load_v_kps_guider(v_kps_guider_path, dtype, device):
128
- v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
129
- v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
130
- print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.')
131
- return v_kps_guider
132
-
133
-
134
- def load_audio_projection(
135
- audio_projection_path,
136
- dtype,
137
- device,
138
- inp_dim: int,
139
- mid_dim: int,
140
- out_dim: int,
141
- inp_seq_len: int,
142
- out_seq_len: int,
143
- ):
144
- audio_projection = AudioProjection(
145
- dim=mid_dim,
146
- depth=4,
147
- dim_head=64,
148
- heads=12,
149
- num_queries=out_seq_len,
150
- embedding_dim=inp_dim,
151
- output_dim=out_dim,
152
- ff_mult=4,
153
- max_seq_len=inp_seq_len,
154
- ).to(dtype=dtype, device=device)
155
- audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu'))
156
- print(f'Loaded weights of Audio Projection from {audio_projection_path}.')
157
- return audio_projection
158
-
159
-
160
- def get_scheduler():
161
- inference_config_path = './inference_v2.yaml'
162
- inference_config = OmegaConf.load(inference_config_path)
163
- scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
164
- scheduler = DDIMScheduler(**scheduler_kwargs)
165
- return scheduler
166
-
167
- @spaces.GPU
168
- def fix_face(image, audio, out_path):
169
- # args = parse_args()
170
- args = args_dum()
171
-
172
- args.reference_image_path = image
173
- args.audio_path = audio
174
- args.output_path = out_path
175
-
176
- # test
177
- # print(args)
178
- # return
179
-
180
- device = torch.device(f'{args.device}:{args.gpu_id}' if args.device == 'cuda' else args.device)
181
- dtype = torch.float16 if args.dtype == 'fp16' else torch.float32
182
-
183
- vae_path = args.vae_path
184
- audio_encoder_path = args.audio_encoder_path
185
-
186
- vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device)
187
- audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device)
188
- audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path)
189
-
190
- unet_config_path = args.unet_config_path
191
- reference_net_path = args.reference_net_path
192
- denoising_unet_path = args.denoising_unet_path
193
- v_kps_guider_path = args.v_kps_guider_path
194
- audio_projection_path = args.audio_projection_path
195
- motion_module_path = args.motion_module_path
196
-
197
- scheduler = get_scheduler()
198
- reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device)
199
- denoising_unet = load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device)
200
- v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device)
201
- audio_projection = load_audio_projection(
202
- audio_projection_path,
203
- dtype,
204
- device,
205
- inp_dim=denoising_unet.config.cross_attention_dim,
206
- mid_dim=denoising_unet.config.cross_attention_dim,
207
- out_dim=denoising_unet.config.cross_attention_dim,
208
- inp_seq_len=2 * (2 * args.num_pad_audio_frames + 1),
209
- out_seq_len=2 * args.num_pad_audio_frames + 1,
210
- )
211
-
212
- if is_xformers_available():
213
- reference_net.enable_xformers_memory_efficient_attention()
214
- denoising_unet.enable_xformers_memory_efficient_attention()
215
- else:
216
- raise ValueError("xformers is not available. Make sure it is installed correctly")
217
-
218
- generator = torch.manual_seed(args.seed)
219
- pipeline = VExpressPipeline(
220
- vae=vae,
221
- reference_net=reference_net,
222
- denoising_unet=denoising_unet,
223
- v_kps_guider=v_kps_guider,
224
- audio_processor=audio_processor,
225
- audio_encoder=audio_encoder,
226
- audio_projection=audio_projection,
227
- scheduler=scheduler,
228
- ).to(dtype=dtype, device=device)
229
-
230
- app = FaceAnalysis(
231
- providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'],
232
- provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [],
233
- root=args.insightface_model_path,
234
- )
235
- app.prepare(ctx_id=0, det_size=(args.image_height, args.image_width))
236
-
237
- reference_image = Image.open(args.reference_image_path).convert('RGB')
238
- reference_image = reference_image.resize((args.image_height, args.image_width))
239
-
240
- reference_image_for_kps = cv2.imread(args.reference_image_path)
241
- reference_image_for_kps = cv2.resize(reference_image_for_kps, (args.image_height, args.image_width))
242
- reference_kps = app.get(reference_image_for_kps)[0].kps[:3]
243
-
244
- _, audio_waveform, meta_info = torchvision.io.read_video(args.audio_path, pts_unit='sec')
245
- audio_sampling_rate = meta_info['audio_fps']
246
- print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.')
247
- if audio_sampling_rate != args.standard_audio_sampling_rate:
248
- audio_waveform = torchaudio.functional.resample(
249
- audio_waveform,
250
- orig_freq=audio_sampling_rate,
251
- new_freq=args.standard_audio_sampling_rate,
252
- )
253
- audio_waveform = audio_waveform.mean(dim=0)
254
-
255
- duration = audio_waveform.shape[0] / args.standard_audio_sampling_rate
256
- video_length = int(duration * args.fps)
257
- print(f'The corresponding video length is {video_length}.')
258
-
259
- if args.kps_path != "":
260
- assert os.path.exists(args.kps_path), f'{args.kps_path} does not exist'
261
- kps_sequence = torch.tensor(torch.load(args.kps_path)) # [len, 3, 2]
262
- print(f'The original length of kps sequence is {kps_sequence.shape[0]}.')
263
- kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear')
264
- kps_sequence = kps_sequence.permute(2, 0, 1)
265
- print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.')
266
-
267
- retarget_strategy = args.retarget_strategy
268
- if retarget_strategy == 'fix_face':
269
- kps_sequence = torch.tensor([reference_kps] * video_length)
270
- elif retarget_strategy == 'no_retarget':
271
- kps_sequence = kps_sequence
272
- elif retarget_strategy == 'offset_retarget':
273
- kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True)
274
- elif retarget_strategy == 'naive_retarget':
275
- kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False)
276
- else:
277
- raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.')
278
-
279
- kps_images = []
280
- for i in range(video_length):
281
- kps_image = np.zeros_like(reference_image_for_kps)
282
- kps_image = draw_kps_image(kps_image, kps_sequence[i])
283
- kps_images.append(Image.fromarray(kps_image))
284
-
285
- vae_scale_factor = 8
286
- latent_height = args.image_height // vae_scale_factor
287
- latent_width = args.image_width // vae_scale_factor
288
-
289
- latent_shape = (1, 4, video_length, latent_height, latent_width)
290
- vae_latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype)
291
-
292
- video_latents = pipeline(
293
- vae_latents=vae_latents,
294
- reference_image=reference_image,
295
- kps_images=kps_images,
296
- audio_waveform=audio_waveform,
297
- width=args.image_width,
298
- height=args.image_height,
299
- video_length=video_length,
300
- num_inference_steps=args.num_inference_steps,
301
- guidance_scale=args.guidance_scale,
302
- context_frames=args.context_frames,
303
- context_stride=args.context_stride,
304
- context_overlap=args.context_overlap,
305
- reference_attention_weight=args.reference_attention_weight,
306
- audio_attention_weight=args.audio_attention_weight,
307
- num_pad_audio_frames=args.num_pad_audio_frames,
308
- generator=generator,
309
- ).video_latents
310
-
311
- video_tensor = pipeline.decode_latents(video_latents)
312
- if isinstance(video_tensor, np.ndarray):
313
- video_tensor = torch.from_numpy(video_tensor)
314
-
315
- save_video(video_tensor, args.audio_path, args.output_path, args.fps)
316
- print(f'The generated video has been saved at {args.output_path}.')
317
-
318
-
319
- # if __name__ == '__main__':
320
- # main()