NeoPy's picture
Upload 115 files
96134ee verified
raw
history blame
18.4 kB
import os
import sys
import torch
import librosa
import numpy as np
import torch.nn.functional as F
from scipy.signal import get_window
from librosa.util import pad_center
from diffusers import DDIMScheduler, AudioLDM2Pipeline
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
sys.path.append(os.getcwd())
from main.configs.config import Config
from main.library.utils import check_audioldm2
config = Config()
class Pipeline(torch.nn.Module):
def __init__(self, model_id, device, double_precision = False, token = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_id = model_id
self.device = device
self.double_precision = double_precision
self.token = token
def load_scheduler(self):
pass
def get_melspectrogram(self):
pass
def vae_encode(self, x):
pass
def vae_decode(self, x):
pass
def decode_to_mel(self, x):
pass
def setup_extra_inputs(self, *args, **kwargs):
pass
def encode_text(self, prompts, **kwargs):
pass
def get_variance(self, timestep, prev_timestep):
pass
def get_alpha_prod_t_prev(self, prev_timestep):
pass
def get_noise_shape(self, x0, num_steps):
return (num_steps, self.model.unet.config.in_channels, x0.shape[-2], x0.shape[-1])
def sample_xts_from_x0(self, x0, num_inference_steps = 50):
alpha_bar = self.model.scheduler.alphas_cumprod
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
timesteps = self.model.scheduler.timesteps.to(self.device)
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
xts = torch.zeros(self.get_noise_shape(x0, num_inference_steps + 1)).to(x0.device)
xts[0] = x0
for t in reversed(timesteps):
idx = num_inference_steps - t_to_idx[int(t)]
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
return xts
def get_zs_from_xts(self, xt, xtm1, noise_pred, t, eta = 0, numerical_fix = True, **kwargs):
alpha_bar = self.model.scheduler.alphas_cumprod
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
prev_timestep = t - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
variance = self.get_variance(t, prev_timestep)
if self.model.scheduler.config.prediction_type == 'epsilon': radom_noise_pred = noise_pred
elif self.model.scheduler.config.prediction_type == 'v_prediction': radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred)
z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
if numerical_fix: xtm1 = mu_xt + (eta * variance ** 0.5)*z
return z, xtm1, None
def reverse_step_with_custom_noise(self, model_output, timestep, sample, variance_noise = None, eta = 0, **kwargs):
prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
beta_prod_t = 1 - alpha_prod_t
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
variance = self.get_variance(timestep, prev_timestep)
if self.model.scheduler.config.prediction_type == 'epsilon': model_output_direction = model_output
elif self.model.scheduler.config.prediction_type == 'v_prediction': model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction)
if eta > 0:
if variance_noise is None: variance_noise = torch.randn(model_output.shape, device=self.device)
prev_sample = prev_sample + (eta * variance ** (0.5) * variance_noise)
return prev_sample
def unet_forward(self, sample, timestep, encoder_hidden_states, class_labels = None, timestep_cond = None, attention_mask = None, cross_attention_kwargs = None, added_cond_kwargs = None, down_block_additional_residuals = None, mid_block_additional_residual = None, encoder_attention_mask = None, replace_h_space = None, replace_skip_conns = None, return_dict = True, zero_out_resconns = None):
pass
class STFT(torch.nn.Module):
def __init__(self, fft_size, hop_size, window_size, window_type="hann"):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.window_size = window_size
self.window_type = window_type
scale = fft_size / hop_size
fourier_basis = np.fft.fft(np.eye(fft_size))
cutoff = fft_size // 2 + 1
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
self.forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
self.inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window_type:
assert fft_size >= window_size
fft_window = torch.from_numpy(pad_center(get_window(window_type, window_size, fftbins=True), size=fft_size)).float()
self.forward_basis *= fft_window
self.inverse_basis *= fft_window
if not hasattr(self, "forward_basis"): self.register_buffer("forward_basis", self.forward_basis)
if not hasattr(self, "inverse_basis"): self.register_buffer("inverse_basis", self.inverse_basis)
def transform(self, signal):
batch_size, num_samples = signal.shape
transformed_signal = F.conv1d(F.pad(signal.view(batch_size, 1, num_samples).unsqueeze(1), (self.fft_size // 2, self.fft_size // 2, 0, 0), mode="reflect").squeeze(1), self.forward_basis, stride=self.hop_size, padding=0).cpu()
cutoff = self.fft_size // 2 + 1
real_part, imag_part = transformed_signal[:, :cutoff, :], transformed_signal[:, cutoff:, :]
return torch.sqrt(real_part ** 2 + imag_part ** 2), torch.atan2(imag_part, real_part)
class MelSpectrogramProcessor(torch.nn.Module):
def __init__(self, fft_size, hop_size, window_size, num_mel_bins, sample_rate, fmin, fmax):
super().__init__()
self.num_mel_bins = num_mel_bins
self.sample_rate = sample_rate
self.stft_processor = STFT(fft_size, hop_size, window_size)
self.register_buffer("mel_filter", torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mel_bins, fmin=fmin, fmax=fmax)).float())
def compute_mel_spectrogram(self, waveform, normalization_fn=torch.log):
assert torch.min(waveform) >= -1
assert torch.max(waveform) <= 1
magnitudes, _ = self.stft_processor.transform(waveform)
return normalization_fn(torch.clamp(torch.matmul(self.mel_filter, magnitudes), min=1e-5))
class AudioLDM2(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, torch_dtype=torch.float16 if config.is_half else torch.float32).to(self.device)
def load_scheduler(self):
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
def get_melspectrogram(self):
return MelSpectrogramProcessor(fft_size=1024, hop_size=160, window_size=1024, num_mel_bins=64, sample_rate=16000, fmin=0, fmax=8000)
def vae_encode(self, x):
if x.shape[2] % 4: x = F.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
output = (self.model.vae.encode(x.half() if config.is_half else x.float()).latent_dist.mode() * self.model.vae.config.scaling_factor)
return output.half() if config.is_half else output.float()
def vae_decode(self, x):
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
def decode_to_mel(self, x):
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().to(torch.float16 if config.is_half else torch.float32)).detach()
if len(tmp.shape) == 1: tmp = tmp.unsqueeze(0)
return tmp
def encode_text(self, prompts, negative = False, save_compute = False, cond_length = 0, **kwargs):
tokenizers, text_encoders = [self.model.tokenizer, self.model.tokenizer_2], [self.model.text_encoder, self.model.text_encoder_2]
prompt_embeds_list, attention_mask_list = [], []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(prompts, padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, max_length=tokenizer.model_max_length if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))) else cond_length, truncation=True, return_tensors="pt")
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
text_input_ids = text_input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
if text_encoder.config.model_type == "clap":
prompt_embeds = text_encoder.get_text_features(text_input_ids, attention_mask=attention_mask)
prompt_embeds = prompt_embeds[:, None, :]
attention_mask = attention_mask.new_ones((len(prompts), 1))
else: prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)[0]
prompt_embeds_list.append(prompt_embeds)
attention_mask_list.append(attention_mask)
projection_output = self.model.projection_model(hidden_states=prompt_embeds_list[0], hidden_states_1=prompt_embeds_list[1], attention_mask=attention_mask_list[0], attention_mask_1=attention_mask_list[1])
generated_prompt_embeds = self.model.generate_language_model(projection_output.hidden_states, attention_mask=projection_output.attention_mask, max_new_tokens=None)
prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device)
return generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device), prompt_embeds, (attention_mask.to(device=self.device) if attention_mask is not None else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device))
def get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
return ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) * (1 - alpha_prod_t / alpha_prod_t_prev)
def get_alpha_prod_t_prev(self, prev_timestep):
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.model.scheduler.final_alpha_cumprod
def unet_forward(self, sample, timestep, encoder_hidden_states, timestep_cond = None, class_labels = None, attention_mask = None, encoder_attention_mask = None, return_dict = True, cross_attention_kwargs = None, mid_block_additional_residual = None, replace_h_space = None, replace_skip_conns = None, zero_out_resconns = None):
encoder_hidden_states_1 = class_labels
class_labels = None
encoder_attention_mask_1 = encoder_attention_mask
encoder_attention_mask = None
default_overall_up_factor = 2 ** self.model.unet.num_upsamplers
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): forward_upsample_size = True
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if encoder_attention_mask_1 is not None:
encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
timesteps = timestep
if not torch.is_tensor(timesteps):
is_mps = sample.device.type == "mps"
dtype = (torch.float16 if is_mps else torch.float32) if isinstance(timestep, float) else (torch.int16 if is_mps else torch.int32)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device)
emb = self.model.unet.time_embedding(self.model.unet.time_proj(timesteps.expand(sample.shape[0])).to(dtype=sample.dtype), timestep_cond)
aug_emb = None
if self.model.unet.class_embedding is not None:
if class_labels is None: raise ValueError
if self.model.unet.config.class_embed_type == "timestep": class_labels = self.model.unet.time_proj(class_labels).to(dtype=sample.dtype)
class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
if self.model.unet.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1)
else: emb = emb + class_emb
emb = emb + aug_emb if aug_emb is not None else emb
if self.model.unet.time_embed_act is not None: emb = self.model.unet.time_embed_act(emb)
sample = self.model.unet.conv_in(sample)
down_block_res_samples = (sample,)
for downsample_block in self.model.unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
if self.model.unet.mid_block is not None: sample = self.model.unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
if replace_h_space is None: h_space = sample.clone()
else:
h_space = replace_h_space
sample = replace_h_space.clone()
if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual
extracted_res_conns = {}
for i, upsample_block in enumerate(self.model.unet.up_blocks):
is_final_block = i == len(self.model.unet.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if replace_skip_conns is not None and replace_skip_conns.get(i): res_samples = replace_skip_conns.get(i)
if zero_out_resconns is not None:
if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or type(zero_out_resconns) is list and i in zero_out_resconns: res_samples = [torch.zeros_like(x) for x in res_samples]
extracted_res_conns[i] = res_samples
if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
else: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)
if self.model.unet.conv_norm_out: sample = self.model.unet.conv_act(self.model.unet.conv_norm_out(sample))
sample = self.model.unet.conv_out(sample)
if not return_dict: return (sample,)
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
def load_model(model, device):
check_audioldm2(model)
ldm_stable = AudioLDM2(model_id=os.path.join("assets", "models", "audioldm2", model), device=device, double_precision=False)
ldm_stable.load_scheduler()
if torch.cuda.is_available(): torch.cuda.empty_cache()
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
return ldm_stable