#coding=utf-8 import logging import os from pathlib import Path import torch from huggingface_hub import snapshot_download import os from pathlib import Path import soundfile as sf import torch import torchvision from huggingface_hub import snapshot_download from moviepy.editor import AudioFileClip, VideoFileClip from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from third_party.FoleyCrafter.foleycrafter.models.onset import torch_utils from third_party.FoleyCrafter.foleycrafter.models.time_detector.model import VideoOnsetNet from third_party.FoleyCrafter.foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram from third_party.FoleyCrafter.foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy vision_transform_list = [ torchvision.transforms.Resize((128, 128)), torchvision.transforms.CenterCrop((112, 112)), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] video_transform = torchvision.transforms.Compose(vision_transform_list) model_base_dir = "pretrained/v2a/foleycrafter" class V2A_FoleyCrafter: def __init__(self, pretrained_model_name_or_path: str=f"{model_base_dir}/checkpoints/auffusion", ckpt: str=f"{model_base_dir}/checkpoints",): self.log = logging.getLogger(self.__class__.__name__) self.log.setLevel(logging.INFO) self.log.info(f"The V2A model uses FoleyCrafter, init...") self.device = 'cpu' if torch.cuda.is_available(): self.device = 'cuda' elif torch.backends.mps.is_available(): self.device = 'mps' else: self.log.warning('CUDA/MPS are not available, running on CPU') # download ckpt if not os.path.isdir(pretrained_model_name_or_path): pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) # ckpt path temporal_ckpt_path = os.path.join(ckpt, "temporal_adapter.ckpt") # load vocoder self.vocoder = Generator.from_pretrained(ckpt, subfolder="vocoder").to(self.device) # load time_detector time_detector_ckpt = os.path.join(ckpt, "timestamp_detector.pth.tar") self.time_detector = VideoOnsetNet(False) self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, self.time_detector, device=self.device, strict=True) # load adapters self.pipe = build_foleycrafter().to(self.device) ckpt = torch.load(temporal_ckpt_path) # load temporal adapter if "state_dict" in ckpt.keys(): ckpt = ckpt["state_dict"] load_gligen_ckpt = {} for key, value in ckpt.items(): if key.startswith("module."): load_gligen_ckpt[key[len("module.") :]] = value else: load_gligen_ckpt[key] = value m, u = self.pipe.controlnet.load_state_dict(load_gligen_ckpt, strict=False) print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};") # load semantic adapter self.pipe.load_ip_adapter( os.path.join(ckpt, "semantic"), subfolder="", weight_name="semantic_adapter.bin", image_encoder_folder=None ) # ip_adapter_weight = semantic_scale # self.pipe.set_ip_adapter_scale(ip_adapter_weight) self.generator = torch.Generator(device=self.device) # self.generator.manual_seed(seed) self.image_processor = CLIPImageProcessor() self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( "h94/IP-Adapter", subfolder="models/image_encoder" ).to(self.device) @torch.no_grad() def generate_audio(self, video_path, output_dir, prompt: str='', negative_prompt: str='', seed: int=42, temporal_scale: float=0.2, semantic_scale: float=1.0, is_postp=False,): self.pipe.set_ip_adapter_scale(semantic_scale) self.generator.manual_seed(seed) video_path = Path(video_path).expanduser() output_dir = Path(output_dir).expanduser() self.log.info(f"Loading video: {video_path}") output_dir.mkdir(parents=True, exist_ok=True) frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=150) time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2) time_frames = video_transform(time_frames) time_frames = {"frames": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)} preds = self.time_detector(time_frames) preds = torch.sigmoid(preds) time_condition = [ -1 if preds[0][int(i / (1024 / 10 * duration) * 150)] < 0.5 else 1 for i in range(int(1024 / 10 * duration)) ] time_condition = time_condition + [-1] * (1024 - len(time_condition)) # w -> b c h w time_condition = ( torch.FloatTensor(time_condition) .unsqueeze(0) .unsqueeze(0) .unsqueeze(0) .repeat(1, 1, 256, 1) .to("cuda") ) images = self.image_processor(images=frames, return_tensors="pt").to("cuda") image_embeddings = self.image_encoder(**images).image_embeds image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0) neg_image_embeddings = torch.zeros_like(image_embeddings) image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1) sample = self.pipe( prompt=prompt, negative_prompt=negative_prompt, ip_adapter_image_embeds=image_embeddings, image=time_condition, controlnet_conditioning_scale=temporal_scale, num_inference_steps=25, height=256, width=1024, output_type="pt", generator=self.generator, ) audio_img = sample.images[0] audio = denormalize_spectrogram(audio_img) audio = self.vocoder.inference(audio, lengths=160000)[0] audio = audio[: int(duration * 16000)] if is_postp: audio_save_path = output_dir / f'{video_path.stem}.neg.wav' video_save_path = output_dir / f'{video_path.stem}.neg.mp4' else: audio_save_path = output_dir / f'{video_path.stem}.step1.wav' video_save_path = output_dir / f'{video_path.stem}.step1.mp4' self.log.info(f"Saving generated audio and video to {output_dir}") sf.write(audio_save_path, audio, 16000) audio = AudioFileClip(audio_save_path) video = VideoFileClip(video_path) duration = min(audio.duration, video.duration) audio = audio.subclip(0, duration) video.audio = audio video = video.subclip(0, duration) video.write_videofile(video_save_path) self.log.info(f'Video saved to {video_save_path}') return audio_save_path, video_save_path