File size: 4,398 Bytes
e73da9c
 
 
 
7c56def
511e6ea
e73da9c
f8afdbc
7c56def
e73da9c
 
 
 
 
 
 
7c56def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e73da9c
7c56def
 
 
 
 
e73da9c
7c56def
 
 
 
e73da9c
7c56def
 
 
 
 
 
 
 
 
 
 
 
 
e73da9c
7c56def
 
 
 
 
e73da9c
7c56def
 
 
e73da9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import torch
from typing import Optional, List, Tuple, NamedTuple, Union
from models import PipelineWrapper
import torchaudio
from audioldm.utils import get_duration

MAX_DURATION = None


class PromptEmbeddings(NamedTuple):
    embedding_hidden_states: torch.Tensor
    embedding_class_lables: torch.Tensor
    boolean_prompt_mask: torch.Tensor


def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
               device: Optional[torch.device] = None,
               return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor:
    if stft:  # AudioLDM/tango loading to spectrogram
        if type(audio_path) is str:
            import audioldm
            import audioldm.audio

            duration = get_duration(audio_path)
            if MAX_DURATION is not None:
                duration = min(duration, MAX_DURATION)

            mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
            mel = mel.unsqueeze(0)
        else:
            mel = audio_path

        c, h, w = mel.shape
        left = min(left, w-1)
        right = min(right, w - left - 1)
        mel = mel[:, :, left:w-right]
        mel = mel.unsqueeze(0).to(device)

        if return_wav:
            return mel, 16000, duration, wav

        return mel, model_sr, duration
    else:
        waveform, sr = torchaudio.load(audio_path)
        if sr != model_sr:
            waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr)
        # waveform = waveform.numpy()[0, ...]

        def normalize_wav(waveform):
            waveform = waveform - torch.mean(waveform)
            waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
            return waveform * 0.5

        waveform = normalize_wav(waveform)
        # waveform = waveform[None, ...]
        # waveform = pad_wav(waveform, segment_length)

        # waveform = waveform[0, ...]
        waveform = torch.FloatTensor(waveform)
        if MAX_DURATION is not None:
            duration = min(waveform.shape[-1] / model_sr, MAX_DURATION)
            waveform = waveform[:, :int(duration * model_sr)]

        # cut waveform
        duration = waveform.shape[-1] / model_sr
        return waveform, model_sr, duration


def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
    vocoder_upsample_factor = np.prod(ldm_stable.model.vocoder.config.upsample_rates) / \
        ldm_stable.model.vocoder.config.sampling_rate

    if length is None:
        length = ldm_stable.model.unet.config.sample_size * ldm_stable.model.vae_scale_factor * \
            vocoder_upsample_factor

    height = int(length / vocoder_upsample_factor)

    # original_waveform_length = int(length * ldm_stable.model.vocoder.config.sampling_rate)
    if height % ldm_stable.model.vae_scale_factor != 0:
        height = int(np.ceil(height / ldm_stable.model.vae_scale_factor)) * ldm_stable.model.vae_scale_factor
        print(
            f"Audio length in seconds {length} is increased to {height * vocoder_upsample_factor} "
            f"so that it can be handled by the model. It will be cut to {length} after the "
            f"denoising process."
        )

    return height


def get_text_embeddings(target_prompt: List[str], target_neg_prompt: List[str], ldm_stable: PipelineWrapper
                        ) -> Tuple[torch.Tensor, PromptEmbeddings, PromptEmbeddings]:
    text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = \
        ldm_stable.encode_text(target_prompt)
    uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = \
        ldm_stable.encode_text(target_neg_prompt)

    text_emb = PromptEmbeddings(embedding_hidden_states=text_embeddings_hidden_states,
                                boolean_prompt_mask=text_embeddings_boolean_prompt_mask,
                                embedding_class_lables=text_embeddings_class_labels)
    uncond_emb = PromptEmbeddings(embedding_hidden_states=uncond_embedding_hidden_states,
                                  boolean_prompt_mask=uncond_boolean_prompt_mask,
                                  embedding_class_lables=uncond_embedding_class_lables)

    return text_embeddings_class_labels, text_emb, uncond_emb