File size: 4,947 Bytes
0dabde8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
import soundfile as sf
import pandas as pd
import librosa
from utils import minmax_norm_diff, reverse_minmax_norm_diff, scale_shift_re
from freevc_wrapper import convert
import time


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """

    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and

    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4

    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg


@torch.no_grad()
def inference_timbre(gen_shape, text,

                     model, scheduler,

                     guidance_scale=5, guidance_rescale=0.7,

                     ddim_steps=50, eta=1, random_seed=2023,

                     device='cuda',

                     ):
    text, text_mask = text
    model.eval()
    
    if random_seed is not None:
        generator = torch.Generator(device=device).manual_seed(random_seed)
    else:
        generator = torch.Generator(device=device)
        generator.seed()
    
    scheduler.set_timesteps(ddim_steps)

    # init noise
    noise = torch.randn(gen_shape, generator=generator, device=device)
    latents = noise

    for t in scheduler.timesteps:
        latents = scheduler.scale_model_input(latents, t)

        if guidance_scale:
            output_text = model(latents, t, text, text_mask, train_cfg=False)
            output_uncond = model(latents, t, text, text_mask, train_cfg=True, cfg_prob=1.0)

            output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
            if guidance_rescale > 0.0:
                output_pred = rescale_noise_cfg(output_pred, output_text,
                                                guidance_rescale=guidance_rescale)
        else:
            output_pred = model(latents, t, text, text_mask, train_cfg=False)

        latents = scheduler.step(model_output=output_pred, timestep=t, sample=latents,
                                 eta=eta, generator=generator).prev_sample

    # pred = reverse_minmax_norm_diff(latents, vmin=0.0, vmax=0.5)
    pred = scale_shift_re(latents, 20, -0.035)
    pred = torch.clip(pred, min=0.0, max=0.5)
    return pred


@torch.no_grad()
def eval_plugin(freevc, cmodel, text_model,

                timbre_model, timbre_scheduler, timbre_shape,

                val_meta, val_folder,

                guidance_scale=3, guidance_rescale=0.7,

                ddim_steps=50, eta=1, random_seed=2024,

                device='cuda',

                epoch=0, save_path='logs/eval/', val_num=10, sr=16000):

    tokenizer, text_encoder = text_model

    df = pd.read_csv(val_meta)

    save_path = save_path + str(epoch) + '/'
    os.makedirs(save_path, exist_ok=True)

    step = 0

    for i in range(len(df)):
        row = df.iloc[i]

        source_path = val_folder + row['path']
        # prompt = [row['prompt']]
        prompt = ["female's voice"]
        with torch.no_grad():
            text_batch = tokenizer(prompt,
                                   max_length=32,
                                   padding='max_length', truncation=True, return_tensors="pt")
            text, text_mask = text_batch.input_ids.to(device), \
                text_batch.attention_mask.to(device)
            text = text_encoder(input_ids=text, attention_mask=text_mask)[0]

        audio_clip = librosa.load(source_path, sr=16000)[0]
        audio_clip = torch.tensor(audio_clip).unsqueeze(0).to(device)

        content = cmodel(audio_clip).last_hidden_state.transpose(1, 2).to(device)

        # start_time = time.time()
        spk_embed = inference_timbre(timbre_shape, [text, text_mask],
                                     timbre_model, timbre_scheduler,
                                     guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
                                     ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
                                     device=device)
        spk_embed = spk_embed.squeeze(-1)

        output, out_sr = convert(freevc, content, spk_embed)
        # end_time = time.time()
        # print(end_time-start_time)
        # print(pred.shape)
        sf.write(save_path + f'{step}_{prompt[0]}' + '.wav', output, samplerate=sr)

        step += 1

        if step >= val_num:
            break