DeepSound-V1 / v2a_models /v2a_mmaudio.py
lym0302
our
1fd4e9c
#coding=utf-8
import logging
from pathlib import Path
import torch
import torchaudio
from third_party.MMAudio.mmaudio.eval_utils import ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging
from third_party.MMAudio.mmaudio.model.flow_matching import FlowMatching
from third_party.MMAudio.mmaudio.model.networks import MMAudio, get_my_mmaudio
from third_party.MMAudio.mmaudio.model.utils.features_utils import FeaturesUtils
class V2A_MMAudio:
def __init__(self,
variant: str="large_44k",
num_steps: int=25,
full_precision: bool=False,):
self.log = logging.getLogger(self.__class__.__name__)
self.log.setLevel(logging.INFO)
self.log.info(f"The V2A model uses MMAudio {variant}, init...")
self.device = 'cpu'
if torch.cuda.is_available():
self.device = 'cuda'
elif torch.backends.mps.is_available():
self.device = 'mps'
else:
self.log.warning('CUDA/MPS are not available, running on CPU')
self.dtype = torch.float32 if full_precision else torch.bfloat16
if variant not in all_model_cfg:
raise ValueError(f'Unknown model variant: {variant}')
self.model: ModelConfig = all_model_cfg[variant]
self.model.download_if_needed()
self.net: MMAudio= get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
self.net.load_weights(torch.load(self.model.model_path, map_location=self.device, weights_only=True))
# Flow Matching
self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
# Feature utils setup
self.feature_utils = FeaturesUtils(tod_vae_ckpt=self.model.vae_path,
synchformer_ckpt=self.model.synchformer_ckpt,
enable_conditions=True,
mode=self.model.mode,
bigvgan_vocoder_ckpt=self.model.bigvgan_16k_path,
need_vae_encoder=False)
self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
@torch.no_grad()
def generate_audio(self,
video_path,
output_dir,
prompt: str='',
negative_prompt: str='',
duration: int=10,
seed: int=42,
cfg_strength: float=4.5,
mask_away_clip: bool=False,
is_postp=False,):
video_path = Path(video_path).expanduser()
output_dir = Path(output_dir).expanduser()
self.log.info(f"Loading video: {video_path}")
output_dir.mkdir(parents=True, exist_ok=True)
video_info = load_video(video_path, duration)
clip_frames = video_info.clip_frames
sync_frames = video_info.sync_frames
duration = video_info.duration_sec
# Setup random generator for reproducibility
rng = torch.Generator(device=self.device)
rng.manual_seed(seed)
if mask_away_clip:
clip_frames = None
else:
clip_frames = clip_frames.unsqueeze(0)
sync_frames = sync_frames.unsqueeze(0)
seq_cfg = self.model.seq_cfg
seq_cfg.duration = duration
self.net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
self.log.info(f'Prompt: {prompt}')
self.log.info(f'Negative prompt: {negative_prompt}')
self.log.info(f"Generating Audio...")
audios = generate(
clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt],
feature_utils=self.feature_utils,
net=self.net,
fm=self.fm,
rng=rng,
cfg_strength=cfg_strength)
audio = audios.float().cpu()[0]
if is_postp:
audio_save_path = output_dir / f'{video_path.stem}.neg.wav'
video_save_path = output_dir / f'{video_path.stem}.neg.mp4'
else:
audio_save_path = output_dir / f'{video_path.stem}.step1.wav'
video_save_path = output_dir / f'{video_path.stem}.step1.mp4'
self.log.info(f"Saving generated audio and video to {output_dir}")
torchaudio.save(str(audio_save_path), audio, seq_cfg.sampling_rate)
self.log.info(f'Audio saved to {audio_save_path}')
make_video(video_info, str(video_save_path), audio, sampling_rate=seq_cfg.sampling_rate)
self.log.info(f'Video saved to {video_save_path}')
return audio_save_path, video_save_path
# def main():
# # 初始化日志(如果你有 logger.py,推荐只做一次初始化)
# setup_eval_logging()
# # 初始化模型
# v2a_model = V2A_MMAudio(
# variant="large_44k_v2", # 这个是你模型的版本名
# num_steps=25, # 采样步数
# seed=42, # 随机种子
# full_precision=False # 是否使用全精度
# )
# # 视频路径(换成你的真实路径)
# video_path = "ZxiXftx2EMg_000477.mp4"
# # 输出目录
# output_dir = "outputs"
# # 提示词(控制生成内容)
# prompt = ""
# negative_prompt = ""
# # 生成音频 + 视频
# audio_save_path, video_save_path = v2a_model.generate_audio(
# video_path=video_path,
# output_dir=output_dir,
# prompt=prompt,
# negative_prompt=negative_prompt,
# duration=10, # 秒
# cfg_strength=4.5, # 指导强度
# mask_away_clip=False # 是否移除 clip
# )
# if __name__ == "__main__":
# main()