Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import yaml | |
import torch | |
import torchaudio | |
import numpy as np | |
import audiosr.latent_diffusion.modules.phoneme_encoder.text as text | |
from audiosr.latent_diffusion.models.ddpm import LatentDiffusion | |
from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding | |
from audiosr.utils import ( | |
default_audioldm_config, | |
download_checkpoint, | |
read_audio_file, | |
lowpass_filtering_prepare_inference, | |
wav_feature_extraction, | |
) | |
import os | |
def seed_everything(seed): | |
import random, os | |
import numpy as np | |
import torch | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
def text2phoneme(data): | |
return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"]) | |
def text_to_filename(text): | |
return text.replace(" ", "_").replace("'", "_").replace('"', "_") | |
def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): | |
norm_mean = -4.2677393 | |
norm_std = 4.5689974 | |
if sampling_rate != 16000: | |
waveform_16k = torchaudio.functional.resample( | |
waveform, orig_freq=sampling_rate, new_freq=16000 | |
) | |
else: | |
waveform_16k = waveform | |
waveform_16k = waveform_16k - waveform_16k.mean() | |
fbank = torchaudio.compliance.kaldi.fbank( | |
waveform_16k, | |
htk_compat=True, | |
sample_frequency=16000, | |
use_energy=False, | |
window_type="hanning", | |
num_mel_bins=128, | |
dither=0.0, | |
frame_shift=10, | |
) | |
TARGET_LEN = log_mel_spec.size(0) | |
# cut and pad | |
n_frames = fbank.shape[0] | |
p = TARGET_LEN - n_frames | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
fbank = m(fbank) | |
elif p < 0: | |
fbank = fbank[:TARGET_LEN, :] | |
fbank = (fbank - norm_mean) / (norm_std * 2) | |
return {"ta_kaldi_fbank": fbank} # [1024, 128] | |
def make_batch_for_super_resolution(input_file, waveform=None, fbank=None): | |
log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file) | |
batch = { | |
"waveform": torch.FloatTensor(waveform), | |
"stft": torch.FloatTensor(stft), | |
"log_mel_spec": torch.FloatTensor(log_mel_spec), | |
"sampling_rate": 48000, | |
} | |
# print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size()) | |
batch.update(lowpass_filtering_prepare_inference(batch)) | |
assert "waveform_lowpass" in batch.keys() | |
lowpass_mel, lowpass_stft = wav_feature_extraction( | |
batch["waveform_lowpass"], target_frame | |
) | |
batch["lowpass_mel"] = lowpass_mel | |
for k in batch.keys(): | |
if type(batch[k]) == torch.Tensor: | |
batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0) | |
return batch, duration | |
def round_up_duration(duration): | |
return int(round(duration / 2.5) + 1) * 2.5 | |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"): | |
if device is None or device == "auto": | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print("Loading AudioSR: %s" % model_name) | |
print("Loading model on %s" % device) | |
ckpt_path = download_checkpoint(model_name) | |
if config is not None: | |
assert type(config) is str | |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
else: | |
config = default_audioldm_config(model_name) | |
# # Use text as condition instead of using waveform during training | |
config["model"]["params"]["device"] = device | |
# config["model"]["params"]["cond_stage_key"] = "text" | |
# No normalization here | |
latent_diffusion = LatentDiffusion(**config["model"]["params"]) | |
resume_from_checkpoint = ckpt_path | |
checkpoint = torch.load(resume_from_checkpoint, map_location=device) | |
latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False) | |
latent_diffusion.eval() | |
latent_diffusion = latent_diffusion.to(device) | |
return latent_diffusion | |
def super_resolution( | |
latent_diffusion, | |
input_file, | |
seed=42, | |
ddim_steps=200, | |
guidance_scale=3.5, | |
latent_t_per_second=12.8, | |
config=None, | |
): | |
seed_everything(int(seed)) | |
waveform = None | |
batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform) | |
with torch.no_grad(): | |
waveform = latent_diffusion.generate_batch( | |
batch, | |
unconditional_guidance_scale=guidance_scale, | |
ddim_steps=ddim_steps, | |
duration=duration, | |
) | |
return waveform | |