Spaces:
Running
Running
File size: 18,413 Bytes
96134ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import os
import sys
import torch
import librosa
import numpy as np
import torch.nn.functional as F
from scipy.signal import get_window
from librosa.util import pad_center
from diffusers import DDIMScheduler, AudioLDM2Pipeline
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
sys.path.append(os.getcwd())
from main.configs.config import Config
from main.library.utils import check_audioldm2
config = Config()
class Pipeline(torch.nn.Module):
def __init__(self, model_id, device, double_precision = False, token = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_id = model_id
self.device = device
self.double_precision = double_precision
self.token = token
def load_scheduler(self):
pass
def get_melspectrogram(self):
pass
def vae_encode(self, x):
pass
def vae_decode(self, x):
pass
def decode_to_mel(self, x):
pass
def setup_extra_inputs(self, *args, **kwargs):
pass
def encode_text(self, prompts, **kwargs):
pass
def get_variance(self, timestep, prev_timestep):
pass
def get_alpha_prod_t_prev(self, prev_timestep):
pass
def get_noise_shape(self, x0, num_steps):
return (num_steps, self.model.unet.config.in_channels, x0.shape[-2], x0.shape[-1])
def sample_xts_from_x0(self, x0, num_inference_steps = 50):
alpha_bar = self.model.scheduler.alphas_cumprod
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
timesteps = self.model.scheduler.timesteps.to(self.device)
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
xts = torch.zeros(self.get_noise_shape(x0, num_inference_steps + 1)).to(x0.device)
xts[0] = x0
for t in reversed(timesteps):
idx = num_inference_steps - t_to_idx[int(t)]
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
return xts
def get_zs_from_xts(self, xt, xtm1, noise_pred, t, eta = 0, numerical_fix = True, **kwargs):
alpha_bar = self.model.scheduler.alphas_cumprod
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
prev_timestep = t - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
variance = self.get_variance(t, prev_timestep)
if self.model.scheduler.config.prediction_type == 'epsilon': radom_noise_pred = noise_pred
elif self.model.scheduler.config.prediction_type == 'v_prediction': radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred)
z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
if numerical_fix: xtm1 = mu_xt + (eta * variance ** 0.5)*z
return z, xtm1, None
def reverse_step_with_custom_noise(self, model_output, timestep, sample, variance_noise = None, eta = 0, **kwargs):
prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
beta_prod_t = 1 - alpha_prod_t
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
variance = self.get_variance(timestep, prev_timestep)
if self.model.scheduler.config.prediction_type == 'epsilon': model_output_direction = model_output
elif self.model.scheduler.config.prediction_type == 'v_prediction': model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction)
if eta > 0:
if variance_noise is None: variance_noise = torch.randn(model_output.shape, device=self.device)
prev_sample = prev_sample + (eta * variance ** (0.5) * variance_noise)
return prev_sample
def unet_forward(self, sample, timestep, encoder_hidden_states, class_labels = None, timestep_cond = None, attention_mask = None, cross_attention_kwargs = None, added_cond_kwargs = None, down_block_additional_residuals = None, mid_block_additional_residual = None, encoder_attention_mask = None, replace_h_space = None, replace_skip_conns = None, return_dict = True, zero_out_resconns = None):
pass
class STFT(torch.nn.Module):
def __init__(self, fft_size, hop_size, window_size, window_type="hann"):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.window_size = window_size
self.window_type = window_type
scale = fft_size / hop_size
fourier_basis = np.fft.fft(np.eye(fft_size))
cutoff = fft_size // 2 + 1
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
self.forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
self.inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window_type:
assert fft_size >= window_size
fft_window = torch.from_numpy(pad_center(get_window(window_type, window_size, fftbins=True), size=fft_size)).float()
self.forward_basis *= fft_window
self.inverse_basis *= fft_window
if not hasattr(self, "forward_basis"): self.register_buffer("forward_basis", self.forward_basis)
if not hasattr(self, "inverse_basis"): self.register_buffer("inverse_basis", self.inverse_basis)
def transform(self, signal):
batch_size, num_samples = signal.shape
transformed_signal = F.conv1d(F.pad(signal.view(batch_size, 1, num_samples).unsqueeze(1), (self.fft_size // 2, self.fft_size // 2, 0, 0), mode="reflect").squeeze(1), self.forward_basis, stride=self.hop_size, padding=0).cpu()
cutoff = self.fft_size // 2 + 1
real_part, imag_part = transformed_signal[:, :cutoff, :], transformed_signal[:, cutoff:, :]
return torch.sqrt(real_part ** 2 + imag_part ** 2), torch.atan2(imag_part, real_part)
class MelSpectrogramProcessor(torch.nn.Module):
def __init__(self, fft_size, hop_size, window_size, num_mel_bins, sample_rate, fmin, fmax):
super().__init__()
self.num_mel_bins = num_mel_bins
self.sample_rate = sample_rate
self.stft_processor = STFT(fft_size, hop_size, window_size)
self.register_buffer("mel_filter", torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mel_bins, fmin=fmin, fmax=fmax)).float())
def compute_mel_spectrogram(self, waveform, normalization_fn=torch.log):
assert torch.min(waveform) >= -1
assert torch.max(waveform) <= 1
magnitudes, _ = self.stft_processor.transform(waveform)
return normalization_fn(torch.clamp(torch.matmul(self.mel_filter, magnitudes), min=1e-5))
class AudioLDM2(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, torch_dtype=torch.float16 if config.is_half else torch.float32).to(self.device)
def load_scheduler(self):
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
def get_melspectrogram(self):
return MelSpectrogramProcessor(fft_size=1024, hop_size=160, window_size=1024, num_mel_bins=64, sample_rate=16000, fmin=0, fmax=8000)
def vae_encode(self, x):
if x.shape[2] % 4: x = F.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
output = (self.model.vae.encode(x.half() if config.is_half else x.float()).latent_dist.mode() * self.model.vae.config.scaling_factor)
return output.half() if config.is_half else output.float()
def vae_decode(self, x):
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
def decode_to_mel(self, x):
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().to(torch.float16 if config.is_half else torch.float32)).detach()
if len(tmp.shape) == 1: tmp = tmp.unsqueeze(0)
return tmp
def encode_text(self, prompts, negative = False, save_compute = False, cond_length = 0, **kwargs):
tokenizers, text_encoders = [self.model.tokenizer, self.model.tokenizer_2], [self.model.text_encoder, self.model.text_encoder_2]
prompt_embeds_list, attention_mask_list = [], []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(prompts, padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, max_length=tokenizer.model_max_length if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))) else cond_length, truncation=True, return_tensors="pt")
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
text_input_ids = text_input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
if text_encoder.config.model_type == "clap":
prompt_embeds = text_encoder.get_text_features(text_input_ids, attention_mask=attention_mask)
prompt_embeds = prompt_embeds[:, None, :]
attention_mask = attention_mask.new_ones((len(prompts), 1))
else: prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)[0]
prompt_embeds_list.append(prompt_embeds)
attention_mask_list.append(attention_mask)
projection_output = self.model.projection_model(hidden_states=prompt_embeds_list[0], hidden_states_1=prompt_embeds_list[1], attention_mask=attention_mask_list[0], attention_mask_1=attention_mask_list[1])
generated_prompt_embeds = self.model.generate_language_model(projection_output.hidden_states, attention_mask=projection_output.attention_mask, max_new_tokens=None)
prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device)
return generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device), prompt_embeds, (attention_mask.to(device=self.device) if attention_mask is not None else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device))
def get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
return ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) * (1 - alpha_prod_t / alpha_prod_t_prev)
def get_alpha_prod_t_prev(self, prev_timestep):
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.model.scheduler.final_alpha_cumprod
def unet_forward(self, sample, timestep, encoder_hidden_states, timestep_cond = None, class_labels = None, attention_mask = None, encoder_attention_mask = None, return_dict = True, cross_attention_kwargs = None, mid_block_additional_residual = None, replace_h_space = None, replace_skip_conns = None, zero_out_resconns = None):
encoder_hidden_states_1 = class_labels
class_labels = None
encoder_attention_mask_1 = encoder_attention_mask
encoder_attention_mask = None
default_overall_up_factor = 2 ** self.model.unet.num_upsamplers
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): forward_upsample_size = True
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if encoder_attention_mask_1 is not None:
encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
timesteps = timestep
if not torch.is_tensor(timesteps):
is_mps = sample.device.type == "mps"
dtype = (torch.float16 if is_mps else torch.float32) if isinstance(timestep, float) else (torch.int16 if is_mps else torch.int32)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device)
emb = self.model.unet.time_embedding(self.model.unet.time_proj(timesteps.expand(sample.shape[0])).to(dtype=sample.dtype), timestep_cond)
aug_emb = None
if self.model.unet.class_embedding is not None:
if class_labels is None: raise ValueError
if self.model.unet.config.class_embed_type == "timestep": class_labels = self.model.unet.time_proj(class_labels).to(dtype=sample.dtype)
class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
if self.model.unet.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1)
else: emb = emb + class_emb
emb = emb + aug_emb if aug_emb is not None else emb
if self.model.unet.time_embed_act is not None: emb = self.model.unet.time_embed_act(emb)
sample = self.model.unet.conv_in(sample)
down_block_res_samples = (sample,)
for downsample_block in self.model.unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
if self.model.unet.mid_block is not None: sample = self.model.unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
if replace_h_space is None: h_space = sample.clone()
else:
h_space = replace_h_space
sample = replace_h_space.clone()
if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual
extracted_res_conns = {}
for i, upsample_block in enumerate(self.model.unet.up_blocks):
is_final_block = i == len(self.model.unet.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if replace_skip_conns is not None and replace_skip_conns.get(i): res_samples = replace_skip_conns.get(i)
if zero_out_resconns is not None:
if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or type(zero_out_resconns) is list and i in zero_out_resconns: res_samples = [torch.zeros_like(x) for x in res_samples]
extracted_res_conns[i] = res_samples
if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
else: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)
if self.model.unet.conv_norm_out: sample = self.model.unet.conv_act(self.model.unet.conv_norm_out(sample))
sample = self.model.unet.conv_out(sample)
if not return_dict: return (sample,)
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
def load_model(model, device):
check_audioldm2(model)
ldm_stable = AudioLDM2(model_id=os.path.join("assets", "models", "audioldm2", model), device=device, double_precision=False)
ldm_stable.load_scheduler()
if torch.cuda.is_available(): torch.cuda.empty_cache()
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
return ldm_stable |