Spaces:
Running
Running
File size: 7,200 Bytes
1fd4e9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
#coding=utf-8
import logging
import os
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
import os
from pathlib import Path
import soundfile as sf
import torch
import torchvision
from huggingface_hub import snapshot_download
from moviepy.editor import AudioFileClip, VideoFileClip
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from third_party.FoleyCrafter.foleycrafter.models.onset import torch_utils
from third_party.FoleyCrafter.foleycrafter.models.time_detector.model import VideoOnsetNet
from third_party.FoleyCrafter.foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram
from third_party.FoleyCrafter.foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy
vision_transform_list = [
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.CenterCrop((112, 112)),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
video_transform = torchvision.transforms.Compose(vision_transform_list)
model_base_dir = "pretrained/v2a/foleycrafter"
class V2A_FoleyCrafter:
def __init__(self,
pretrained_model_name_or_path: str=f"{model_base_dir}/checkpoints/auffusion",
ckpt: str=f"{model_base_dir}/checkpoints",):
self.log = logging.getLogger(self.__class__.__name__)
self.log.setLevel(logging.INFO)
self.log.info(f"The V2A model uses FoleyCrafter, init...")
self.device = 'cpu'
if torch.cuda.is_available():
self.device = 'cuda'
elif torch.backends.mps.is_available():
self.device = 'mps'
else:
self.log.warning('CUDA/MPS are not available, running on CPU')
# download ckpt
if not os.path.isdir(pretrained_model_name_or_path):
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
# ckpt path
temporal_ckpt_path = os.path.join(ckpt, "temporal_adapter.ckpt")
# load vocoder
self.vocoder = Generator.from_pretrained(ckpt, subfolder="vocoder").to(self.device)
# load time_detector
time_detector_ckpt = os.path.join(ckpt, "timestamp_detector.pth.tar")
self.time_detector = VideoOnsetNet(False)
self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, self.time_detector, device=self.device, strict=True)
# load adapters
self.pipe = build_foleycrafter().to(self.device)
ckpt = torch.load(temporal_ckpt_path)
# load temporal adapter
if "state_dict" in ckpt.keys():
ckpt = ckpt["state_dict"]
load_gligen_ckpt = {}
for key, value in ckpt.items():
if key.startswith("module."):
load_gligen_ckpt[key[len("module.") :]] = value
else:
load_gligen_ckpt[key] = value
m, u = self.pipe.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
# load semantic adapter
self.pipe.load_ip_adapter(
os.path.join(ckpt, "semantic"), subfolder="", weight_name="semantic_adapter.bin", image_encoder_folder=None
)
# ip_adapter_weight = semantic_scale
# self.pipe.set_ip_adapter_scale(ip_adapter_weight)
self.generator = torch.Generator(device=self.device)
# self.generator.manual_seed(seed)
self.image_processor = CLIPImageProcessor()
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter", subfolder="models/image_encoder"
).to(self.device)
@torch.no_grad()
def generate_audio(self,
video_path,
output_dir,
prompt: str='',
negative_prompt: str='',
seed: int=42,
temporal_scale: float=0.2,
semantic_scale: float=1.0,
is_postp=False,):
self.pipe.set_ip_adapter_scale(semantic_scale)
self.generator.manual_seed(seed)
video_path = Path(video_path).expanduser()
output_dir = Path(output_dir).expanduser()
self.log.info(f"Loading video: {video_path}")
output_dir.mkdir(parents=True, exist_ok=True)
frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=150)
time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2)
time_frames = video_transform(time_frames)
time_frames = {"frames": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
preds = self.time_detector(time_frames)
preds = torch.sigmoid(preds)
time_condition = [
-1 if preds[0][int(i / (1024 / 10 * duration) * 150)] < 0.5 else 1
for i in range(int(1024 / 10 * duration))
]
time_condition = time_condition + [-1] * (1024 - len(time_condition))
# w -> b c h w
time_condition = (
torch.FloatTensor(time_condition)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(1, 1, 256, 1)
.to("cuda")
)
images = self.image_processor(images=frames, return_tensors="pt").to("cuda")
image_embeddings = self.image_encoder(**images).image_embeds
image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
neg_image_embeddings = torch.zeros_like(image_embeddings)
image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
sample = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
ip_adapter_image_embeds=image_embeddings,
image=time_condition,
controlnet_conditioning_scale=temporal_scale,
num_inference_steps=25,
height=256,
width=1024,
output_type="pt",
generator=self.generator,
)
audio_img = sample.images[0]
audio = denormalize_spectrogram(audio_img)
audio = self.vocoder.inference(audio, lengths=160000)[0]
audio = audio[: int(duration * 16000)]
if is_postp:
audio_save_path = output_dir / f'{video_path.stem}.neg.wav'
video_save_path = output_dir / f'{video_path.stem}.neg.mp4'
else:
audio_save_path = output_dir / f'{video_path.stem}.step1.wav'
video_save_path = output_dir / f'{video_path.stem}.step1.mp4'
self.log.info(f"Saving generated audio and video to {output_dir}")
sf.write(audio_save_path, audio, 16000)
audio = AudioFileClip(audio_save_path)
video = VideoFileClip(video_path)
duration = min(audio.duration, video.duration)
audio = audio.subclip(0, duration)
video.audio = audio
video = video.subclip(0, duration)
video.write_videofile(video_save_path)
self.log.info(f'Video saved to {video_save_path}')
return audio_save_path, video_save_path
|