openfree commited on
Commit
ec118f6
·
verified ·
1 Parent(s): 2399e79

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +295 -292
sonic.py CHANGED
@@ -1,12 +1,8 @@
1
- # sonic.py
2
- # ---------------------------------------------------------------------
3
- # Sonic – single-image + speech → talking-head video (offline edition)
4
- # ---------------------------------------------------------------------
5
- import os, math
6
- from typing import Dict, Any, List
7
-
8
  import torch
 
9
  from PIL import Image
 
10
  from omegaconf import OmegaConf
11
  from tqdm import tqdm
12
  import cv2
@@ -17,311 +13,318 @@ from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatur
17
 
18
  from src.utils.util import save_videos_grid, seed_everything
19
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
20
- from src.models.base.unet_spatio_temporal_condition import (
21
- UNetSpatioTemporalConditionModel,
22
- add_ip_adapters,
23
- )
24
  from src.pipelines.pipeline_sonic import SonicPipeline
25
  from src.models.audio_adapter.audio_proj import AudioProjModel
26
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
27
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
28
  from src.dataset.face_align.align import AlignImage
29
 
30
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
31
 
 
32
 
33
- # ------------------------------------------------------------------ #
34
- # 헬퍼 : diffusers 경로 자동 찾기 #
35
- # ------------------------------------------------------------------ #
36
- def _locate_diffusers_dir(root: str) -> str:
37
- """
38
- `root` 하위 디렉터리에서 diffusers 스냅샷(model_index.json or config.json)
39
- 이 들어 있는 실제 모델 폴더를 찾아서 반환한다. 존재하지 않으면 오류.
40
- """
41
- for cur, _dirs, files in os.walk(root):
42
- if {"model_index.json", "config.json"} & set(files):
43
- return cur
44
- raise FileNotFoundError(
45
- f"[ERROR] No diffusers model files found under '{root}'. "
46
- "Check that the checkpoint was downloaded correctly."
47
- )
48
-
49
-
50
- # ------------------------------------------------------------------ #
51
- # 영상 생성용 내부 함수 #
52
- # ------------------------------------------------------------------ #
53
- def _gen_video_tensor(
54
- pipe: SonicPipeline,
55
- cfg: OmegaConf,
56
- wav_enc: WhisperModel,
57
- audio_pe: AudioProjModel,
58
- audio2bucket: Audio2bucketModel,
59
- image_encoder: CLIPVisionModelWithProjection,
60
- width: int,
61
- height: int,
62
- batch: Dict[str, torch.Tensor],
63
- ) -> torch.Tensor:
64
- """
65
- single 이미지 + 오디오 feature → video tensor (C,T,H,W)
66
- """
67
-
68
- # -------- batch 차원 보정 --------------------------------------
69
  for k, v in batch.items():
70
  if isinstance(v, torch.Tensor):
71
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- ref_img = batch["ref_img"] # (1,C,H,W)
74
- clip_img = batch["clip_images"]
75
- face_mask = batch["face_mask"]
76
- image_embeds = image_encoder(clip_img).image_embeds
77
-
78
- audio_feat: torch.Tensor = batch["audio_feature"] # (1, 80, T)
79
- audio_len: int = int(batch["audio_len"]) # scalar
80
- step: int = int(cfg.step)
81
-
82
- # step 이 전체 길이보다 크면 최소 1 로 보정
83
- if audio_len < step:
84
- step = max(1, audio_len)
85
-
86
- # -------- Whisper encoder 1초 단위로 수행 ----------------------
87
- window = 16_000 # 1-s chunk
88
- aud_prompts: List[torch.Tensor] = []
89
- last_prompts: List[torch.Tensor] = []
90
-
91
- for i in range(0, audio_feat.shape[-1], window):
92
- chunk = audio_feat[:, :, i : i + window]
93
-
94
- # 모든 hidden-states / 마지막 hidden-state
95
- layers: List[torch.Tensor] = wav_enc.encoder(
96
- chunk, output_hidden_states=True
97
- ).hidden_states
98
- last_hidden = wav_enc.encoder(chunk).last_hidden_state # (1, 80, 384)
99
-
100
- # Whisper layer 는 6개 → AudioProj 가 기대하는 5개로 truncate
101
- prompt = torch.stack(layers, dim=2)[:, :, :5] # (1,80,5,384)
102
- aud_prompts.append(prompt)
103
- last_prompts.append(last_hidden.unsqueeze(-2)) # (1,80,1,384)
104
-
105
- if len(aud_prompts) == 0:
106
- raise ValueError("[ERROR] No speech recognised in the provided audio.")
107
-
108
- # concat 뒤 padding 규칙 적용
109
- aud_prompts = torch.cat(aud_prompts, dim=1) # (1, 80*…, 5, 384)
110
- last_prompts = torch.cat(last_prompts, dim=1) # (1, 80*…, 1, 384)
111
-
112
- aud_prompts = torch.cat(
113
- [torch.zeros_like(aud_prompts[:, :4]), aud_prompts, torch.zeros_like(aud_prompts[:, :6])],
114
- dim=1,
115
- )
116
- last_prompts = torch.cat(
117
- [torch.zeros_like(last_prompts[:, :24]), last_prompts, torch.zeros_like(last_prompts[:, :26])],
118
- dim=1,
119
- )
120
-
121
- # -------- f=10 / w=5 로 clip 자르기 --------------------------
122
- ref_list, aud_list, uncond_list, mb_list = [], [], [], []
123
-
124
- total_tokens = aud_prompts.shape[1]
125
- n_chunks = max(1, math.ceil(total_tokens / (2 * step)))
126
-
127
- for i in tqdm(range(n_chunks), desc="audio-chunks", ncols=0):
128
- s = i * 2 * step
129
-
130
- cond_clip = aud_prompts[:, s : s + 10] # (1,10,5,384)
131
- if cond_clip.shape[1] < 10: # 뒤쪽 padding
132
- pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]])
133
- cond_clip = torch.cat([cond_clip, pad], dim=1)
134
-
135
- bucket_clip = last_prompts[:, s : s + 50] # (1,50,1,384)
136
- if bucket_clip.shape[1] < 50:
137
- pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]])
138
- bucket_clip = torch.cat([bucket_clip, pad], dim=1)
139
-
140
- # (bz,f,w,b,c) 5-D 로 변환
141
- cond_clip = cond_clip.unsqueeze(3) # (1,10,5,1,384)
142
- bucket_clip = bucket_clip.unsqueeze(3) # (1,50,1,1,384)
143
- uncond_clip = torch.zeros_like(cond_clip)
144
-
145
- motion_bucket = audio2bucket(bucket_clip, image_embeds) * 16 + 16
146
-
147
- ref_list .append(ref_img[0])
148
- aud_list .append(audio_pe(cond_clip).squeeze(0)[0]) # (ctx,1024)
149
- uncond_list .append(audio_pe(uncond_clip).squeeze(0)[0]) # (ctx,1024)
150
- mb_list .append(motion_bucket[0])
151
-
152
- # -------- UNet 파이프라인 실행 --------------------------------
153
- video = (
154
- pipe(
155
- ref_img,
156
- clip_img,
157
- face_mask,
158
- aud_list,
159
- uncond_list,
160
- mb_list,
161
- height=height,
162
- width=width,
163
- num_frames=len(aud_list),
164
- decode_chunk_size=cfg.decode_chunk_size,
165
- motion_bucket_scale=cfg.motion_bucket_scale,
166
- fps=cfg.fps,
167
- noise_aug_strength=cfg.noise_aug_strength,
168
- min_guidance_scale1=cfg.min_appearance_guidance_scale,
169
- max_guidance_scale1=cfg.max_appearance_guidance_scale,
170
- min_guidance_scale2=cfg.audio_guidance_scale,
171
- max_guidance_scale2=cfg.audio_guidance_scale,
172
- overlap=cfg.overlap,
173
- shift_offset=cfg.shift_offset,
174
- frames_per_batch=cfg.n_sample_frames,
175
- num_inference_steps=cfg.num_inference_steps,
176
- i2i_noise_strength=cfg.i2i_noise_strength,
177
- ).frames
178
- * 0.5
179
- + 0.5
180
- ).clamp(0, 1)
181
-
182
- # (B,C,T,H,W) → (C,T,H,W)
183
- return video.to(pipe.device).squeeze(0).cpu()
184
-
185
-
186
- # ------------------------------------------------------------------ #
187
- # Sonic – main class #
188
- # ------------------------------------------------------------------ #
189
- class Sonic:
190
- config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
191
- config = OmegaConf.load(config_file)
192
-
193
- def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
194
- cfg = self.config
195
- cfg.use_interframe = enable_interpolate_frame
196
-
197
- # diffusers 모델 상위 폴더 (로컬 다운로드 경로)
198
- self.diffusers_root = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
199
- self.device = (
200
- f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
201
  )
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- self._load_models(cfg)
204
- print("Sonic init done")
205
-
206
- # -------------------------------------------------------------- #
207
- def _load_models(self, cfg):
208
- # dtype
209
- dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
210
-
211
- diff_root = _locate_diffusers_dir(self.diffusers_root)
212
-
213
- # diffusers 모듈들
214
- vae = AutoencoderKLTemporalDecoder.from_pretrained(diff_root, subfolder="vae", variant="fp16")
215
- sched = EulerDiscreteScheduler.from_pretrained(diff_root, subfolder="scheduler")
216
- img_e = CLIPVisionModelWithProjection.from_pretrained(diff_root, subfolder="image_encoder", variant="fp16")
217
- unet = UNetSpatioTemporalConditionModel.from_pretrained(diff_root, subfolder="unet", variant="fp16")
218
- add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
219
-
220
- # 오디오 어댑터
221
- a2t = AudioProjModel(seq_len=10, blocks=5, channels=384,
222
- intermediate_dim=1024, output_dim=1024, context_tokens=32).to(self.device)
223
- a2b = Audio2bucketModel(seq_len=50, blocks=1, channels=384,
224
- clip_channels=1024, intermediate_dim=1024, output_dim=1,
225
- context_tokens=2).to(self.device)
226
-
227
- # 체크포인트 로드
228
- a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
229
- a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
230
- unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
231
-
232
- # Whisper
233
- whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
234
  whisper.requires_grad_(False)
235
 
236
- # 이미지 / 얼굴 / 보간
237
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
238
- self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
239
- if cfg.use_interframe:
240
- self.rife = RIFEModel(device=self.device)
241
- self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
242
-
243
- # dtype 적용
244
- for m in (vae, img_e, unet):
245
- m.to(dtype)
246
-
247
- self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(self.device, dtype=dtype)
248
- self.image_encoder = img_e
249
- self.audio2token = a2t
250
- self.audio2bucket = a2b
251
- self.whisper = whisper
252
-
253
- # -------------------------------------------------------------- #
254
- def preprocess(self, image_path: str, expand_ratio: float = 1.0) -> Dict[str, Any]:
255
- img = cv2.imread(image_path)
256
- h, w = img.shape[:2]
257
- _, _, bboxes = self.face_det(img, maxface=True)
258
- if bboxes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  x1, y1, ww, hh = bboxes[0]
260
- crop = process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)
261
- return {"face_num": 1, "crop_bbox": crop}
262
- return {"face_num": 0, "crop_bbox": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- # -------------------------------------------------------------- #
265
  @torch.no_grad()
266
- def process(
267
- self,
268
- image_path: str,
269
- audio_path: str,
270
- output_path: str,
271
- min_resolution: int = 512,
272
- inference_steps: int = 25,
273
- dynamic_scale: float = 1.0,
274
- keep_resolution: bool = False,
275
- seed: int | None = None,
276
- ) -> int:
277
- cfg = self.config
278
- if seed is not None:
279
- cfg.seed = seed
280
- cfg.num_inference_steps = inference_steps
281
- cfg.motion_bucket_scale = dynamic_scale
282
- seed_everything(cfg.seed)
283
-
284
- # 이미지·오디오 tensor 변환
285
- data = image_audio_to_tensor(
286
- self.face_det,
287
- self.feature_extractor,
288
- image_path,
289
- audio_path,
290
- limit=-1,
291
- image_size=min_resolution,
292
- area=cfg.area,
293
- )
294
- if data is None:
 
 
 
 
 
 
 
295
  return -1
296
-
297
- h, w = data["ref_img"].shape[-2:]
298
  if keep_resolution:
299
- im = Image.open(image_path)
300
- resolution = f"{(im.width // 2) * 2}x{(im.height // 2) * 2}"
301
  else:
302
- resolution = f"{w}x{h}"
303
-
304
- # video tensor 생성
305
- video = _gen_video_tensor(
306
- self.pipe, cfg, self.whisper, self.audio2token, self.audio2bucket,
307
- self.image_encoder, w, h, data,
308
- )
309
-
310
- # 중간 프레임 보간
311
- if cfg.use_interframe:
312
- out = video.to(self.device)
313
- frames = []
314
- for i in tqdm(range(out.shape[1] - 1), desc="interpolate", ncols=0):
315
- frames.extend([out[:, i], self.rife.inference(out[:, i], out[:, i + 1]).clamp(0, 1)])
316
- frames.append(out[:, -1])
317
- video = torch.stack(frames, 1).cpu() # (C,T',H,W)
318
-
319
- # 저장
320
- tmp = output_path.replace(".mp4", "_noaudio.mp4")
321
- save_videos_grid(video.unsqueeze(0), tmp, n_rows=1, fps=cfg.fps * (2 if cfg.use_interframe else 1))
322
- os.system(
323
- f"ffmpeg -loglevel error -y -i '{tmp}' -i '{audio_path}' -s {resolution} "
324
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}'"
325
- )
326
- os.remove(tmp)
 
 
 
 
 
327
  return 0
 
 
1
+ import os
 
 
 
 
 
 
2
  import torch
3
+ import torch.utils.checkpoint
4
  from PIL import Image
5
+ import numpy as np
6
  from omegaconf import OmegaConf
7
  from tqdm import tqdm
8
  import cv2
 
13
 
14
  from src.utils.util import save_videos_grid, seed_everything
15
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
16
+ from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters
 
 
 
17
  from src.pipelines.pipeline_sonic import SonicPipeline
18
  from src.models.audio_adapter.audio_proj import AudioProjModel
19
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
20
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
21
  from src.dataset.face_align.align import AlignImage
22
 
 
23
 
24
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
 
26
+ def test(
27
+ pipe,
28
+ config,
29
+ wav_enc,
30
+ audio_pe,
31
+ audio2bucket,
32
+ image_encoder,
33
+ width,
34
+ height,
35
+ batch
36
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  for k, v in batch.items():
38
  if isinstance(v, torch.Tensor):
39
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
40
+ ref_img = batch['ref_img']
41
+ clip_img = batch['clip_images']
42
+ face_mask = batch['face_mask']
43
+ image_embeds = image_encoder(
44
+ clip_img
45
+ ).image_embeds
46
+
47
+ audio_feature = batch['audio_feature']
48
+ audio_len = batch['audio_len']
49
+ step = int(config.step)
50
+
51
+ window = 3000
52
+ audio_prompts = []
53
+ last_audio_prompts = []
54
+ for i in range(0, audio_feature.shape[-1], window):
55
+ audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
56
+ last_audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window]).last_hidden_state
57
+ last_audio_prompt = last_audio_prompt.unsqueeze(-2)
58
+ audio_prompt = torch.stack(audio_prompt, dim=2)
59
+ audio_prompts.append(audio_prompt)
60
+ last_audio_prompts.append(last_audio_prompt)
61
+
62
+ audio_prompts = torch.cat(audio_prompts, dim=1)
63
+ audio_prompts = audio_prompts[:,:audio_len*2]
64
+ audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:,:4]), audio_prompts, torch.zeros_like(audio_prompts[:,:6])], 1)
65
+
66
+ last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
67
+ last_audio_prompts = last_audio_prompts[:,:audio_len*2]
68
+ last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:,:24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:,:26])], 1)
69
+
70
+
71
+ ref_tensor_list = []
72
+ audio_tensor_list = []
73
+ uncond_audio_tensor_list = []
74
+ motion_buckets = []
75
+ for i in tqdm(range(audio_len//step)):
76
+
77
+
78
+ audio_clip = audio_prompts[:,i*2*step:i*2*step+10].unsqueeze(0)
79
+ audio_clip_for_bucket = last_audio_prompts[:,i*2*step:i*2*step+50].unsqueeze(0)
80
+ motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
81
+ motion_bucket = motion_bucket * 16 + 16
82
+ motion_buckets.append(motion_bucket[0])
83
+
84
+ cond_audio_clip = audio_pe(audio_clip).squeeze(0)
85
+ uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0)
86
+
87
+ ref_tensor_list.append(ref_img[0])
88
+ audio_tensor_list.append(cond_audio_clip[0])
89
+ uncond_audio_tensor_list.append(uncond_audio_clip[0])
90
+
91
+ video = pipe(
92
+ ref_img,
93
+ clip_img,
94
+ face_mask,
95
+ audio_tensor_list,
96
+ uncond_audio_tensor_list,
97
+ motion_buckets,
98
+ height=height,
99
+ width=width,
100
+ num_frames=len(audio_tensor_list),
101
+ decode_chunk_size=config.decode_chunk_size,
102
+ motion_bucket_scale=config.motion_bucket_scale,
103
+ fps=config.fps,
104
+ noise_aug_strength=config.noise_aug_strength,
105
+ min_guidance_scale1=config.min_appearance_guidance_scale, # 1.0,
106
+ max_guidance_scale1=config.max_appearance_guidance_scale,
107
+ min_guidance_scale2=config.audio_guidance_scale, # 1.0,
108
+ max_guidance_scale2=config.audio_guidance_scale,
109
+ overlap=config.overlap,
110
+ shift_offset=config.shift_offset,
111
+ frames_per_batch=config.n_sample_frames,
112
+ num_inference_steps=config.num_inference_steps,
113
+ i2i_noise_strength=config.i2i_noise_strength
114
+ ).frames
115
+
116
+
117
+ # Concat it with pose tensor
118
+ # pose_tensor = torch.stack(pose_tensor_list,1).unsqueeze(0)
119
+ video = (video*0.5 + 0.5).clamp(0, 1)
120
+ video = torch.cat([video.to(pipe.device)], dim=0).cpu()
121
+
122
+ return video
123
+
124
+
125
+ class Sonic():
126
+ config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml')
127
+ config = OmegaConf.load(config_file)
128
+
129
+ def __init__(self,
130
+ device_id=0,
131
+ enable_interpolate_frame=True,
132
+ ):
133
+
134
+ config = self.config
135
+ config.use_interframe = enable_interpolate_frame
136
+
137
+ device = 'cuda:{}'.format(device_id) if device_id > -1 else 'cpu'
138
+
139
+ config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
140
+
141
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
142
+ config.pretrained_model_name_or_path,
143
+ subfolder="vae",
144
+ variant="fp16")
145
+
146
+ val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
147
+ config.pretrained_model_name_or_path,
148
+ subfolder="scheduler")
149
+
150
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
151
+ config.pretrained_model_name_or_path,
152
+ subfolder="image_encoder",
153
+ variant="fp16")
154
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
155
+ config.pretrained_model_name_or_path,
156
+ subfolder="unet",
157
+ variant="fp16")
158
+ add_ip_adapters(unet, [32], [config.ip_audio_scale])
159
+
160
+ audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32).to(device)
161
+ audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2).to(device)
162
+
163
+ unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
164
+ audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
165
+ audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path)
166
+
167
+ unet.load_state_dict(
168
+ torch.load(unet_checkpoint_path, map_location="cpu"),
169
+ strict=True,
170
+ )
171
+
172
+ audio2token.load_state_dict(
173
+ torch.load(audio2token_checkpoint_path, map_location="cpu"),
174
+ strict=True,
175
+ )
176
 
177
+ audio2bucket.load_state_dict(
178
+ torch.load(audio2bucket_checkpoint_path, map_location="cpu"),
179
+ strict=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
181
+
182
+
183
+ if config.weight_dtype == "fp16":
184
+ weight_dtype = torch.float16
185
+ elif config.weight_dtype == "fp32":
186
+ weight_dtype = torch.float32
187
+ elif config.weight_dtype == "bf16":
188
+ weight_dtype = torch.bfloat16
189
+ else:
190
+ raise ValueError(
191
+ f"Do not support weight dtype: {config.weight_dtype} during training"
192
+ )
193
 
194
+ whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
195
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  whisper.requires_grad_(False)
197
 
198
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
199
+
200
+ det_path = os.path.join(BASE_DIR, os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt'))
201
+ self.face_det = AlignImage(device, det_path=det_path)
202
+ if config.use_interframe:
203
+ rife = RIFEModel(device=device)
204
+ rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
205
+ self.rife = rife
206
+
207
+
208
+ image_encoder.to(weight_dtype)
209
+ vae.to(weight_dtype)
210
+ unet.to(weight_dtype)
211
+
212
+ pipe = SonicPipeline(
213
+ unet=unet,
214
+ image_encoder=image_encoder,
215
+ vae=vae,
216
+ scheduler=val_noise_scheduler,
217
+ )
218
+ pipe = pipe.to(device=device, dtype=weight_dtype)
219
+
220
+
221
+ self.pipe = pipe
222
+ self.whisper = whisper
223
+ self.audio2token = audio2token
224
+ self.audio2bucket = audio2bucket
225
+ self.image_encoder = image_encoder
226
+ self.device = device
227
+
228
+ print('init done')
229
+
230
+
231
+ def preprocess(self,
232
+ image_path, expand_ratio=1.0):
233
+ face_image = cv2.imread(image_path)
234
+ h, w = face_image.shape[:2]
235
+ _, _, bboxes = self.face_det(face_image, maxface=True)
236
+ face_num = len(bboxes)
237
+ bbox = []
238
+ if face_num > 0:
239
  x1, y1, ww, hh = bboxes[0]
240
+ x2, y2 = x1 + ww, y1 + hh
241
+ bbox = x1, y1, x2, y2
242
+ bbox_s = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w)
243
+
244
+ return {
245
+ 'face_num': face_num,
246
+ 'crop_bbox': bbox_s,
247
+ }
248
+
249
+ def crop_image(self,
250
+ input_image_path,
251
+ output_image_path,
252
+ crop_bbox):
253
+ face_image = cv2.imread(input_image_path)
254
+ crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
255
+ cv2.imwrite(output_image_path, crop_image)
256
 
 
257
  @torch.no_grad()
258
+ def process(self,
259
+ image_path,
260
+ audio_path,
261
+ output_path,
262
+ min_resolution=512,
263
+ inference_steps=25,
264
+ dynamic_scale=1.0,
265
+ keep_resolution=False,
266
+ seed=None):
267
+
268
+ config = self.config
269
+ device = self.device
270
+ pipe = self.pipe
271
+ whisper = self.whisper
272
+ audio2token = self.audio2token
273
+ audio2bucket = self.audio2bucket
274
+ image_encoder = self.image_encoder
275
+
276
+ # specific parameters
277
+ if seed:
278
+ config.seed = seed
279
+
280
+ config.num_inference_steps = inference_steps
281
+
282
+ config.motion_bucket_scale = dynamic_scale
283
+
284
+ seed_everything(config.seed)
285
+
286
+ video_path = output_path.replace('.mp4', '_noaudio.mp4')
287
+ audio_video_path = output_path
288
+
289
+ imSrc_ = Image.open(image_path).convert('RGB')
290
+ raw_w, raw_h = imSrc_.size
291
+
292
+ test_data = image_audio_to_tensor(self.face_det, self.feature_extractor, image_path, audio_path, limit=config.frame_num, image_size=min_resolution, area=config.area)
293
+ if test_data is None:
294
  return -1
295
+ height, width = test_data['ref_img'].shape[-2:]
 
296
  if keep_resolution:
297
+ resolution = f'{raw_w//2*2}x{raw_h//2*2}'
 
298
  else:
299
+ resolution = f'{width}x{height}'
300
+
301
+ video = test(
302
+ pipe,
303
+ config,
304
+ wav_enc=whisper,
305
+ audio_pe=audio2token,
306
+ audio2bucket=audio2bucket,
307
+ image_encoder=image_encoder,
308
+ width=width,
309
+ height=height,
310
+ batch=test_data,
311
+ )
312
+
313
+ if config.use_interframe:
314
+ rife = self.rife
315
+ out = video.to(device)
316
+ results = []
317
+ video_len = out.shape[2]
318
+ for idx in tqdm(range(video_len-1), ncols=0):
319
+ I1 = out[:, :, idx]
320
+ I2 = out[:, :, idx+1]
321
+ middle = rife.inference(I1, I2).clamp(0, 1).detach()
322
+ results.append(out[:, :, idx])
323
+ results.append(middle)
324
+ results.append(out[:, :, video_len-1])
325
+ video = torch.stack(results, 2).cpu()
326
+
327
+ save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * 2 if config.use_interframe else config.fps)
328
+ os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'")
329
  return 0
330
+