DiffmorpherXAp-adapter / utils /lora_utils_successed_ver1.py
CSH-1220
Add application file
d57e374
raw
history blame
30 kB
from timeit import default_timer as timer
from datetime import timedelta
from PIL import Image
import os
import itertools
import numpy as np
from einops import rearrange
import torch
import torch.nn.functional as F
from torchvision import transforms
import transformers
from accelerate import Accelerator
from accelerate.utils import set_seed
from packaging import version
from PIL import Image
import tqdm
from typing import Any, Callable, Dict, List, Optional, Union
from transformers import AutoTokenizer, PretrainedConfig
from APadapter.ap_adapter.attention_processor import AttnProcessor2_0,IPAttnProcessor2_0
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
import torchaudio
from audio_encoder.AudioMAE import AudioMAEConditionCTPoolRand, extract_kaldi_fbank_feature
from audioldm.utils import default_audioldm_config
from audioldm.audio import TacotronSTFT, read_wav_file
from audioldm.audio.tools import get_mel_from_wav, _pad_spec, normalize_wav, pad_wav
from transformers import (
ClapFeatureExtractor,
ClapModel,
GPT2Model,
RobertaTokenizer,
RobertaTokenizerFast,
SpeechT5HifiGan,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from diffusers.utils.torch_utils import randn_tensor
from peft import (
prepare_model_for_kbit_training,
LoraConfig,
get_peft_model,
PeftModel
)
from torchviz import make_dot
import json
from matplotlib import pyplot as plt
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.17.0")
def wav_to_fbank(
filename,
target_length=1024,
fn_STFT=None,
augment_data=False,
mix_data=False,
snr=None
):
assert fn_STFT is not None
waveform = read_wav_file(filename, target_length * 160) # hop size is 160
waveform = waveform[0, ...]
waveform = torch.FloatTensor(waveform)
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
fbank = torch.FloatTensor(fbank.T)
log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
log_magnitudes_stft, target_length
)
fbank = fbank.contiguous()
log_magnitudes_stft = log_magnitudes_stft.contiguous()
waveform = waveform.contiguous()
return fbank, log_magnitudes_stft, waveform
def wav_to_mel(
original_audio_file_path,
duration,
augment_data=False,
mix_data=False,
snr=None):
config=default_audioldm_config()
fn_STFT = TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
mel, _, _ = wav_to_fbank(
original_audio_file_path,
target_length=int(duration * 102.4),
fn_STFT=fn_STFT,
augment_data=augment_data,
mix_data=mix_data,
snr=snr
)
mel = mel.unsqueeze(0)
return mel
def prepare_inputs_for_generation(
inputs_embeds,
attention_mask=None,
past_key_values=None,
**kwargs,
):
if past_key_values is not None:
# only last token for inputs_embeds if past is defined in kwargs
inputs_embeds = inputs_embeds[:, -1:]
kwargs["use_cache"] = True
return {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
def generate_language_model(
language_model,
inputs_embeds: torch.Tensor = None,
max_new_tokens: int = 512,
**model_kwargs,
):
"""
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
Parameters:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence used as a prompt for the generation.
max_new_tokens (`int`):
Number of new tokens to generate.
model_kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
function of the model.
Return:
`inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
max_new_tokens = max_new_tokens if max_new_tokens is not None else language_model.config.max_new_tokens
model_kwargs = language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
# forward pass to get next hidden states
output = language_model(**model_inputs, return_dict=True)
next_hidden_states = output.last_hidden_state
# Update the model input
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
# Update generated hidden states, model inputs, and length for next step
model_kwargs = language_model._update_model_kwargs_for_generation(output, model_kwargs)
return inputs_embeds[:, -max_new_tokens:, :]
def encode_prompt(
tokenizer,
tokenizer_2,
text_encoder,
text_encoder_2,
projection_model,
language_model,
prompt,
device,
num_waveforms_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
max_new_tokens: Optional[int] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
if prompt_embeds is None:
prompt_embeds_list = []
attention_mask_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(
prompt,
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
# logger.warning(
# f"The following part of your input was truncated because {text_encoder.config.model_type} can "
# f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
# )
text_input_ids = text_input_ids.to(device)
attention_mask = attention_mask.to(device)
if text_encoder.config.model_type == "clap":
prompt_embeds = text_encoder.get_text_features(
text_input_ids,
attention_mask=attention_mask,
)
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
prompt_embeds = prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state
attention_mask = attention_mask.new_ones((batch_size, 1))
else:
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds_list.append(prompt_embeds)
attention_mask_list.append(attention_mask)
projection_output = projection_model(
hidden_states=prompt_embeds_list[0],
hidden_states_1=prompt_embeds_list[1],
attention_mask=attention_mask_list[0],
attention_mask_1=attention_mask_list[1],
)
projected_prompt_embeds = projection_output.hidden_states
projected_attention_mask = projection_output.attention_mask
generated_prompt_embeds = generate_language_model(
language_model,
projected_prompt_embeds,
attention_mask=projected_attention_mask,
max_new_tokens=max_new_tokens,
)
prompt_embeds = prompt_embeds.to(dtype=text_encoder_2.dtype, device=device)
attention_mask = (
attention_mask.to(device=device)
if attention_mask is not None
else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
)
generated_prompt_embeds = generated_prompt_embeds.to(dtype=language_model.dtype, device=device)
bs_embed, seq_len, hidden_size = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
# duplicate attention mask for each generation per prompt
attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
# duplicate generated embeddings for each generation per prompt, using mps friendly method
generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
generated_prompt_embeds = generated_prompt_embeds.view(
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
negative_prompt_embeds_list = []
negative_attention_mask_list = []
max_length = prompt_embeds.shape[1]
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=tokenizer.model_max_length
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
else max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
if text_encoder.config.model_type == "clap":
negative_prompt_embeds = text_encoder.get_text_features(
uncond_input_ids,
attention_mask=negative_attention_mask,
)
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
else:
negative_prompt_embeds = text_encoder(
uncond_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_attention_mask_list.append(negative_attention_mask)
projection_output = projection_model(
hidden_states=negative_prompt_embeds_list[0],
hidden_states_1=negative_prompt_embeds_list[1],
attention_mask=negative_attention_mask_list[0],
attention_mask_1=negative_attention_mask_list[1],
)
negative_projected_prompt_embeds = projection_output.hidden_states
negative_projected_attention_mask = projection_output.attention_mask
negative_generated_prompt_embeds = generate_language_model(
language_model,
negative_projected_prompt_embeds,
attention_mask=negative_projected_attention_mask,
max_new_tokens=max_new_tokens,
)
if do_classifier_free_guidance:
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder_2.dtype, device=device)
negative_attention_mask = (
negative_attention_mask.to(device=device)
if negative_attention_mask is not None
else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
)
negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
dtype=language_model.dtype, device=device
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
# duplicate unconditional attention mask for each generation per prompt
negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
# duplicate unconditional generated embeddings for each generation per prompt
seq_len = negative_generated_prompt_embeds.shape[1]
negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
batch_size * num_waveforms_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask])
generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
return prompt_embeds, attention_mask, generated_prompt_embeds
def prepare_latents(vae, vocoder, scheduler, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
shape = (
batch_size,
num_channels_latents,
height // vae_scale_factor,
vocoder.config.model_in_dim // vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
return latents
def plot_loss(loss_history, loss_plot_path, lora_steps):
plt.figure(figsize=(10, 6))
plt.plot(range(1, lora_steps + 1), loss_history, label="Training Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss Over Steps")
plt.legend()
plt.grid(True)
plt.savefig(loss_plot_path)
plt.close()
# print(f"Loss plot saved to {loss_plot_path}")
# model_path: path of the model
# image: input image, have not been pre-processed
# save_lora_dir: the path to save the lora
# prompt: the user input prompt
# lora_steps: number of lora training step
# lora_lr: learning rate of lora training
# lora_rank: the rank of lora
def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None,
text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
time_pooling = time_pooling
freq_pooling = freq_pooling
# initialize accelerator
# accelerator = Accelerator(
# gradient_accumulation_steps=1,
# mixed_precision='no'
# )
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(0)
# set device and dtype
# prepare accelerator
# unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
# optimizer = accelerator.prepare_optimizer(optimizer)
# lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
GPT2.requires_grad_(False)
projection_model.requires_grad_(False)
vocoder.requires_grad_(False)
unet.requires_grad_(False)
for name, param in text_encoder_2.named_parameters():
if param.requires_grad:
print(name)
for name, param in GPT2.named_parameters():
if param.requires_grad:
print(name)
for name, param in vae.named_parameters():
if param.requires_grad:
print(name)
for name, param in vocoder.named_parameters():
if param.requires_grad:
print(name)
unet.to(device)
vae.to(device)
text_encoder.to(device)
# initialize UNet LoRA
unet_lora_attn_procs = {}
i = 0 # Counter variable to iterate through the cross-attention dimension array.
cross = [None, None, 768, 768, 1024, 1024, None, None] # Predefined cross-attention dimensions for different layers.
do_copy = False
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
else:
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
# if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
# lora_attn_processor_class = LoRAAttnAddedKVProcessor
# else:
# lora_attn_processor_class = (
# LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
# )
if cross_attention_dim is None:
unet_lora_attn_procs[name] = AttnProcessor2_0()
else:
cross_attention_dim = cross[i%8]
i += 1
if cross_attention_dim == 768:
unet_lora_attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
name = name,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=8,
do_copy = do_copy
).to(device, dtype=torch.float32)
else:
unet_lora_attn_procs[name] = AttnProcessor2_0()
unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# Optimizer creation
params_to_optimize = (unet_lora_layers.parameters())
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=lora_lr,
betas=(0.9, 0.999),
weight_decay=1e-2,
eps=1e-08,
)
lr_scheduler = get_scheduler(
"constant",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=lora_steps,
num_cycles=1,
power=1.0,
)
do_classifier_free_guidance = guidance_scale > 1.0
# initialize text embeddings
with torch.no_grad():
prompt_embeds, attention_mask, generated_prompt_embeds = encode_prompt(
tokenizer,
tokenizer_2,
text_encoder,
text_encoder_2,
projection_model,
GPT2,
prompt,
device,
num_waveforms_per_prompt = 1,
do_classifier_free_guidance= do_classifier_free_guidance,
negative_prompt = negative_prompt,
)
waveform, sr = torchaudio.load(audio_path)
fbank = torch.zeros((1024, 128))
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
model = AudioMAEConditionCTPoolRand().to(device).to(dtype=torch.float32)
model.eval()
mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
LOA_embeds = LOA_embed[0]
uncond_LOA_embeds = uncond_LOA_embed[0]
bs_embed, seq_len, _ = LOA_embeds.shape
num = prompt_embeds.shape[0] // 2
LOA_embeds = LOA_embeds.view(bs_embed , seq_len, -1)
LOA_embeds = LOA_embeds.repeat(num, 1, 1)
uncond_LOA_embeds = uncond_LOA_embeds.view(bs_embed , seq_len, -1)
uncond_LOA_embeds = uncond_LOA_embeds.repeat(num, 1, 1)
negative_g, g = generated_prompt_embeds.chunk(2)
uncond = torch.cat([negative_g, uncond_LOA_embeds], dim=1)
cond = torch.cat([g, LOA_embeds], dim=1)
generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
model_dtype = next(unet.parameters()).dtype
generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
# num_channels_latents = unet.config.in_channels
# batch_size = 1
# num_waveforms_per_prompt = 1
# generator = None
# latents = None
# latents = prepare_latents(
# vae,
# vocoder,
# noise_scheduler,
# batch_size * num_waveforms_per_prompt,
# num_channels_latents,
# height,
# prompt_embeds.dtype,
# device,
# generator,
# latents,
# )
loss_history = []
if not os.path.exists(save_lora_dir):
os.makedirs(save_lora_dir)
weight_path = os.path.join(save_lora_dir, weight_name)
base_name, _ = os.path.splitext(weight_path)
save_image_path = f"{base_name}.png"
print(f'Save image path: {save_image_path}')
mel_spect_tensor = wav_to_mel(audio_path, duration = 10).unsqueeze(0).to(next(vae.parameters()).dtype)
for step in progress.tqdm(range(lora_steps), desc="Training LoRA..."):
unet.train()
# with accelerator.accumulate(unet):
latents_dist = vae.encode(mel_spect_tensor.to(device)).latent_dist
model_input = torch.cat([latents_dist.sample()] * 2) if do_classifier_free_guidance else latents_dist.sample()
model_input = model_input * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input).to(model_input.device)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
# Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
generated_prompt_embeds = generated_prompt_embeds.to(device)
prompt_embeds = prompt_embeds.to(device)
attention_mask = attention_mask.to(device)
# Predict the noise residual
model_pred = unet(sample=noisy_model_input,
timestep=timesteps,
encoder_hidden_states=generated_prompt_embeds,
encoder_hidden_states_1=prompt_embeds,
encoder_attention_mask_1=attention_mask,
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="mean")
loss_history.append(loss.item())
loss.requires_grad = True
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# with open(loss_log_path, "w") as f:
# json.dump(loss_history, f)
plot_loss(loss_history, save_image_path, step+1)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None,
weight_name=weight_name,
safe_serialization=safe_serialization
)
def load_lora(unet, lora_0, lora_1, alpha):
attn_procs = unet.attn_processors
for name, processor in attn_procs.items():
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
weight_name_v = name + ".to_v_ip.weight"
weight_name_k = name + ".to_k_ip.weight"
if weight_name_v in lora_0 and weight_name_v in lora_1:
v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half())
if weight_name_k in lora_0 and weight_name_k in lora_1:
k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half())
unet.set_attn_processor(attn_procs)
return unet