Spaces:
Running
Running
#coding=utf-8 | |
import logging | |
from pathlib import Path | |
import torch | |
import torchaudio | |
from third_party.MMAudio.mmaudio.eval_utils import ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging | |
from third_party.MMAudio.mmaudio.model.flow_matching import FlowMatching | |
from third_party.MMAudio.mmaudio.model.networks import MMAudio, get_my_mmaudio | |
from third_party.MMAudio.mmaudio.model.utils.features_utils import FeaturesUtils | |
class V2A_MMAudio: | |
def __init__(self, | |
variant: str="large_44k", | |
num_steps: int=25, | |
full_precision: bool=False,): | |
self.log = logging.getLogger(self.__class__.__name__) | |
self.log.setLevel(logging.INFO) | |
self.log.info(f"The V2A model uses MMAudio {variant}, init...") | |
self.device = 'cpu' | |
if torch.cuda.is_available(): | |
self.device = 'cuda' | |
elif torch.backends.mps.is_available(): | |
self.device = 'mps' | |
else: | |
self.log.warning('CUDA/MPS are not available, running on CPU') | |
self.dtype = torch.float32 if full_precision else torch.bfloat16 | |
if variant not in all_model_cfg: | |
raise ValueError(f'Unknown model variant: {variant}') | |
self.model: ModelConfig = all_model_cfg[variant] | |
self.model.download_if_needed() | |
self.net: MMAudio= get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval() | |
self.net.load_weights(torch.load(self.model.model_path, map_location=self.device, weights_only=True)) | |
# Flow Matching | |
self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) | |
# Feature utils setup | |
self.feature_utils = FeaturesUtils(tod_vae_ckpt=self.model.vae_path, | |
synchformer_ckpt=self.model.synchformer_ckpt, | |
enable_conditions=True, | |
mode=self.model.mode, | |
bigvgan_vocoder_ckpt=self.model.bigvgan_16k_path, | |
need_vae_encoder=False) | |
self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval() | |
def generate_audio(self, | |
video_path, | |
output_dir, | |
prompt: str='', | |
negative_prompt: str='', | |
duration: int=10, | |
seed: int=42, | |
cfg_strength: float=4.5, | |
mask_away_clip: bool=False, | |
is_postp=False,): | |
video_path = Path(video_path).expanduser() | |
output_dir = Path(output_dir).expanduser() | |
self.log.info(f"Loading video: {video_path}") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
video_info = load_video(video_path, duration) | |
clip_frames = video_info.clip_frames | |
sync_frames = video_info.sync_frames | |
duration = video_info.duration_sec | |
# Setup random generator for reproducibility | |
rng = torch.Generator(device=self.device) | |
rng.manual_seed(seed) | |
if mask_away_clip: | |
clip_frames = None | |
else: | |
clip_frames = clip_frames.unsqueeze(0) | |
sync_frames = sync_frames.unsqueeze(0) | |
seq_cfg = self.model.seq_cfg | |
seq_cfg.duration = duration | |
self.net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
self.log.info(f'Prompt: {prompt}') | |
self.log.info(f'Negative prompt: {negative_prompt}') | |
self.log.info(f"Generating Audio...") | |
audios = generate( | |
clip_frames, | |
sync_frames, | |
[prompt], | |
negative_text=[negative_prompt], | |
feature_utils=self.feature_utils, | |
net=self.net, | |
fm=self.fm, | |
rng=rng, | |
cfg_strength=cfg_strength) | |
audio = audios.float().cpu()[0] | |
if is_postp: | |
audio_save_path = output_dir / f'{video_path.stem}.neg.wav' | |
video_save_path = output_dir / f'{video_path.stem}.neg.mp4' | |
else: | |
audio_save_path = output_dir / f'{video_path.stem}.step1.wav' | |
video_save_path = output_dir / f'{video_path.stem}.step1.mp4' | |
self.log.info(f"Saving generated audio and video to {output_dir}") | |
torchaudio.save(str(audio_save_path), audio, seq_cfg.sampling_rate) | |
self.log.info(f'Audio saved to {audio_save_path}') | |
make_video(video_info, str(video_save_path), audio, sampling_rate=seq_cfg.sampling_rate) | |
self.log.info(f'Video saved to {video_save_path}') | |
return audio_save_path, video_save_path | |
# def main(): | |
# # 初始化日志(如果你有 logger.py,推荐只做一次初始化) | |
# setup_eval_logging() | |
# # 初始化模型 | |
# v2a_model = V2A_MMAudio( | |
# variant="large_44k_v2", # 这个是你模型的版本名 | |
# num_steps=25, # 采样步数 | |
# seed=42, # 随机种子 | |
# full_precision=False # 是否使用全精度 | |
# ) | |
# # 视频路径(换成你的真实路径) | |
# video_path = "ZxiXftx2EMg_000477.mp4" | |
# # 输出目录 | |
# output_dir = "outputs" | |
# # 提示词(控制生成内容) | |
# prompt = "" | |
# negative_prompt = "" | |
# # 生成音频 + 视频 | |
# audio_save_path, video_save_path = v2a_model.generate_audio( | |
# video_path=video_path, | |
# output_dir=output_dir, | |
# prompt=prompt, | |
# negative_prompt=negative_prompt, | |
# duration=10, # 秒 | |
# cfg_strength=4.5, # 指导强度 | |
# mask_away_clip=False # 是否移除 clip | |
# ) | |
# if __name__ == "__main__": | |
# main() |