jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import os
import torch
import json
from torchvision.transforms import v2
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
import folder_paths
import comfy.model_management as mm
from comfy.utils import load_torch_file
script_directory = os.path.dirname(os.path.abspath(__file__))
if not "mmaudio" in folder_paths.folder_names_and_paths:
folder_paths.add_model_folder_path("mmaudio", os.path.join(folder_paths.models_dir, "mmaudio"))
from .mmaudio.eval_utils import generate
from .mmaudio.model.flow_matching import FlowMatching
from .mmaudio.model.networks import MMAudio
from .mmaudio.model.utils.features_utils import FeaturesUtils
from .mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
from .mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
from .mmaudio.ext.synchformer import Synchformer
from .mmaudio.ext.autoencoder import AutoEncoderModule
from open_clip import CLIP
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
def process_video_tensor(video_tensor: torch.Tensor, duration_sec: float) -> tuple[torch.Tensor, torch.Tensor, float]:
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToPILImage(),
v2.ToTensor(),
v2.ConvertImageDtype(torch.float32),
])
sync_transform = v2.Compose([
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
v2.CenterCrop(_SYNC_SIZE),
v2.ToPILImage(),
v2.ToTensor(),
v2.ConvertImageDtype(torch.float32),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Assuming video_tensor is in the shape (frames, height, width, channels)
total_frames = video_tensor.shape[0]
clip_frames_count = int(_CLIP_FPS * duration_sec)
sync_frames_count = int(_SYNC_FPS * duration_sec)
# Adjust duration if there are not enough frames
if total_frames < clip_frames_count:
log.warning(f'Clip video is too short: {total_frames / _CLIP_FPS:.2f} < {duration_sec:.2f}')
clip_frames_count = total_frames
duration_sec = total_frames / _CLIP_FPS
if total_frames < sync_frames_count:
log.warning(f'Sync video is too short: {total_frames / _SYNC_FPS:.2f} < {duration_sec:.2f}, truncating to {total_frames / _SYNC_FPS:.2f} sec')
sync_frames_count = total_frames
duration_sec = total_frames / _SYNC_FPS
clip_frames = video_tensor[:clip_frames_count]
sync_frames = video_tensor[:sync_frames_count]
clip_frames = clip_frames.permute(0, 3, 1, 2)
sync_frames = sync_frames.permute(0, 3, 1, 2)
clip_frames = torch.stack([clip_transform(frame) for frame in clip_frames])
sync_frames = torch.stack([sync_transform(frame) for frame in sync_frames])
clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
# if clip_length_sec < duration_sec:
# log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
# log.warning(f'Truncating to {clip_length_sec:.2f} sec')
# duration_sec = clip_length_sec
# if sync_length_sec < duration_sec:
# log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
# log.warning(f'Truncating to {sync_length_sec:.2f} sec')
# duration_sec = sync_length_sec
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
return clip_frames, sync_frames, duration_sec
#region Model loading
class MMAudioModelLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mmaudio_model": (folder_paths.get_filename_list("mmaudio"), {"tooltip": "These models are loaded from the 'ComfyUI/models/mmaudio' -folder",}),
"base_precision": (["fp16", "fp32", "bf16"], {"default": "fp16"}),
},
}
RETURN_TYPES = ("MMAUDIO_MODEL",)
RETURN_NAMES = ("mmaudio_model", )
FUNCTION = "loadmodel"
CATEGORY = "MMAudio"
def loadmodel(self, mmaudio_model, base_precision):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision]
mmaudio_model_path = folder_paths.get_full_path_or_raise("mmaudio", mmaudio_model)
mmaudio_sd = load_torch_file(mmaudio_model_path, device=offload_device)
if "small" in mmaudio_model:
num_heads = 7
model = MMAudio(
latent_dim=40,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=345,
clip_seq_len=64,
sync_seq_len=192
)
elif "large" in mmaudio_model:
num_heads = 14
model = MMAudio(latent_dim=40,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=21,
fused_depth=14,
num_heads=num_heads,
latent_seq_len=345,
clip_seq_len=64,
sync_seq_len=192,
v2=True
)
model = model.eval().to(device=device, dtype=base_dtype)
model.load_weights(mmaudio_sd)
log.info(f'Loaded MMAudio model weights from {mmaudio_model_path}')
if "44" in mmaudio_model:
model.seq_cfg = CONFIG_44K
elif "16" in mmaudio_model:
model.seq_cfg = CONFIG_16K
return (model,)
#region Features Utils
class MMAudioVoCoderLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vocoder_model": (folder_paths.get_filename_list("mmaudio"), {"tooltip": "These models are loaded from 'ComfyUI/models/mmaudio'"}),
},
}
RETURN_TYPES = ("VOCODER_MODEL",)
RETURN_NAMES = ("mmaudio_vocoder", )
FUNCTION = "loadmodel"
CATEGORY = "MMAudio"
def loadmodel(self, vocoder_model):
from .mmaudio.ext.bigvgan import BigVGAN
vocoder_model_path = folder_paths.get_full_path_or_raise("mmaudio", vocoder_model)
vocoder_model = BigVGAN.from_pretrained(vocoder_model_path).eval()
return (vocoder_model_path,)
class MMAudioFeatureUtilsLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae_model": (folder_paths.get_filename_list("mmaudio"), {"tooltip": "These models are loaded from 'ComfyUI/models/mmaudio'"}),
"synchformer_model": (folder_paths.get_filename_list("mmaudio"), {"tooltip": "These models are loaded from 'ComfyUI/models/mmaudio'"}),
"clip_model": (folder_paths.get_filename_list("mmaudio"), {"tooltip": "These models are loaded from 'ComfyUI/models/mmaudio'"}),
},
"optional": {
"bigvgan_vocoder_model": ("VOCODER_MODEL", {"tooltip": "These models are loaded from 'ComfyUI/models/mmaudio'"}),
"mode": (["16k", "44k"], {"default": "44k"}),
"precision": (["fp16", "fp32", "bf16"],
{"default": "fp16"}
),
}
}
RETURN_TYPES = ("MMAUDIO_FEATUREUTILS",)
RETURN_NAMES = ("mmaudio_featureutils", )
FUNCTION = "loadmodel"
CATEGORY = "MMAudio"
def loadmodel(self, vae_model, precision, synchformer_model, clip_model, mode, bigvgan_vocoder_model=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
#synchformer
synchformer_path = folder_paths.get_full_path_or_raise("mmaudio", synchformer_model)
synchformer_sd = load_torch_file(synchformer_path, device=offload_device)
synchformer = Synchformer()
synchformer.load_state_dict(synchformer_sd)
synchformer = synchformer.eval().to(device=device, dtype=dtype)
#vae
download_path = folder_paths.get_folder_paths("mmaudio")[0]
nvidia_bigvgan_vocoder_path = os.path.join(download_path, "nvidia", "bigvgan_v2_44khz_128band_512x")
if mode == "44k":
if not os.path.exists(nvidia_bigvgan_vocoder_path):
log.info(f"Downloading nvidia bigvgan vocoder model to: {nvidia_bigvgan_vocoder_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="nvidia/bigvgan_v2_44khz_128band_512x",
ignore_patterns=["*3m*",],
local_dir=nvidia_bigvgan_vocoder_path,
local_dir_use_symlinks=False,
)
bigvgan_vocoder = BigVGANv2.from_pretrained(nvidia_bigvgan_vocoder_path).eval().to(device=device, dtype=dtype)
else:
assert bigvgan_vocoder_model is not None, "bigvgan_vocoder_model must be provided for 16k mode"
bigvgan_vocoder = bigvgan_vocoder_model
vae_path = folder_paths.get_full_path_or_raise("mmaudio", vae_model)
vae_sd = load_torch_file(vae_path, device=offload_device)
vae = AutoEncoderModule(
vae_state_dict=vae_sd,
bigvgan_vocoder=bigvgan_vocoder,
mode=mode
)
vae = vae.eval().to(device=device, dtype=dtype)
#clip
clip_path = os.path.join(download_path, "apple", "DFN5B-CLIP-ViT-H-14-384")
if not os.path.exists(clip_path):
log.info(f"Downloading Apple DFN5B-CLIP-ViT-H-14-384 model to: {clip_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="apple/DFN5B-CLIP-ViT-H-14-384",
ignore_patterns=["pytorch_model.bin"],
local_dir=clip_path,
local_dir_use_symlinks=False,
)
clip_model_path = folder_paths.get_full_path_or_raise("mmaudio", clip_model)
clip_config_path = os.path.join(script_directory, "configs", "DFN5B-CLIP-ViT-H-14-384.json")
with open(clip_config_path) as f:
clip_config = json.load(f)
with init_empty_weights():
clip_model = CLIP(**clip_config["model_cfg"]).eval()
clip_sd = load_torch_file(os.path.join(clip_model_path), device=offload_device)
for name, param in clip_model.named_parameters():
set_module_tensor_to_device(clip_model, name, device=device, dtype=dtype, value=clip_sd[name])
clip_model.to(device=device, dtype=dtype)
#clip_model = create_model_from_pretrained("hf-hub:apple/DFN5B-CLIP-ViT-H-14-384", return_transform=False)
feature_utils = FeaturesUtils(vae=vae,
synchformer=synchformer,
enable_conditions=True,
clip_model=clip_model)
return (feature_utils,)
#region sampling
class MMAudioSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mmaudio_model": ("MMAUDIO_MODEL",),
"feature_utils": ("MMAUDIO_FEATUREUTILS",),
"duration": ("FLOAT", {"default": 8, "step": 0.01, "tooltip": "Duration of the audio in seconds"}),
"steps": ("INT", {"default": 25, "step": 1, "tooltip": "Number of steps to interpolate"}),
"cfg": ("FLOAT", {"default": 4.5, "step": 0.1, "tooltip": "Strength of the conditioning"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"prompt": ("STRING", {"default": "", "multiline": True} ),
"negative_prompt": ("STRING", {"default": "", "multiline": True} ),
"mask_away_clip": ("BOOLEAN", {"default": False, "tooltip": "If true, the clip video will be masked away"}),
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "If true, the model will be offloaded to the offload device"}),
},
"optional": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio", )
FUNCTION = "sample"
CATEGORY = "MMAudio"
def sample(self, mmaudio_model, seed, feature_utils, duration, steps, cfg, prompt, negative_prompt, mask_away_clip, force_offload, images=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
rng = torch.Generator(device=device)
rng.manual_seed(seed)
seq_cfg = mmaudio_model.seq_cfg
if images is not None:
images = images.to(device=device)
clip_frames, sync_frames, duration = process_video_tensor(images, duration)
print("clip_frames", clip_frames.shape, "sync_frames", sync_frames.shape, "duration", duration)
if mask_away_clip:
clip_frames = None
else:
clip_frames = clip_frames.unsqueeze(0)
sync_frames = sync_frames.unsqueeze(0)
else:
clip_frames = None
sync_frames = None
seq_cfg.duration = duration
mmaudio_model.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
scheduler = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=steps)
feature_utils.to(device)
mmaudio_model.to(device)
audios = generate(clip_frames,
sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=mmaudio_model,
fm=scheduler,
rng=rng,
cfg_strength=cfg)
if force_offload:
mmaudio_model.to(offload_device)
feature_utils.to(offload_device)
mm.soft_empty_cache()
waveform = audios.float().cpu()
#torchaudio.save("test.wav", waveform, 44100)
audio = {
"waveform": waveform,
"sample_rate": 44100
}
return (audio,)
NODE_CLASS_MAPPINGS = {
"MMAudioModelLoader": MMAudioModelLoader,
"MMAudioFeatureUtilsLoader": MMAudioFeatureUtilsLoader,
"MMAudioSampler": MMAudioSampler,
"MMAudioVoCoderLoader": MMAudioVoCoderLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MMAudioModelLoader": "MMAudio ModelLoader",
"MMAudioFeatureUtilsLoader": "MMAudio FeatureUtilsLoader",
"MMAudioSampler": "MMAudio Sampler",
"MMAudioVoCoderLoader": "MMAudio VoCoderLoader",
}