|
import os
|
|
import random
|
|
import pandas as pd
|
|
import torch
|
|
import librosa
|
|
import numpy as np
|
|
import soundfile as sf
|
|
from tqdm import tqdm
|
|
from utils import scale_shift_re
|
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
"""
|
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
|
"""
|
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
|
return noise_cfg
|
|
|
|
|
|
@torch.no_grad()
|
|
def inference(autoencoder, unet, gt, gt_mask,
|
|
tokenizer, text_encoder,
|
|
params, noise_scheduler,
|
|
text_raw, neg_text=None,
|
|
audio_frames=500,
|
|
guidance_scale=3, guidance_rescale=0.0,
|
|
ddim_steps=50, eta=1, random_seed=2024,
|
|
device='cuda',
|
|
):
|
|
if neg_text is None:
|
|
neg_text = [""]
|
|
if tokenizer is not None:
|
|
text_batch = tokenizer(text_raw,
|
|
max_length=params['text_encoder']['max_length'],
|
|
padding="max_length", truncation=True, return_tensors="pt")
|
|
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
|
|
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
|
|
|
|
uncond_text_batch = tokenizer(neg_text,
|
|
max_length=params['text_encoder']['max_length'],
|
|
padding="max_length", truncation=True, return_tensors="pt")
|
|
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
|
|
uncond_text = text_encoder(input_ids=uncond_text,
|
|
attention_mask=uncond_text_mask).last_hidden_state
|
|
else:
|
|
text, text_mask = None, None
|
|
guidance_scale = None
|
|
|
|
codec_dim = params['model']['out_chans']
|
|
unet.eval()
|
|
|
|
if random_seed is not None:
|
|
generator = torch.Generator(device=device).manual_seed(random_seed)
|
|
else:
|
|
generator = torch.Generator(device=device)
|
|
generator.seed()
|
|
|
|
noise_scheduler.set_timesteps(ddim_steps)
|
|
|
|
|
|
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
|
|
latents = noise
|
|
|
|
for t in noise_scheduler.timesteps:
|
|
latents = noise_scheduler.scale_model_input(latents, t)
|
|
|
|
if guidance_scale:
|
|
|
|
latents_combined = torch.cat([latents, latents], dim=0)
|
|
text_combined = torch.cat([text, uncond_text], dim=0)
|
|
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
|
|
|
|
if gt is not None:
|
|
gt_combined = torch.cat([gt, gt], dim=0)
|
|
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
|
|
else:
|
|
gt_combined = None
|
|
gt_mask_combined = None
|
|
|
|
output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
|
|
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
|
|
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
|
|
|
|
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
|
|
if guidance_rescale > 0.0:
|
|
output_pred = rescale_noise_cfg(output_pred, output_text,
|
|
guidance_rescale=guidance_rescale)
|
|
else:
|
|
output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
|
|
cls_token=None, gt=gt, mae_mask_infer=gt_mask)
|
|
|
|
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
|
|
sample=latents,
|
|
eta=eta, generator=generator).prev_sample
|
|
|
|
pred = scale_shift_re(latents, params['autoencoder']['scale'],
|
|
params['autoencoder']['shift'])
|
|
if gt is not None:
|
|
pred[~gt_mask] = gt[~gt_mask]
|
|
pred_wav = autoencoder(embedding=pred)
|
|
return pred_wav
|
|
|
|
|
|
@torch.no_grad()
|
|
def eval_udit(autoencoder, unet,
|
|
tokenizer, text_encoder,
|
|
params, noise_scheduler,
|
|
val_df, subset,
|
|
audio_frames, mae=False,
|
|
guidance_scale=3, guidance_rescale=0.0,
|
|
ddim_steps=50, eta=1, random_seed=2023,
|
|
device='cuda',
|
|
epoch=0, save_path='logs/eval/', val_num=5):
|
|
val_df = pd.read_csv(val_df)
|
|
val_df = val_df[val_df['split'] == subset]
|
|
if mae:
|
|
val_df = val_df[val_df['audio_length'] != 0]
|
|
|
|
save_path = save_path + str(epoch) + '/'
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
for i in tqdm(range(len(val_df))):
|
|
row = val_df.iloc[i]
|
|
text = [row['caption']]
|
|
if mae:
|
|
audio_path = params['data']['val_dir'] + str(row['audio_path'])
|
|
gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
|
|
gt = gt / (np.max(np.abs(gt)) + 1e-9)
|
|
sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
|
|
num_samples = 10 * sr
|
|
if len(gt) < num_samples:
|
|
padding = num_samples - len(gt)
|
|
gt = np.pad(gt, (0, padding), 'constant')
|
|
else:
|
|
gt = gt[:num_samples]
|
|
gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
|
|
gt = autoencoder(audio=gt)
|
|
B, D, L = gt.shape
|
|
mask_len = int(L * 0.2)
|
|
gt_mask = torch.zeros(B, D, L).to(device)
|
|
for _ in range(2):
|
|
start = random.randint(0, L - mask_len)
|
|
gt_mask[:, :, start:start + mask_len] = 1
|
|
gt_mask = gt_mask.bool()
|
|
else:
|
|
gt = None
|
|
gt_mask = None
|
|
|
|
pred = inference(autoencoder, unet, gt, gt_mask,
|
|
tokenizer, text_encoder,
|
|
params, noise_scheduler,
|
|
text, neg_text=None,
|
|
audio_frames=audio_frames,
|
|
guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
|
|
ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
|
|
device=device)
|
|
|
|
pred = pred.cpu().numpy().squeeze(0).squeeze(0)
|
|
|
|
sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
|
|
|
|
if i + 1 >= val_num:
|
|
break
|
|
|