import os import torch import torch.utils.checkpoint from PIL import Image import numpy as np from omegaconf import OmegaConf from tqdm import tqdm import cv2 from diffusers import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor from src.utils.util import save_videos_grid, seed_everything from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters from src.pipelines.pipeline_sonic import SonicPipeline from src.models.audio_adapter.audio_proj import AudioProjModel from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel from src.utils.RIFE.RIFE_HDv3 import RIFEModel from src.dataset.face_align.align import AlignImage BASE_DIR = os.path.dirname(os.path.abspath(__file__)) def test( pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder, width, height, batch ): # 배치 텐서를 (1,B,C,H,W) 형태로 for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.unsqueeze(0).to(pipe.device).float() ref_img = batch['ref_img'] clip_img = batch['clip_images'] face_mask = batch['face_mask'] image_embeds = image_encoder(clip_img).image_embeds audio_feature = batch['audio_feature'] audio_len = batch['audio_len'] step = int(config.step) # window=3000 -> 16000으로 변경(1초 간격) window = 16000 audio_prompts = [] last_audio_prompts = [] for i in range(0, audio_feature.shape[-1], window): audio_clip_chunk = audio_feature[:, :, i:i+window] # Whisper encoder audio_prompt = wav_enc.encoder(audio_clip_chunk, output_hidden_states=True).hidden_states last_audio_prompt = wav_enc.encoder(audio_clip_chunk).last_hidden_state last_audio_prompt = last_audio_prompt.unsqueeze(-2) audio_prompt = torch.stack(audio_prompt, dim=2) audio_prompts.append(audio_prompt) last_audio_prompts.append(last_audio_prompt) # ★ 여기서 비었으면 예외 if len(audio_prompts) == 0: raise ValueError( "[ERROR] No speech recognized from the audio. " "Please provide a valid speech audio (with clear voice)." ) audio_prompts = torch.cat(audio_prompts, dim=1) audio_prompts = audio_prompts[:, :audio_len*2] audio_prompts = torch.cat([ torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6]) ], dim=1) last_audio_prompts = torch.cat(last_audio_prompts, dim=1) last_audio_prompts = last_audio_prompts[:, :audio_len*2] last_audio_prompts = torch.cat([ torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26]) ], dim=1) ref_tensor_list = [] audio_tensor_list = [] uncond_audio_tensor_list = [] motion_buckets = [] for i in tqdm(range(audio_len // step)): audio_clip = audio_prompts[:, i*2*step : i*2*step + 10].unsqueeze(0) audio_clip_for_bucket = last_audio_prompts[:, i*2*step : i*2*step + 50].unsqueeze(0) motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) motion_bucket = motion_bucket * 16 + 16 motion_buckets.append(motion_bucket[0]) cond_audio_clip = audio_pe(audio_clip).squeeze(0) uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0) ref_tensor_list.append(ref_img[0]) audio_tensor_list.append(cond_audio_clip[0]) uncond_audio_tensor_list.append(uncond_audio_clip[0]) video = pipe( ref_img, clip_img, face_mask, audio_tensor_list, uncond_audio_tensor_list, motion_buckets, height=height, width=width, num_frames=len(audio_tensor_list), decode_chunk_size=config.decode_chunk_size, motion_bucket_scale=config.motion_bucket_scale, fps=config.fps, noise_aug_strength=config.noise_aug_strength, min_guidance_scale1=config.min_appearance_guidance_scale, max_guidance_scale1=config.max_appearance_guidance_scale, min_guidance_scale2=config.audio_guidance_scale, max_guidance_scale2=config.audio_guidance_scale, overlap=config.overlap, shift_offset=config.shift_offset, frames_per_batch=config.n_sample_frames, num_inference_steps=config.num_inference_steps, i2i_noise_strength=config.i2i_noise_strength ).frames video = (video * 0.5 + 0.5).clamp(0, 1) video = torch.cat([video.to(pipe.device)], dim=0).cpu() return video class Sonic(): config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml') config = OmegaConf.load(config_file) def __init__(self, device_id=0, enable_interpolate_frame=True, ): config = self.config config.use_interframe = enable_interpolate_frame device = f'cuda:{device_id}' if device_id > -1 else 'cpu' config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path) # VAE vae = AutoencoderKLTemporalDecoder.from_pretrained( config.pretrained_model_name_or_path, subfolder="vae", variant="fp16") # 스케줄러 val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( config.pretrained_model_name_or_path, subfolder="scheduler") # CLIP Vision image_encoder = CLIPVisionModelWithProjection.from_pretrained( config.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16") # UNet unet = UNetSpatioTemporalConditionModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="unet", variant="fp16") # Adapter add_ip_adapters(unet, [32], [config.ip_audio_scale]) audio2token = AudioProjModel( seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32 ).to(device) audio2bucket = Audio2bucketModel( seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2 ).to(device) # 로컬 체크포인트 로드 unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path) audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path) audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path) unet.load_state_dict( torch.load(unet_checkpoint_path, map_location="cpu"), strict=True, ) audio2token.load_state_dict( torch.load(audio2token_checkpoint_path, map_location="cpu"), strict=True, ) audio2bucket.load_state_dict( torch.load(audio2bucket_checkpoint_path, map_location="cpu"), strict=True, ) # weight_dtype 설정 if config.weight_dtype == "fp16": weight_dtype = torch.float16 elif config.weight_dtype == "fp32": weight_dtype = torch.float32 elif config.weight_dtype == "bf16": weight_dtype = torch.bfloat16 else: raise ValueError(f"Do not support weight dtype: {config.weight_dtype}") # Whisper whisper = WhisperModel.from_pretrained( os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/') ).to(device).eval() whisper.requires_grad_(False) self.feature_extractor = AutoFeatureExtractor.from_pretrained( os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/') ) # Face detect det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt') self.face_det = AlignImage(device, det_path=det_path) # RIFE 중간프레임 보간 if config.use_interframe: rife = RIFEModel(device=device) rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/')) self.rife = rife # dtype 변경 image_encoder.to(weight_dtype) vae.to(weight_dtype) unet.to(weight_dtype) # SonicPipeline 초기화 pipe = SonicPipeline( unet=unet, image_encoder=image_encoder, vae=vae, scheduler=val_noise_scheduler, ) pipe = pipe.to(device=device, dtype=weight_dtype) self.pipe = pipe self.whisper = whisper self.audio2token = audio2token self.audio2bucket = audio2bucket self.image_encoder = image_encoder self.device = device print('Sonic init done') def preprocess(self, image_path, expand_ratio=1.0): face_image = cv2.imread(image_path) h, w = face_image.shape[:2] _, _, bboxes = self.face_det(face_image, maxface=True) face_num = len(bboxes) bbox_s = None if face_num > 0: x1, y1, ww, hh = bboxes[0] x2, y2 = x1 + ww, y1 + hh bbox = x1, y1, x2, y2 bbox_s = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w) return { 'face_num': face_num, 'crop_bbox': bbox_s, } def crop_image(self, input_image_path, output_image_path, crop_bbox): face_image = cv2.imread(input_image_path) crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]] cv2.imwrite(output_image_path, crop_image) @torch.no_grad() def process(self, image_path, audio_path, output_path, min_resolution=512, inference_steps=25, dynamic_scale=1.0, keep_resolution=False, seed=None): config = self.config device = self.device pipe = self.pipe whisper = self.whisper audio2token = self.audio2token audio2bucket = self.audio2bucket image_encoder = self.image_encoder # 시드 설정 if seed: config.seed = seed config.num_inference_steps = inference_steps config.motion_bucket_scale = dynamic_scale seed_everything(config.seed) video_path = output_path.replace('.mp4', '_noaudio.mp4') audio_video_path = output_path # 오디오+이미지 -> tensor test_data = image_audio_to_tensor( self.face_det, self.feature_extractor, image_path, audio_path, limit=-1, # 전체 사용 image_size=min_resolution, area=config.area ) if test_data is None: return -1 height, width = test_data['ref_img'].shape[-2:] if keep_resolution: imSrc_ = Image.open(image_path).convert('RGB') raw_w, raw_h = imSrc_.size resolution = f'{raw_w//2*2}x{raw_h//2*2}' else: resolution = f'{width}x{height}' # 여기서 test(...) 호출 video = test( pipe, config, wav_enc=whisper, audio_pe=audio2token, audio2bucket=audio2bucket, image_encoder=image_encoder, width=width, height=height, batch=test_data, ) # 중간프레임 보간 if config.use_interframe: rife = self.rife out = video.to(device) results = [] video_len = out.shape[2] for idx in tqdm(range(video_len - 1), ncols=0): I1 = out[:, :, idx] I2 = out[:, :, idx + 1] middle = rife.inference(I1, I2).clamp(0, 1).detach() results.append(out[:, :, idx]) results.append(middle) results.append(out[:, :, video_len - 1]) video = torch.stack(results, 2).cpu() # 비디오 저장 save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1)) # 오디오 합성 후 최종 mp4 os.system( f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} " f"-vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'" ) return 0