Spaces:
Running
Running
import os | |
import sys | |
import time | |
import tqdm | |
import torch | |
import logging | |
import librosa | |
import argparse | |
import scipy.signal | |
import logging.handlers | |
import numpy as np | |
import soundfile as sf | |
from torch import inference_mode | |
from distutils.util import strtobool | |
sys.path.append(os.getcwd()) | |
from main.configs.config import Config | |
from main.library.audioldm2.utils import load_audio | |
from main.library.audioldm2.models import load_model | |
config = Config() | |
translations = config.translations | |
logger = logging.getLogger(__name__) | |
logger.propagate = False | |
for l in ["torch", "httpx", "httpcore", "diffusers", "transformers"]: | |
logging.getLogger(l).setLevel(logging.ERROR) | |
if logger.hasHandlers(): logger.handlers.clear() | |
else: | |
console_handler = logging.StreamHandler() | |
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
console_handler.setFormatter(console_formatter) | |
console_handler.setLevel(logging.INFO) | |
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "audioldm2.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8') | |
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
file_handler.setFormatter(file_formatter) | |
file_handler.setLevel(logging.DEBUG) | |
logger.addHandler(console_handler) | |
logger.addHandler(file_handler) | |
logger.setLevel(logging.DEBUG) | |
def parse_arguments(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_path", type=str, required=True) | |
parser.add_argument("--output_path", type=str, default="./output.wav") | |
parser.add_argument("--export_format", type=str, default="wav") | |
parser.add_argument("--sample_rate", type=int, default=44100) | |
parser.add_argument("--audioldm_model", type=str, default="audioldm2-music") | |
parser.add_argument("--source_prompt", type=str, default="") | |
parser.add_argument("--target_prompt", type=str, default="") | |
parser.add_argument("--steps", type=int, default=200) | |
parser.add_argument("--cfg_scale_src", type=float, default=3.5) | |
parser.add_argument("--cfg_scale_tar", type=float, default=12) | |
parser.add_argument("--t_start", type=int, default=45) | |
parser.add_argument("--save_compute", type=lambda x: bool(strtobool(x)), default=False) | |
return parser.parse_args() | |
def main(): | |
args = parse_arguments() | |
input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute = args.input_path, args.output_path, args.export_format, args.sample_rate, args.audioldm_model, args.source_prompt, args.target_prompt, args.steps, args.cfg_scale_src, args.cfg_scale_tar, args.t_start, args.save_compute | |
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_name']: audioldm_model, translations['export_format']: export_format, translations['sample_rate']: sample_rate, translations['steps']: steps, translations['source_prompt']: source_prompt, translations['target_prompt']: target_prompt, translations['cfg_scale_src']: cfg_scale_src, translations['cfg_scale_tar']: cfg_scale_tar, translations['t_start']: t_start, translations['save_compute']: save_compute} | |
for key, value in log_data.items(): | |
logger.debug(f"{key}: {value}") | |
start_time = time.time() | |
logger.info(translations["start_edit"].format(input_path=input_path)) | |
pid_path = os.path.join("assets", "audioldm2_pid.txt") | |
with open(pid_path, "w") as pid_file: | |
pid_file.write(str(os.getpid())) | |
try: | |
edit(input_path, output_path, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute, sample_rate, config.device, export_format=export_format) | |
except Exception as e: | |
logger.error(translations["error_edit"].format(e=e)) | |
import traceback | |
logger.debug(traceback.format_exc()) | |
logger.info(translations["edit_success"].format(time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format))) | |
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute): | |
with inference_mode(): | |
w0 = ldm_stable.vae_encode(x0) | |
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute) | |
return zs, wts, extra_info | |
def low_pass_filter(audio, cutoff=7500, sr=16000): | |
b, a = scipy.signal.butter(4, cutoff / (sr / 2), btype='low') | |
return scipy.signal.filtfilt(b, a, audio) | |
def sample(output_audio, sr, ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute, export_format = "wav"): | |
tstart = torch.tensor(tstart, dtype=torch.int32) | |
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute) | |
with inference_mode(): | |
x0_dec = ldm_stable.vae_decode(w0.to(torch.float16 if config.is_half else torch.float32)) | |
if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :] | |
with torch.no_grad(): | |
audio = ldm_stable.decode_to_mel(x0_dec.to(torch.float16 if config.is_half else torch.float32)) | |
audio = audio.float().squeeze().cpu().numpy() | |
orig_sr = 16000 | |
if sr != 16000 and sr > 0: | |
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr, res_type="soxr_vhq") | |
orig_sr = sr | |
audio = low_pass_filter(audio, 7500, orig_sr) | |
sf.write(output_audio, np.tile(audio, (2, 1)).T, orig_sr, format=export_format) | |
return output_audio | |
def edit(input_audio, output_audio, model_id, source_prompt = "", target_prompt = "", steps = 200, cfg_scale_src = 3.5, cfg_scale_tar = 12, t_start = 45, save_compute = True, sr = 44100, device = "cpu", export_format = "wav"): | |
ldm_stable = load_model(model_id, device=device) | |
ldm_stable.model.scheduler.set_timesteps(steps, device=device) | |
x0, duration = load_audio(input_audio, ldm_stable.get_melspectrogram(), device=device) | |
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute) | |
return sample(output_audio, sr, ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute, export_format=export_format) | |
def inversion_forward_process(model, x0, etas = None, prompts = [""], cfg_scales = [3.5], num_inference_steps = 50, numerical_fix = False, duration = None, first_order = False, save_compute = True): | |
if len(prompts) > 1 or prompts[0] != "": | |
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts) | |
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None) | |
else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=False) | |
timesteps = model.model.scheduler.timesteps.to(model.device) | |
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps) | |
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps | |
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps) | |
zs = torch.zeros(size=variance_noise_shape, device=model.device) | |
extra_info = [None] * len(zs) | |
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)} | |
xt = x0 | |
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "") | |
for t in tqdm.tqdm(timesteps, desc=translations["inverting"], ncols=100, unit="a"): | |
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1 | |
xt = xts[idx + 1][None] | |
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32) | |
with torch.no_grad(): | |
if save_compute and prompts[0] != "": | |
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None) | |
out, cond_out = comb_out.sample.chunk(2, dim=0) | |
else: | |
out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample | |
if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample | |
if len(prompts) > 1 or prompts[0] != "": noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0) | |
else: noise_pred = out | |
xtm1 = xts[idx][None] | |
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order) | |
zs[idx] = z | |
xts[idx] = xtm1 | |
extra_info[idx] = extra | |
if zs is not None: zs[0] = torch.zeros_like(zs[0]) | |
return xt, zs, xts, extra_info | |
def inversion_reverse_process(model, xT, tstart, etas = 0, prompts = [""], neg_prompts = [""], cfg_scales = None, zs = None, duration = None, first_order = False, extra_info = None, save_compute = True): | |
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts) | |
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None) | |
xt = xT[tstart.max()].unsqueeze(0) | |
if etas is None: etas = 0 | |
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps | |
assert len(etas) == model.model.scheduler.num_inference_steps | |
timesteps = model.model.scheduler.timesteps.to(model.device) | |
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} | |
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} | |
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute) | |
for t in tqdm.tqdm(timesteps[-zs.shape[0]:], desc=translations["editing"], ncols=100, unit="a"): | |
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - (model.model.scheduler.num_inference_steps - zs.shape[0] + 1) | |
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32) | |
with torch.no_grad(): | |
if save_compute: | |
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None) | |
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0) | |
else: | |
uncond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample | |
cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample | |
z = zs[idx] if zs is not None else None | |
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0) | |
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z.unsqueeze(0), eta=etas[idx], first_order=first_order) | |
return xt, zs | |
if __name__ == "__main__": main() |