Spaces:
Runtime error
Runtime error
from timeit import default_timer as timer | |
from datetime import timedelta | |
from PIL import Image | |
import os | |
import itertools | |
import numpy as np | |
from einops import rearrange | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import transformers | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed | |
from packaging import version | |
from PIL import Image | |
import tqdm | |
from typing import Any, Callable, Dict, List, Optional, Union | |
from transformers import AutoTokenizer, PretrainedConfig | |
from APadapter.ap_adapter.attention_processor import AttnProcessor2_0,IPAttnProcessor2_0 | |
import diffusers | |
from diffusers import ( | |
AutoencoderKL, | |
DDPMScheduler, | |
DiffusionPipeline, | |
DPMSolverMultistepScheduler, | |
StableDiffusionPipeline, | |
UNet2DConditionModel, | |
) | |
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin | |
from diffusers.models.attention_processor import ( | |
AttnAddedKVProcessor, | |
AttnAddedKVProcessor2_0, | |
LoRAAttnAddedKVProcessor, | |
LoRAAttnProcessor, | |
LoRAAttnProcessor2_0, | |
SlicedAttnAddedKVProcessor, | |
) | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
import torchaudio | |
from audio_encoder.AudioMAE import AudioMAEConditionCTPoolRand, extract_kaldi_fbank_feature | |
from audioldm.utils import default_audioldm_config | |
from audioldm.audio import TacotronSTFT, read_wav_file | |
from audioldm.audio.tools import get_mel_from_wav, _pad_spec, normalize_wav, pad_wav | |
from transformers import ( | |
ClapFeatureExtractor, | |
ClapModel, | |
GPT2Model, | |
RobertaTokenizer, | |
RobertaTokenizerFast, | |
SpeechT5HifiGan, | |
T5EncoderModel, | |
T5Tokenizer, | |
T5TokenizerFast, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from peft import ( | |
prepare_model_for_kbit_training, | |
LoraConfig, | |
get_peft_model, | |
PeftModel | |
) | |
from torchviz import make_dot | |
import json | |
from matplotlib import pyplot as plt | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
# check_min_version("0.17.0") | |
def wav_to_fbank( | |
filename, | |
target_length=1024, | |
fn_STFT=None, | |
augment_data=False, | |
mix_data=False, | |
snr=None | |
): | |
assert fn_STFT is not None | |
waveform = read_wav_file(filename, target_length * 160) # hop size is 160 | |
waveform = waveform[0, ...] | |
waveform = torch.FloatTensor(waveform) | |
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) | |
fbank = torch.FloatTensor(fbank.T) | |
log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) | |
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( | |
log_magnitudes_stft, target_length | |
) | |
fbank = fbank.contiguous() | |
log_magnitudes_stft = log_magnitudes_stft.contiguous() | |
waveform = waveform.contiguous() | |
return fbank, log_magnitudes_stft, waveform | |
def wav_to_mel( | |
original_audio_file_path, | |
duration, | |
augment_data=False, | |
mix_data=False, | |
snr=None): | |
config=default_audioldm_config() | |
fn_STFT = TacotronSTFT( | |
config["preprocessing"]["stft"]["filter_length"], | |
config["preprocessing"]["stft"]["hop_length"], | |
config["preprocessing"]["stft"]["win_length"], | |
config["preprocessing"]["mel"]["n_mel_channels"], | |
config["preprocessing"]["audio"]["sampling_rate"], | |
config["preprocessing"]["mel"]["mel_fmin"], | |
config["preprocessing"]["mel"]["mel_fmax"], | |
) | |
mel, _, _ = wav_to_fbank( | |
original_audio_file_path, | |
target_length=int(duration * 102.4), | |
fn_STFT=fn_STFT, | |
augment_data=augment_data, | |
mix_data=mix_data, | |
snr=snr | |
) | |
mel = mel.unsqueeze(0) | |
return mel | |
def prepare_inputs_for_generation( | |
inputs_embeds, | |
attention_mask=None, | |
past_key_values=None, | |
**kwargs, | |
): | |
if past_key_values is not None: | |
# only last token for inputs_embeds if past is defined in kwargs | |
inputs_embeds = inputs_embeds[:, -1:] | |
kwargs["use_cache"] = True | |
return { | |
"inputs_embeds": inputs_embeds, | |
"attention_mask": attention_mask, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
} | |
def generate_language_model( | |
language_model, | |
inputs_embeds: torch.Tensor = None, | |
max_new_tokens: int = 512, | |
**model_kwargs, | |
): | |
""" | |
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs. | |
Parameters: | |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
The sequence used as a prompt for the generation. | |
max_new_tokens (`int`): | |
Number of new tokens to generate. | |
model_kwargs (`Dict[str, Any]`, *optional*): | |
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward` | |
function of the model. | |
Return: | |
`inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
The sequence of generated hidden-states. | |
""" | |
max_new_tokens = max_new_tokens if max_new_tokens is not None else language_model.config.max_new_tokens | |
model_kwargs = language_model._get_initial_cache_position(inputs_embeds, model_kwargs) | |
for _ in range(max_new_tokens): | |
# prepare model inputs | |
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) | |
# forward pass to get next hidden states | |
output = language_model(**model_inputs, return_dict=True) | |
next_hidden_states = output.last_hidden_state | |
# Update the model input | |
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) | |
# Update generated hidden states, model inputs, and length for next step | |
model_kwargs = language_model._update_model_kwargs_for_generation(output, model_kwargs) | |
return inputs_embeds[:, -max_new_tokens:, :] | |
def encode_prompt( | |
tokenizer, | |
tokenizer_2, | |
text_encoder, | |
text_encoder_2, | |
projection_model, | |
language_model, | |
prompt, | |
device, | |
num_waveforms_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt=None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
generated_prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
negative_attention_mask: Optional[torch.LongTensor] = None, | |
max_new_tokens: Optional[int] = None, | |
): | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
# Define tokenizers and text encoders | |
tokenizers = [tokenizer, tokenizer_2] | |
text_encoders = [text_encoder, text_encoder_2] | |
if prompt_embeds is None: | |
prompt_embeds_list = [] | |
attention_mask_list = [] | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
attention_mask = text_inputs.attention_mask | |
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
text_input_ids, untruncated_ids | |
): | |
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) | |
# logger.warning( | |
# f"The following part of your input was truncated because {text_encoder.config.model_type} can " | |
# f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}" | |
# ) | |
text_input_ids = text_input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
if text_encoder.config.model_type == "clap": | |
prompt_embeds = text_encoder.get_text_features( | |
text_input_ids, | |
attention_mask=attention_mask, | |
) | |
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) | |
prompt_embeds = prompt_embeds[:, None, :] | |
# make sure that we attend to this single hidden-state | |
attention_mask = attention_mask.new_ones((batch_size, 1)) | |
else: | |
prompt_embeds = text_encoder( | |
text_input_ids, | |
attention_mask=attention_mask, | |
) | |
prompt_embeds = prompt_embeds[0] | |
prompt_embeds_list.append(prompt_embeds) | |
attention_mask_list.append(attention_mask) | |
projection_output = projection_model( | |
hidden_states=prompt_embeds_list[0], | |
hidden_states_1=prompt_embeds_list[1], | |
attention_mask=attention_mask_list[0], | |
attention_mask_1=attention_mask_list[1], | |
) | |
projected_prompt_embeds = projection_output.hidden_states | |
projected_attention_mask = projection_output.attention_mask | |
generated_prompt_embeds = generate_language_model( | |
language_model, | |
projected_prompt_embeds, | |
attention_mask=projected_attention_mask, | |
max_new_tokens=max_new_tokens, | |
) | |
prompt_embeds = prompt_embeds.to(dtype=text_encoder_2.dtype, device=device) | |
attention_mask = ( | |
attention_mask.to(device=device) | |
if attention_mask is not None | |
else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device) | |
) | |
generated_prompt_embeds = generated_prompt_embeds.to(dtype=language_model.dtype, device=device) | |
bs_embed, seq_len, hidden_size = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size) | |
# duplicate attention mask for each generation per prompt | |
attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt) | |
attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len) | |
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape | |
# duplicate generated embeddings for each generation per prompt, using mps friendly method | |
generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) | |
generated_prompt_embeds = generated_prompt_embeds.view( | |
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size | |
) | |
# get unconditional embeddings for classifier free guidance | |
if do_classifier_free_guidance and negative_prompt_embeds is None: | |
uncond_tokens: List[str] | |
if negative_prompt is None: | |
uncond_tokens = [""] * batch_size | |
elif type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif isinstance(negative_prompt, str): | |
uncond_tokens = [negative_prompt] | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
else: | |
uncond_tokens = negative_prompt | |
negative_prompt_embeds_list = [] | |
negative_attention_mask_list = [] | |
max_length = prompt_embeds.shape[1] | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
uncond_input = tokenizer( | |
uncond_tokens, | |
padding="max_length", | |
max_length=tokenizer.model_max_length | |
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) | |
else max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
uncond_input_ids = uncond_input.input_ids.to(device) | |
negative_attention_mask = uncond_input.attention_mask.to(device) | |
if text_encoder.config.model_type == "clap": | |
negative_prompt_embeds = text_encoder.get_text_features( | |
uncond_input_ids, | |
attention_mask=negative_attention_mask, | |
) | |
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) | |
negative_prompt_embeds = negative_prompt_embeds[:, None, :] | |
# make sure that we attend to this single hidden-state | |
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1)) | |
else: | |
negative_prompt_embeds = text_encoder( | |
uncond_input_ids, | |
attention_mask=negative_attention_mask, | |
) | |
negative_prompt_embeds = negative_prompt_embeds[0] | |
negative_prompt_embeds_list.append(negative_prompt_embeds) | |
negative_attention_mask_list.append(negative_attention_mask) | |
projection_output = projection_model( | |
hidden_states=negative_prompt_embeds_list[0], | |
hidden_states_1=negative_prompt_embeds_list[1], | |
attention_mask=negative_attention_mask_list[0], | |
attention_mask_1=negative_attention_mask_list[1], | |
) | |
negative_projected_prompt_embeds = projection_output.hidden_states | |
negative_projected_attention_mask = projection_output.attention_mask | |
negative_generated_prompt_embeds = generate_language_model( | |
language_model, | |
negative_projected_prompt_embeds, | |
attention_mask=negative_projected_attention_mask, | |
max_new_tokens=max_new_tokens, | |
) | |
if do_classifier_free_guidance: | |
seq_len = negative_prompt_embeds.shape[1] | |
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder_2.dtype, device=device) | |
negative_attention_mask = ( | |
negative_attention_mask.to(device=device) | |
if negative_attention_mask is not None | |
else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device) | |
) | |
negative_generated_prompt_embeds = negative_generated_prompt_embeds.to( | |
dtype=language_model.dtype, device=device | |
) | |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) | |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1) | |
# duplicate unconditional attention mask for each generation per prompt | |
negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt) | |
negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len) | |
# duplicate unconditional generated embeddings for each generation per prompt | |
seq_len = negative_generated_prompt_embeds.shape[1] | |
negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) | |
negative_generated_prompt_embeds = negative_generated_prompt_embeds.view( | |
batch_size * num_waveforms_per_prompt, seq_len, -1 | |
) | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
attention_mask = torch.cat([negative_attention_mask, attention_mask]) | |
generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds]) | |
return prompt_embeds, attention_mask, generated_prompt_embeds | |
def prepare_latents(vae, vocoder, scheduler, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): | |
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
shape = ( | |
batch_size, | |
num_channels_latents, | |
height // vae_scale_factor, | |
vocoder.config.model_in_dim // vae_scale_factor, | |
) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if latents is None: | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
latents = latents.to(device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * scheduler.init_noise_sigma | |
return latents | |
def plot_loss(loss_history, loss_plot_path, lora_steps): | |
plt.figure(figsize=(10, 6)) | |
plt.plot(range(1, lora_steps + 1), loss_history, label="Training Loss") | |
plt.xlabel("Steps") | |
plt.ylabel("Loss") | |
plt.title("Training Loss Over Steps") | |
plt.legend() | |
plt.grid(True) | |
plt.savefig(loss_plot_path) | |
plt.close() | |
# print(f"Loss plot saved to {loss_plot_path}") | |
# model_path: path of the model | |
# image: input image, have not been pre-processed | |
# save_lora_dir: the path to save the lora | |
# prompt: the user input prompt | |
# lora_steps: number of lora training step | |
# lora_lr: learning rate of lora training | |
# lora_rank: the rank of lora | |
def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None, | |
text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None, | |
vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm): | |
time_pooling = time_pooling | |
freq_pooling = freq_pooling | |
# initialize accelerator | |
# accelerator = Accelerator( | |
# gradient_accumulation_steps=1, | |
# mixed_precision='no' | |
# ) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
set_seed(0) | |
# set device and dtype | |
# prepare accelerator | |
# unet_lora_layers = accelerator.prepare_model(unet_lora_layers) | |
# optimizer = accelerator.prepare_optimizer(optimizer) | |
# lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
text_encoder_2.requires_grad_(False) | |
GPT2.requires_grad_(False) | |
projection_model.requires_grad_(False) | |
vocoder.requires_grad_(False) | |
unet.requires_grad_(False) | |
for name, param in text_encoder_2.named_parameters(): | |
if param.requires_grad: | |
print(name) | |
for name, param in GPT2.named_parameters(): | |
if param.requires_grad: | |
print(name) | |
for name, param in vae.named_parameters(): | |
if param.requires_grad: | |
print(name) | |
for name, param in vocoder.named_parameters(): | |
if param.requires_grad: | |
print(name) | |
unet.to(device) | |
vae.to(device) | |
text_encoder.to(device) | |
# initialize UNet LoRA | |
unet_lora_attn_procs = {} | |
i = 0 # Counter variable to iterate through the cross-attention dimension array. | |
cross = [None, None, 768, 768, 1024, 1024, None, None] # Predefined cross-attention dimensions for different layers. | |
do_copy = False | |
for name, attn_processor in unet.attn_processors.items(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
else: | |
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") | |
# if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): | |
# lora_attn_processor_class = LoRAAttnAddedKVProcessor | |
# else: | |
# lora_attn_processor_class = ( | |
# LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor | |
# ) | |
if cross_attention_dim is None: | |
unet_lora_attn_procs[name] = AttnProcessor2_0() | |
else: | |
cross_attention_dim = cross[i%8] | |
i += 1 | |
if cross_attention_dim == 768: | |
unet_lora_attn_procs[name] = IPAttnProcessor2_0( | |
hidden_size=hidden_size, | |
name = name, | |
cross_attention_dim=cross_attention_dim, | |
scale=1.0, | |
num_tokens=8, | |
do_copy = do_copy | |
).to(device, dtype=torch.float32) | |
else: | |
unet_lora_attn_procs[name] = AttnProcessor2_0() | |
unet.set_attn_processor(unet_lora_attn_procs) | |
unet_lora_layers = AttnProcsLayers(unet.attn_processors) | |
# Optimizer creation | |
params_to_optimize = (unet_lora_layers.parameters()) | |
optimizer = torch.optim.AdamW( | |
params_to_optimize, | |
lr=lora_lr, | |
betas=(0.9, 0.999), | |
weight_decay=1e-2, | |
eps=1e-08, | |
) | |
lr_scheduler = get_scheduler( | |
"constant", | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=lora_steps, | |
num_cycles=1, | |
power=1.0, | |
) | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
# initialize text embeddings | |
with torch.no_grad(): | |
prompt_embeds, attention_mask, generated_prompt_embeds = encode_prompt( | |
tokenizer, | |
tokenizer_2, | |
text_encoder, | |
text_encoder_2, | |
projection_model, | |
GPT2, | |
prompt, | |
device, | |
num_waveforms_per_prompt = 1, | |
do_classifier_free_guidance= do_classifier_free_guidance, | |
negative_prompt = negative_prompt, | |
) | |
waveform, sr = torchaudio.load(audio_path) | |
fbank = torch.zeros((1024, 128)) | |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank) | |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0) | |
model = AudioMAEConditionCTPoolRand().to(device).to(dtype=torch.float32) | |
model.eval() | |
mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype) | |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling) | |
uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling) | |
LOA_embeds = LOA_embed[0] | |
uncond_LOA_embeds = uncond_LOA_embed[0] | |
bs_embed, seq_len, _ = LOA_embeds.shape | |
num = prompt_embeds.shape[0] // 2 | |
LOA_embeds = LOA_embeds.view(bs_embed , seq_len, -1) | |
LOA_embeds = LOA_embeds.repeat(num, 1, 1) | |
uncond_LOA_embeds = uncond_LOA_embeds.view(bs_embed , seq_len, -1) | |
uncond_LOA_embeds = uncond_LOA_embeds.repeat(num, 1, 1) | |
negative_g, g = generated_prompt_embeds.chunk(2) | |
uncond = torch.cat([negative_g, uncond_LOA_embeds], dim=1) | |
cond = torch.cat([g, LOA_embeds], dim=1) | |
generated_prompt_embeds = torch.cat([uncond, cond], dim=0) | |
model_dtype = next(unet.parameters()).dtype | |
generated_prompt_embeds = generated_prompt_embeds.to(model_dtype) | |
# num_channels_latents = unet.config.in_channels | |
# batch_size = 1 | |
# num_waveforms_per_prompt = 1 | |
# generator = None | |
# latents = None | |
# latents = prepare_latents( | |
# vae, | |
# vocoder, | |
# noise_scheduler, | |
# batch_size * num_waveforms_per_prompt, | |
# num_channels_latents, | |
# height, | |
# prompt_embeds.dtype, | |
# device, | |
# generator, | |
# latents, | |
# ) | |
loss_history = [] | |
if not os.path.exists(save_lora_dir): | |
os.makedirs(save_lora_dir) | |
weight_path = os.path.join(save_lora_dir, weight_name) | |
base_name, _ = os.path.splitext(weight_path) | |
save_image_path = f"{base_name}.png" | |
print(f'Save image path: {save_image_path}') | |
mel_spect_tensor = wav_to_mel(audio_path, duration = 10).unsqueeze(0).to(next(vae.parameters()).dtype) | |
for step in progress.tqdm(range(lora_steps), desc="Training LoRA..."): | |
unet.train() | |
# with accelerator.accumulate(unet): | |
latents_dist = vae.encode(mel_spect_tensor.to(device)).latent_dist | |
model_input = torch.cat([latents_dist.sample()] * 2) if do_classifier_free_guidance else latents_dist.sample() | |
model_input = model_input * vae.config.scaling_factor | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(model_input).to(model_input.device) | |
bsz, channels, height, width = model_input.shape | |
# Sample a random timestep for each image | |
timesteps = torch.randint( | |
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device | |
) | |
timesteps = timesteps.long() | |
# Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process) | |
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) | |
generated_prompt_embeds = generated_prompt_embeds.to(device) | |
prompt_embeds = prompt_embeds.to(device) | |
attention_mask = attention_mask.to(device) | |
# Predict the noise residual | |
model_pred = unet(sample=noisy_model_input, | |
timestep=timesteps, | |
encoder_hidden_states=generated_prompt_embeds, | |
encoder_hidden_states_1=prompt_embeds, | |
encoder_attention_mask_1=attention_mask, | |
return_dict=False, | |
)[0] | |
# Get the target for loss depending on the prediction type | |
if noise_scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif noise_scheduler.config.prediction_type == "v_prediction": | |
target = noise_scheduler.get_velocity(model_input, noise, timesteps) | |
else: | |
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
loss = F.mse_loss(model_pred, target, reduction="mean") | |
loss_history.append(loss.item()) | |
loss.requires_grad = True | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# with open(loss_log_path, "w") as f: | |
# json.dump(loss_history, f) | |
plot_loss(loss_history, save_image_path, step+1) | |
LoraLoaderMixin.save_lora_weights( | |
save_directory=save_lora_dir, | |
unet_lora_layers=unet_lora_layers, | |
text_encoder_lora_layers=None, | |
weight_name=weight_name, | |
safe_serialization=safe_serialization | |
) | |
def load_lora(unet, lora_0, lora_1, alpha): | |
attn_procs = unet.attn_processors | |
for name, processor in attn_procs.items(): | |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'): | |
weight_name_v = name + ".to_v_ip.weight" | |
weight_name_k = name + ".to_k_ip.weight" | |
if weight_name_v in lora_0 and weight_name_v in lora_1: | |
v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v] | |
processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half()) | |
if weight_name_k in lora_0 and weight_name_k in lora_1: | |
k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k] | |
processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half()) | |
unet.set_attn_processor(attn_procs) | |
return unet | |