File size: 4,604 Bytes
bd3a23c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import torch
import soundfile as sf
import pandas as pd
from tqdm import tqdm
from utils import minmax_norm_diff, reverse_minmax_norm_diff
from spk_ext import se_extractor


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """

    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and

    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4

    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg


@torch.no_grad()
def inference_timbre(gen_shape, text,

                     model, scheduler,

                     guidance_scale=5, guidance_rescale=0.7,

                     ddim_steps=50, eta=1, random_seed=2023,

                     device='cuda',

                     ):
    text, text_mask = text
    model.eval()
    generator = torch.Generator(device=device).manual_seed(random_seed)
    scheduler.set_timesteps(ddim_steps)

    # init noise
    noise = torch.randn(gen_shape, generator=generator, device=device)
    latents = noise

    for t in scheduler.timesteps:
        latents = scheduler.scale_model_input(latents, t)

        if guidance_scale:
            output_text = model(latents, t, text, text_mask, train_cfg=False)
            output_uncond = model(latents, t, text, text_mask, train_cfg=True, cfg_prob=1.0)

            output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
            if guidance_rescale > 0.0:
                output_pred = rescale_noise_cfg(output_pred, output_text,
                                                guidance_rescale=guidance_rescale)
        else:
            output_pred = model(latents, t, text, text_mask, train_cfg=False)

        latents = scheduler.step(model_output=output_pred, timestep=t, sample=latents,
                                 eta=eta, generator=generator).prev_sample

    # pred = reverse_minmax_norm_diff(latents, vmin=0.0, vmax=0.5)
    # pred = torch.clip(pred, min=0.0, max=0.5)
    return latents


@torch.no_grad()
def eval_plugin_light(vc_model, text_model,

                      timbre_model, timbre_scheduler, timbre_shape,

                      val_meta, val_folder,

                      guidance_scale=3, guidance_rescale=0.7,

                      ddim_steps=50, eta=1, random_seed=2024,

                      device='cuda',

                      epoch=0, save_path='logs/eval/', val_num=10, sr=24000):

    tokenizer, text_encoder = text_model

    df = pd.read_csv(val_meta)

    save_path = save_path + str(epoch) + '/'
    os.makedirs(save_path, exist_ok=True)

    step = 0

    for i in range(len(df)):
        row = df.iloc[i]

        source_path = val_folder + row['path']
        prompt = [row['prompt']]

        with torch.no_grad():
            text_batch = tokenizer(prompt,
                                   max_length=32,
                                   padding='max_length', truncation=True, return_tensors="pt")
            text, text_mask = text_batch.input_ids.to(device), \
                text_batch.attention_mask.to(device)
            text = text_encoder(input_ids=text, attention_mask=text_mask)[0]

        spk_embed = inference_timbre(timbre_shape, [text, text_mask],
                                     timbre_model, timbre_scheduler,
                                     guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
                                     ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
                                     device=device)

        source_se = se_extractor(source_path, vc_model).to(device)
        # print(source_se.shape)
        # print(spk_embed.shape)

        encode_message = "@MyShell"
        vc_model.convert(
            audio_src_path=source_path,
            src_se=source_se,
            tgt_se=spk_embed,
            output_path=save_path + f'{step}_{prompt[0]}' + '.wav',
            message=encode_message)

        step += 1
        if step >= val_num:
            break