import os import torch import soundfile as sf import pandas as pd from tqdm import tqdm from utils import minmax_norm_diff, reverse_minmax_norm_diff from spk_ext import se_extractor def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @torch.no_grad() def inference_timbre(gen_shape, text, model, scheduler, guidance_scale=5, guidance_rescale=0.7, ddim_steps=50, eta=1, random_seed=2023, device='cuda', ): text, text_mask = text model.eval() generator = torch.Generator(device=device).manual_seed(random_seed) scheduler.set_timesteps(ddim_steps) # init noise noise = torch.randn(gen_shape, generator=generator, device=device) latents = noise for t in scheduler.timesteps: latents = scheduler.scale_model_input(latents, t) if guidance_scale: output_text = model(latents, t, text, text_mask, train_cfg=False) output_uncond = model(latents, t, text, text_mask, train_cfg=True, cfg_prob=1.0) output_pred = output_uncond + guidance_scale * (output_text - output_uncond) if guidance_rescale > 0.0: output_pred = rescale_noise_cfg(output_pred, output_text, guidance_rescale=guidance_rescale) else: output_pred = model(latents, t, text, text_mask, train_cfg=False) latents = scheduler.step(model_output=output_pred, timestep=t, sample=latents, eta=eta, generator=generator).prev_sample # pred = reverse_minmax_norm_diff(latents, vmin=0.0, vmax=0.5) # pred = torch.clip(pred, min=0.0, max=0.5) return latents @torch.no_grad() def eval_plugin_light(vc_model, text_model, timbre_model, timbre_scheduler, timbre_shape, val_meta, val_folder, guidance_scale=3, guidance_rescale=0.7, ddim_steps=50, eta=1, random_seed=2024, device='cuda', epoch=0, save_path='logs/eval/', val_num=10, sr=24000): tokenizer, text_encoder = text_model df = pd.read_csv(val_meta) save_path = save_path + str(epoch) + '/' os.makedirs(save_path, exist_ok=True) step = 0 for i in range(len(df)): row = df.iloc[i] source_path = val_folder + row['path'] prompt = [row['prompt']] with torch.no_grad(): text_batch = tokenizer(prompt, max_length=32, padding='max_length', truncation=True, return_tensors="pt") text, text_mask = text_batch.input_ids.to(device), \ text_batch.attention_mask.to(device) text = text_encoder(input_ids=text, attention_mask=text_mask)[0] spk_embed = inference_timbre(timbre_shape, [text, text_mask], timbre_model, timbre_scheduler, guidance_scale=guidance_scale, guidance_rescale=guidance_rescale, ddim_steps=ddim_steps, eta=eta, random_seed=random_seed, device=device) source_se = se_extractor(source_path, vc_model).to(device) # print(source_se.shape) # print(spk_embed.shape) encode_message = "@MyShell" vc_model.convert( audio_src_path=source_path, src_se=source_se, tgt_se=spk_embed, output_path=save_path + f'{step}_{prompt[0]}' + '.wav', message=encode_message) step += 1 if step >= val_num: break