Spaces:
Running
Running
File size: 5,924 Bytes
1fd4e9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#coding=utf-8
import logging
from pathlib import Path
import torch
import torchaudio
from third_party.MMAudio.mmaudio.eval_utils import ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging
from third_party.MMAudio.mmaudio.model.flow_matching import FlowMatching
from third_party.MMAudio.mmaudio.model.networks import MMAudio, get_my_mmaudio
from third_party.MMAudio.mmaudio.model.utils.features_utils import FeaturesUtils
class V2A_MMAudio:
def __init__(self,
variant: str="large_44k",
num_steps: int=25,
full_precision: bool=False,):
self.log = logging.getLogger(self.__class__.__name__)
self.log.setLevel(logging.INFO)
self.log.info(f"The V2A model uses MMAudio {variant}, init...")
self.device = 'cpu'
if torch.cuda.is_available():
self.device = 'cuda'
elif torch.backends.mps.is_available():
self.device = 'mps'
else:
self.log.warning('CUDA/MPS are not available, running on CPU')
self.dtype = torch.float32 if full_precision else torch.bfloat16
if variant not in all_model_cfg:
raise ValueError(f'Unknown model variant: {variant}')
self.model: ModelConfig = all_model_cfg[variant]
self.model.download_if_needed()
self.net: MMAudio= get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
self.net.load_weights(torch.load(self.model.model_path, map_location=self.device, weights_only=True))
# Flow Matching
self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
# Feature utils setup
self.feature_utils = FeaturesUtils(tod_vae_ckpt=self.model.vae_path,
synchformer_ckpt=self.model.synchformer_ckpt,
enable_conditions=True,
mode=self.model.mode,
bigvgan_vocoder_ckpt=self.model.bigvgan_16k_path,
need_vae_encoder=False)
self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
@torch.no_grad()
def generate_audio(self,
video_path,
output_dir,
prompt: str='',
negative_prompt: str='',
duration: int=10,
seed: int=42,
cfg_strength: float=4.5,
mask_away_clip: bool=False,
is_postp=False,):
video_path = Path(video_path).expanduser()
output_dir = Path(output_dir).expanduser()
self.log.info(f"Loading video: {video_path}")
output_dir.mkdir(parents=True, exist_ok=True)
video_info = load_video(video_path, duration)
clip_frames = video_info.clip_frames
sync_frames = video_info.sync_frames
duration = video_info.duration_sec
# Setup random generator for reproducibility
rng = torch.Generator(device=self.device)
rng.manual_seed(seed)
if mask_away_clip:
clip_frames = None
else:
clip_frames = clip_frames.unsqueeze(0)
sync_frames = sync_frames.unsqueeze(0)
seq_cfg = self.model.seq_cfg
seq_cfg.duration = duration
self.net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
self.log.info(f'Prompt: {prompt}')
self.log.info(f'Negative prompt: {negative_prompt}')
self.log.info(f"Generating Audio...")
audios = generate(
clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt],
feature_utils=self.feature_utils,
net=self.net,
fm=self.fm,
rng=rng,
cfg_strength=cfg_strength)
audio = audios.float().cpu()[0]
if is_postp:
audio_save_path = output_dir / f'{video_path.stem}.neg.wav'
video_save_path = output_dir / f'{video_path.stem}.neg.mp4'
else:
audio_save_path = output_dir / f'{video_path.stem}.step1.wav'
video_save_path = output_dir / f'{video_path.stem}.step1.mp4'
self.log.info(f"Saving generated audio and video to {output_dir}")
torchaudio.save(str(audio_save_path), audio, seq_cfg.sampling_rate)
self.log.info(f'Audio saved to {audio_save_path}')
make_video(video_info, str(video_save_path), audio, sampling_rate=seq_cfg.sampling_rate)
self.log.info(f'Video saved to {video_save_path}')
return audio_save_path, video_save_path
# def main():
# # 初始化日志(如果你有 logger.py,推荐只做一次初始化)
# setup_eval_logging()
# # 初始化模型
# v2a_model = V2A_MMAudio(
# variant="large_44k_v2", # 这个是你模型的版本名
# num_steps=25, # 采样步数
# seed=42, # 随机种子
# full_precision=False # 是否使用全精度
# )
# # 视频路径(换成你的真实路径)
# video_path = "ZxiXftx2EMg_000477.mp4"
# # 输出目录
# output_dir = "outputs"
# # 提示词(控制生成内容)
# prompt = ""
# negative_prompt = ""
# # 生成音频 + 视频
# audio_save_path, video_save_path = v2a_model.generate_audio(
# video_path=video_path,
# output_dir=output_dir,
# prompt=prompt,
# negative_prompt=negative_prompt,
# duration=10, # 秒
# cfg_strength=4.5, # 指导强度
# mask_away_clip=False # 是否移除 clip
# )
# if __name__ == "__main__":
# main() |