import os import re import yaml import torch import torchaudio import numpy as np import audiosr.latent_diffusion.modules.phoneme_encoder.text as text from audiosr.latent_diffusion.models.ddpm import LatentDiffusion from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding from audiosr.utils import ( default_audioldm_config, download_checkpoint, read_audio_file, lowpass_filtering_prepare_inference, wav_feature_extraction, ) import os def seed_everything(seed): import random, os import numpy as np import torch random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def text2phoneme(data): return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"]) def text_to_filename(text): return text.replace(" ", "_").replace("'", "_").replace('"', "_") def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): norm_mean = -4.2677393 norm_std = 4.5689974 if sampling_rate != 16000: waveform_16k = torchaudio.functional.resample( waveform, orig_freq=sampling_rate, new_freq=16000 ) else: waveform_16k = waveform waveform_16k = waveform_16k - waveform_16k.mean() fbank = torchaudio.compliance.kaldi.fbank( waveform_16k, htk_compat=True, sample_frequency=16000, use_energy=False, window_type="hanning", num_mel_bins=128, dither=0.0, frame_shift=10, ) TARGET_LEN = log_mel_spec.size(0) # cut and pad n_frames = fbank.shape[0] p = TARGET_LEN - n_frames if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[:TARGET_LEN, :] fbank = (fbank - norm_mean) / (norm_std * 2) return {"ta_kaldi_fbank": fbank} # [1024, 128] def make_batch_for_super_resolution(input_file, waveform=None, fbank=None): log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file) batch = { "waveform": torch.FloatTensor(waveform), "stft": torch.FloatTensor(stft), "log_mel_spec": torch.FloatTensor(log_mel_spec), "sampling_rate": 48000, } # print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size()) batch.update(lowpass_filtering_prepare_inference(batch)) assert "waveform_lowpass" in batch.keys() lowpass_mel, lowpass_stft = wav_feature_extraction( batch["waveform_lowpass"], target_frame ) batch["lowpass_mel"] = lowpass_mel for k in batch.keys(): if type(batch[k]) == torch.Tensor: batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0) return batch, duration def round_up_duration(duration): return int(round(duration / 2.5) + 1) * 2.5 def build_model(ckpt_path=None, config=None, device=None, model_name="basic"): if device is None or device == "auto": if torch.cuda.is_available(): device = torch.device("cuda:0") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print("Loading AudioSR: %s" % model_name) print("Loading model on %s" % device) ckpt_path = download_checkpoint(model_name) if config is not None: assert type(config) is str config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) else: config = default_audioldm_config(model_name) # # Use text as condition instead of using waveform during training config["model"]["params"]["device"] = device # config["model"]["params"]["cond_stage_key"] = "text" # No normalization here latent_diffusion = LatentDiffusion(**config["model"]["params"]) resume_from_checkpoint = ckpt_path checkpoint = torch.load(resume_from_checkpoint, map_location=device) latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False) latent_diffusion.eval() latent_diffusion = latent_diffusion.to(device) return latent_diffusion def super_resolution( latent_diffusion, input_file, seed=42, ddim_steps=200, guidance_scale=3.5, latent_t_per_second=12.8, config=None, ): seed_everything(int(seed)) waveform = None batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform) with torch.no_grad(): waveform = latent_diffusion.generate_batch( batch, unconditional_guidance_scale=guidance_scale, ddim_steps=ddim_steps, duration=duration, ) return waveform