Spaces:
Runtime error
Runtime error
CSH-1220
commited on
Commit
·
55f08a9
1
Parent(s):
aef267d
Files update
Browse files- app.py +5 -1
- download.py +9 -0
- pipeline/morph_pipeline_successed_ver1.py +101 -175
- utils/lora_utils_successed_ver1.py +12 -24
app.py
CHANGED
@@ -47,9 +47,13 @@ def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Lo
|
|
47 |
)
|
48 |
|
49 |
# Collect the output file paths
|
50 |
-
output_paths =
|
|
|
|
|
|
|
51 |
return output_paths
|
52 |
|
|
|
53 |
# Gradio interface function
|
54 |
def interface(audio1, audio2, prompt1, prompt2):
|
55 |
output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
|
|
|
47 |
)
|
48 |
|
49 |
# Collect the output file paths
|
50 |
+
output_paths = sorted(
|
51 |
+
[os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")],
|
52 |
+
key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
|
53 |
+
)
|
54 |
return output_paths
|
55 |
|
56 |
+
|
57 |
# Gradio interface function
|
58 |
def interface(audio1, audio2, prompt1, prompt2):
|
59 |
output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
|
download.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
import torch
|
3 |
+
|
4 |
+
model_path = hf_hub_download(
|
5 |
+
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
6 |
+
filename="pytorch_model.bin",
|
7 |
+
local_dir="./",
|
8 |
+
local_dir_use_symlinks=False
|
9 |
+
)
|
pipeline/morph_pipeline_successed_ver1.py
CHANGED
@@ -49,64 +49,12 @@ if is_librosa_available():
|
|
49 |
import librosa
|
50 |
import warnings
|
51 |
import matplotlib.pyplot as plt
|
52 |
-
from huggingface_hub import hf_hub_download
|
53 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
54 |
|
55 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
56 |
-
|
57 |
-
pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
|
58 |
-
pipeline_trained = pipeline_trained.to(DEVICE)
|
59 |
-
layer_num = 0
|
60 |
-
cross = [None, None, 768, 768, 1024, 1024, None, None]
|
61 |
-
unet = pipeline_trained.unet
|
62 |
-
|
63 |
-
|
64 |
-
attn_procs = {}
|
65 |
-
for name in unet.attn_processors.keys():
|
66 |
-
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
67 |
-
if name.startswith("mid_block"):
|
68 |
-
hidden_size = unet.config.block_out_channels[-1]
|
69 |
-
elif name.startswith("up_blocks"):
|
70 |
-
block_id = int(name[len("up_blocks.")])
|
71 |
-
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
72 |
-
elif name.startswith("down_blocks"):
|
73 |
-
block_id = int(name[len("down_blocks.")])
|
74 |
-
hidden_size = unet.config.block_out_channels[block_id]
|
75 |
-
|
76 |
-
if cross_attention_dim is None:
|
77 |
-
attn_procs[name] = AttnProcessor2_0()
|
78 |
-
else:
|
79 |
-
cross_attention_dim = cross[layer_num % 8]
|
80 |
-
layer_num += 1
|
81 |
-
if cross_attention_dim == 768:
|
82 |
-
attn_procs[name] = IPAttnProcessor2_0(
|
83 |
-
hidden_size=hidden_size,
|
84 |
-
name=name,
|
85 |
-
cross_attention_dim=cross_attention_dim,
|
86 |
-
scale=0.5,
|
87 |
-
num_tokens=8,
|
88 |
-
do_copy=False
|
89 |
-
).to(DEVICE, dtype=torch.float32)
|
90 |
-
else:
|
91 |
-
attn_procs[name] = AttnProcessor2_0()
|
92 |
-
|
93 |
-
adapter_weight = hf_hub_download(
|
94 |
-
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
95 |
-
filename="pytorch_model.bin",
|
96 |
-
)
|
97 |
-
|
98 |
-
state_dict = torch.load(adapter_weight, map_location=DEVICE)
|
99 |
-
for name, processor in attn_procs.items():
|
100 |
-
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
101 |
-
weight_name_v = name + ".to_v_ip.weight"
|
102 |
-
weight_name_k = name + ".to_k_ip.weight"
|
103 |
-
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
|
104 |
-
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
|
105 |
-
|
106 |
-
unet.set_attn_processor(attn_procs)
|
107 |
-
unet.to(DEVICE, dtype=torch.float32)
|
108 |
-
|
109 |
|
|
|
|
|
|
|
110 |
|
111 |
|
112 |
def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
|
@@ -125,10 +73,6 @@ def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
|
|
125 |
plt.show()
|
126 |
|
127 |
|
128 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
129 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
130 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
131 |
-
|
132 |
class StoreProcessor():
|
133 |
def __init__(self, original_processor, value_dict, name):
|
134 |
self.original_processor = original_processor
|
@@ -140,12 +84,9 @@ class StoreProcessor():
|
|
140 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
141 |
# Is self attention
|
142 |
if encoder_hidden_states is None:
|
143 |
-
# 將 hidden_states 存入 value_dict 中,名稱為 self.name
|
144 |
-
# 如果輸入沒有 encoder_hidden_states,表示是自注意力層,則將輸入的 hidden_states 儲存在 value_dict 中。
|
145 |
# print(f'In StoreProcessor: {self.name} {self.id}')
|
146 |
self.value_dict[self.name][self.id] = hidden_states.detach()
|
147 |
self.id += 1
|
148 |
-
# 調用原始處理器,執行正常的注意力操作
|
149 |
res = self.original_processor(attn, hidden_states, *args,
|
150 |
encoder_hidden_states=encoder_hidden_states,
|
151 |
attention_mask=attention_mask,
|
@@ -167,32 +108,26 @@ class LoadProcessor():
|
|
167 |
|
168 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
169 |
# Is self attention
|
170 |
-
# 判斷是否是自注意力(self-attention)
|
171 |
if encoder_hidden_states is None:
|
172 |
-
# 如果當前索引小於 10 倍的 self.lamd,使用自定義的混合邏輯
|
173 |
if self.id < 10 * self.lamd:
|
174 |
map0 = self.aud1_dict[self.name][self.id]
|
175 |
map1 = self.aud2_dict[self.name][self.id]
|
176 |
cross_map = self.beta * hidden_states + \
|
177 |
(1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
|
178 |
-
# 調用原始處理器,將 cross_map 作為 encoder_hidden_states 傳入
|
179 |
res = self.original_processor(attn, hidden_states, *args,
|
180 |
encoder_hidden_states=cross_map,
|
181 |
attention_mask=attention_mask,
|
182 |
**kwargs)
|
183 |
else:
|
184 |
-
# 否則,使用原��的 encoder_hidden_states(可能為 None)
|
185 |
res = self.original_processor(attn, hidden_states, *args,
|
186 |
encoder_hidden_states=encoder_hidden_states,
|
187 |
attention_mask=attention_mask,
|
188 |
**kwargs)
|
189 |
|
190 |
self.id += 1
|
191 |
-
# 如果索引到達 self.aud1_dict[self.name] 的長度,重置索引為 0
|
192 |
if self.id == len(self.aud1_dict[self.name]):
|
193 |
self.id = 0
|
194 |
else:
|
195 |
-
# 如果是跨注意力(encoder_hidden_states 不為 None),直接使用原始處理器
|
196 |
res = self.original_processor(attn, hidden_states, *args,
|
197 |
encoder_hidden_states=encoder_hidden_states,
|
198 |
attention_mask=attention_mask,
|
@@ -908,7 +843,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
908 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
909 |
# print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
|
910 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
911 |
-
model = AudioMAEConditionCTPoolRand().
|
912 |
model.eval()
|
913 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
914 |
uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
|
@@ -932,16 +867,66 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
932 |
|
933 |
return prompt_embeds, attention_mask, generated_prompt_embeds
|
934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
935 |
@torch.no_grad()
|
936 |
def aud2latent(self, audio_path, audio_length_in_s):
|
937 |
DEVICE = torch.device(
|
938 |
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
939 |
-
|
940 |
-
# waveform, sr = torchaudio.load(audio_path)
|
941 |
-
# fbank = torch.zeros((height, 64))
|
942 |
-
# ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank, num_mels=64)
|
943 |
-
# mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0).unsqueeze(0)
|
944 |
-
|
945 |
mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
|
946 |
output_path = audio_path.replace('.wav', '_fbank.png')
|
947 |
visualize_mel_spectrogram(mel_spect_tensor, output_path)
|
@@ -954,7 +939,8 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
954 |
@torch.no_grad()
|
955 |
def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
|
956 |
start_step = 0
|
957 |
-
|
|
|
958 |
device = start_latents.device
|
959 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
960 |
start_latents *= self.scheduler.init_noise_sigma
|
@@ -973,9 +959,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
973 |
def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
|
974 |
closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
|
975 |
prompt = (
|
976 |
-
f"
|
977 |
-
f"The sound is closer to '{closer_prompt}' with an interpolation factor of alpha={alpha:.2f}, "
|
978 |
-
f"where alpha=0 represents fully the {prompt_1} and alpha=1 represents fully {prompt_2}."
|
979 |
)
|
980 |
return prompt
|
981 |
|
@@ -983,8 +967,10 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
983 |
def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
|
984 |
prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
|
985 |
alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
|
|
|
986 |
latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
|
987 |
if not use_morph_prompt:
|
|
|
988 |
max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
|
989 |
if prompt_embeds_1.shape[1] < max_length:
|
990 |
pad_size = max_length - prompt_embeds_1.shape[1]
|
@@ -1033,13 +1019,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1033 |
# attention_mask = (attention_mask > 0.5).long()
|
1034 |
|
1035 |
if morphing_with_lora:
|
1036 |
-
pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
1037 |
-
waveform = pipeline_trained(
|
1038 |
time_pooling= time_pooling,
|
1039 |
freq_pooling= freq_pooling,
|
1040 |
latents = latents,
|
1041 |
num_inference_steps= num_inference_steps,
|
1042 |
-
guidance_scale= guidance_scale,
|
1043 |
num_waveforms_per_prompt= 1,
|
1044 |
audio_length_in_s=audio_length_in_s,
|
1045 |
prompt_embeds = prompt_embeds.chunk(2)[1],
|
@@ -1050,13 +1036,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1050 |
negative_attention_mask = attention_mask.chunk(2)[0],
|
1051 |
).audios[0]
|
1052 |
if morphing_with_lora:
|
1053 |
-
pipeline_trained.unet.set_attn_processor(original_processor)
|
1054 |
else:
|
1055 |
latent_model_input = latents
|
1056 |
morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
|
1057 |
if morphing_with_lora:
|
1058 |
-
pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
1059 |
-
waveform = pipeline_trained(
|
1060 |
time_pooling= time_pooling,
|
1061 |
freq_pooling= freq_pooling,
|
1062 |
latents = latent_model_input,
|
@@ -1068,15 +1054,18 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1068 |
negative_prompt= 'Low quality',
|
1069 |
).audios[0]
|
1070 |
if morphing_with_lora:
|
1071 |
-
pipeline_trained.unet.set_attn_processor(original_processor)
|
1072 |
|
1073 |
-
return waveform
|
1074 |
|
1075 |
@torch.no_grad()
|
1076 |
def __call__(
|
1077 |
self,
|
|
|
1078 |
audio_file = None,
|
1079 |
audio_file2 = None,
|
|
|
|
|
1080 |
save_lora_dir = "./lora",
|
1081 |
load_lora_path_1 = None,
|
1082 |
load_lora_path_2 = None,
|
@@ -1100,7 +1089,6 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1100 |
attn_beta=0,
|
1101 |
lamd=0.6,
|
1102 |
fix_lora=None,
|
1103 |
-
save_intermediates=True,
|
1104 |
num_frames=50,
|
1105 |
max_new_tokens: Optional[int] = None,
|
1106 |
callback_steps: Optional[int] = 1,
|
@@ -1108,6 +1096,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1108 |
morphing_with_lora=False,
|
1109 |
use_morph_prompt=False,
|
1110 |
):
|
|
|
1111 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1112 |
# 0. Load the pre-trained AP-adapter model
|
1113 |
layer_num = 0
|
@@ -1123,48 +1112,44 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1123 |
elif name.startswith("down_blocks"):
|
1124 |
block_id = int(name[len("down_blocks.")])
|
1125 |
hidden_size = self.unet.config.block_out_channels[block_id]
|
1126 |
-
|
1127 |
if cross_attention_dim is None:
|
1128 |
attn_procs[name] = AttnProcessor2_0()
|
1129 |
else:
|
1130 |
cross_attention_dim = cross[layer_num % 8]
|
1131 |
layer_num += 1
|
1132 |
if cross_attention_dim == 768:
|
1133 |
-
attn_procs[name] = IPAttnProcessor2_0(
|
1134 |
hidden_size=hidden_size,
|
1135 |
name=name,
|
1136 |
cross_attention_dim=cross_attention_dim,
|
1137 |
-
|
|
|
1138 |
num_tokens=8,
|
1139 |
do_copy=False
|
1140 |
-
).to(
|
1141 |
else:
|
1142 |
attn_procs[name] = AttnProcessor2_0()
|
1143 |
-
|
1144 |
-
state_dict = torch.load(adapter_weight, map_location=device)
|
1145 |
for name, processor in attn_procs.items():
|
1146 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
1147 |
weight_name_v = name + ".to_v_ip.weight"
|
1148 |
weight_name_k = name + ".to_k_ip.weight"
|
1149 |
-
|
1150 |
-
|
|
|
|
|
|
|
|
|
1151 |
self.unet.set_attn_processor(attn_procs)
|
1152 |
-
self.
|
1153 |
-
self.unet = self.unet.to(DEVICE, dtype=torch.float32)
|
1154 |
-
self.language_model = self.language_model.to(DEVICE, dtype=torch.float32)
|
1155 |
-
self.projection_model = self.projection_model.to(DEVICE, dtype=torch.float32)
|
1156 |
-
self.vocoder = self.vocoder.to(DEVICE, dtype=torch.float32)
|
1157 |
-
self.text_encoder = self.text_encoder.to(DEVICE, dtype=torch.float32)
|
1158 |
-
self.text_encoder_2 = self.text_encoder_2.to(DEVICE, dtype=torch.float32)
|
1159 |
|
1160 |
-
|
1161 |
-
|
1162 |
# 1. Pre-check
|
1163 |
height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
|
1164 |
_, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
|
1165 |
# print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
|
1166 |
|
1167 |
# # 2. Define call parameters
|
|
|
1168 |
do_classifier_free_guidance = guidance_scale > 1.0
|
1169 |
self.use_lora = use_lora
|
1170 |
self.use_adain = use_adain
|
@@ -1178,7 +1163,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1178 |
weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
|
1179 |
load_lora_path_1 = save_lora_dir + "/" + weight_name
|
1180 |
if not os.path.exists(load_lora_path_1):
|
1181 |
-
train_lora(audio_file ,
|
1182 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
1183 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
1184 |
print(f"Load from {load_lora_path_1}.")
|
@@ -1193,7 +1178,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1193 |
weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
|
1194 |
load_lora_path_2 = save_lora_dir + "/" + weight_name
|
1195 |
if not os.path.exists(load_lora_path_2):
|
1196 |
-
train_lora(audio_file2 ,
|
1197 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
1198 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
1199 |
print(f"Load from {load_lora_path_2}.")
|
@@ -1212,75 +1197,29 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1212 |
|
1213 |
|
1214 |
# 4. Prepare latent variables
|
1215 |
-
# For the first audio file
|
1216 |
original_processor = list(self.unet.attn_processors.values())[0]
|
1217 |
-
|
1218 |
if noisy_latent_with_lora:
|
1219 |
self.unet = load_lora(self.unet, lora_1, lora_2, 0)
|
1220 |
-
# print(self.unet.attn_processors)
|
1221 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
1222 |
audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
|
1223 |
-
# mel_spectrogram = self.vae.decode(audio_latent).sample
|
1224 |
-
# first_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
1225 |
-
# first_audio = first_audio[:, :original_waveform_length]
|
1226 |
-
# torchaudio.save(f"{self.output_path}/{0:02d}_gt.wav", first_audio, 16000)
|
1227 |
-
|
1228 |
# aud_noise_1 is the noisy latent representation of the audio file 1
|
1229 |
-
aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps)
|
1230 |
-
# We use the pre-trained model to generate the audio file from the noisy latent representation
|
1231 |
-
# waveform = pipeline_trained(
|
1232 |
-
# audio_file = audio_file,
|
1233 |
-
# time_pooling= 2,
|
1234 |
-
# freq_pooling= 2,
|
1235 |
-
# prompt= prompt_1,
|
1236 |
-
# latents = aud_noise_1,
|
1237 |
-
# negative_prompt= negative_prompt_1,
|
1238 |
-
# num_inference_steps= 100,
|
1239 |
-
# guidance_scale= guidance_scale,
|
1240 |
-
# num_waveforms_per_prompt= 1,
|
1241 |
-
# audio_length_in_s=10,
|
1242 |
-
# ).audios
|
1243 |
-
# file_path = os.path.join(self.output_path, f"{0:02d}_gt2.wav")
|
1244 |
-
# scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
|
1245 |
-
|
1246 |
# After reconstructed the audio file 1, we set the original processor back
|
1247 |
if noisy_latent_with_lora:
|
1248 |
self.unet.set_attn_processor(original_processor)
|
1249 |
-
# print(self.unet.attn_processors)
|
1250 |
|
1251 |
-
# For the second audio file
|
1252 |
if noisy_latent_with_lora:
|
1253 |
self.unet = load_lora(self.unet, lora_1, lora_2, 1)
|
1254 |
-
# print(self.unet.attn_processors)
|
1255 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
1256 |
audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
|
1257 |
-
# mel_spectrogram = self.vae.decode(audio_latent).sample
|
1258 |
-
# last_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
1259 |
-
# last_audio = last_audio[:, :original_waveform_length]
|
1260 |
-
# torchaudio.save(f"{self.output_path}/{num_frames-1:02d}_gt.wav", last_audio, 16000)
|
1261 |
# aud_noise_2 is the noisy latent representation of the audio file 2
|
1262 |
-
aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps)
|
1263 |
-
# waveform = pipeline_trained(
|
1264 |
-
# audio_file = audio_file2,
|
1265 |
-
# time_pooling= 2,
|
1266 |
-
# freq_pooling= 2,
|
1267 |
-
# prompt= prompt_2,
|
1268 |
-
# latents = aud_noise_2,
|
1269 |
-
# negative_prompt= negative_prompt_2,
|
1270 |
-
# num_inference_steps= 100,
|
1271 |
-
# guidance_scale= guidance_scale,
|
1272 |
-
# num_waveforms_per_prompt= 1,
|
1273 |
-
# audio_length_in_s=10,
|
1274 |
-
# ).audios
|
1275 |
-
# file_path = os.path.join(self.output_path, f"{num_frames-1:02d}_gt2.wav")
|
1276 |
-
# scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
|
1277 |
if noisy_latent_with_lora:
|
1278 |
self.unet.set_attn_processor(original_processor)
|
1279 |
-
# print(self.unet.attn_processors)
|
1280 |
# After reconstructed the audio file 1, we set the original processor back
|
1281 |
original_processor = list(self.unet.attn_processors.values())[0]
|
1282 |
-
|
1283 |
-
|
1284 |
def morph(alpha_list, desc):
|
1285 |
audios = []
|
1286 |
# if attn_beta is not None:
|
@@ -1288,11 +1227,9 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1288 |
self.unet = load_lora(
|
1289 |
self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
|
1290 |
attn_processor_dict = {}
|
1291 |
-
# print(self.unet.attn_processors)
|
1292 |
for k in self.unet.attn_processors.keys():
|
1293 |
# print(k)
|
1294 |
if do_replace_attn(k):
|
1295 |
-
# print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
|
1296 |
if self.use_lora:
|
1297 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
1298 |
self.aud1_dict, k)
|
@@ -1300,16 +1237,8 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1300 |
attn_processor_dict[k] = StoreProcessor(original_processor,
|
1301 |
self.aud1_dict, k)
|
1302 |
else:
|
1303 |
-
attn_processor_dict[k] = self.unet.attn_processors[k]
|
1304 |
-
|
1305 |
-
|
1306 |
-
# print(attn_processor_dict)
|
1307 |
-
|
1308 |
-
# print(self.unet.attn_processors)
|
1309 |
-
# self.unet.set_attn_processor(attn_processor_dict)
|
1310 |
-
# print(self.unet.attn_processors)
|
1311 |
-
|
1312 |
-
first_audio = self.cal_latent(
|
1313 |
audio_length_in_s,
|
1314 |
time_pooling,
|
1315 |
freq_pooling,
|
@@ -1335,14 +1264,12 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1335 |
self.unet.set_attn_processor(original_processor)
|
1336 |
file_path = os.path.join(self.output_path, f"{0:02d}.wav")
|
1337 |
scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
|
1338 |
-
|
1339 |
if self.use_lora:
|
1340 |
self.unet = load_lora(
|
1341 |
self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
|
1342 |
attn_processor_dict = {}
|
1343 |
for k in self.unet.attn_processors.keys():
|
1344 |
if do_replace_attn(k):
|
1345 |
-
# print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
|
1346 |
if self.use_lora:
|
1347 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
1348 |
self.aud2_dict, k)
|
@@ -1351,8 +1278,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1351 |
self.aud2_dict, k)
|
1352 |
else:
|
1353 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
1354 |
-
|
1355 |
-
last_audio = self.cal_latent(
|
1356 |
audio_length_in_s,
|
1357 |
time_pooling,
|
1358 |
freq_pooling,
|
@@ -1376,6 +1302,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1376 |
)
|
1377 |
file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
|
1378 |
scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
|
|
|
1379 |
self.unet.set_attn_processor(original_processor)
|
1380 |
|
1381 |
for i in tqdm(range(1, num_frames - 1), desc=desc):
|
@@ -1395,8 +1322,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
1395 |
original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
|
1396 |
else:
|
1397 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
1398 |
-
|
1399 |
-
audio = self.cal_latent(
|
1400 |
audio_length_in_s,
|
1401 |
time_pooling,
|
1402 |
freq_pooling,
|
|
|
49 |
import librosa
|
50 |
import warnings
|
51 |
import matplotlib.pyplot as plt
|
|
|
52 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
56 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
57 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
58 |
|
59 |
|
60 |
def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
|
|
|
73 |
plt.show()
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
76 |
class StoreProcessor():
|
77 |
def __init__(self, original_processor, value_dict, name):
|
78 |
self.original_processor = original_processor
|
|
|
84 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
85 |
# Is self attention
|
86 |
if encoder_hidden_states is None:
|
|
|
|
|
87 |
# print(f'In StoreProcessor: {self.name} {self.id}')
|
88 |
self.value_dict[self.name][self.id] = hidden_states.detach()
|
89 |
self.id += 1
|
|
|
90 |
res = self.original_processor(attn, hidden_states, *args,
|
91 |
encoder_hidden_states=encoder_hidden_states,
|
92 |
attention_mask=attention_mask,
|
|
|
108 |
|
109 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
110 |
# Is self attention
|
|
|
111 |
if encoder_hidden_states is None:
|
|
|
112 |
if self.id < 10 * self.lamd:
|
113 |
map0 = self.aud1_dict[self.name][self.id]
|
114 |
map1 = self.aud2_dict[self.name][self.id]
|
115 |
cross_map = self.beta * hidden_states + \
|
116 |
(1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
|
|
|
117 |
res = self.original_processor(attn, hidden_states, *args,
|
118 |
encoder_hidden_states=cross_map,
|
119 |
attention_mask=attention_mask,
|
120 |
**kwargs)
|
121 |
else:
|
|
|
122 |
res = self.original_processor(attn, hidden_states, *args,
|
123 |
encoder_hidden_states=encoder_hidden_states,
|
124 |
attention_mask=attention_mask,
|
125 |
**kwargs)
|
126 |
|
127 |
self.id += 1
|
|
|
128 |
if self.id == len(self.aud1_dict[self.name]):
|
129 |
self.id = 0
|
130 |
else:
|
|
|
131 |
res = self.original_processor(attn, hidden_states, *args,
|
132 |
encoder_hidden_states=encoder_hidden_states,
|
133 |
attention_mask=attention_mask,
|
|
|
843 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
844 |
# print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
|
845 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
846 |
+
model = AudioMAEConditionCTPoolRand().cuda()
|
847 |
model.eval()
|
848 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
849 |
uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
|
|
|
867 |
|
868 |
return prompt_embeds, attention_mask, generated_prompt_embeds
|
869 |
|
870 |
+
def init_trained_pipeline(self, model_path, device, dtype, ap_scale, text_ap_scale):
|
871 |
+
pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=dtype).to(device)
|
872 |
+
layer_num = 0
|
873 |
+
cross = [None, None, 768, 768, 1024, 1024, None, None]
|
874 |
+
unet = pipeline_trained.unet
|
875 |
+
attn_procs = {}
|
876 |
+
for name in unet.attn_processors.keys():
|
877 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
878 |
+
if name.startswith("mid_block"):
|
879 |
+
hidden_size = unet.config.block_out_channels[-1]
|
880 |
+
elif name.startswith("up_blocks"):
|
881 |
+
block_id = int(name[len("up_blocks.")])
|
882 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
883 |
+
elif name.startswith("down_blocks"):
|
884 |
+
block_id = int(name[len("down_blocks.")])
|
885 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
886 |
+
|
887 |
+
if cross_attention_dim is None:
|
888 |
+
attn_procs[name] = AttnProcessor2_0()
|
889 |
+
else:
|
890 |
+
cross_attention_dim = cross[layer_num % 8]
|
891 |
+
layer_num += 1
|
892 |
+
if cross_attention_dim == 768:
|
893 |
+
attn_procs[name] = IPAttnProcessor2_0(
|
894 |
+
hidden_size=hidden_size,
|
895 |
+
name=name,
|
896 |
+
flag='trained',
|
897 |
+
cross_attention_dim=cross_attention_dim,
|
898 |
+
text_scale=text_ap_scale,
|
899 |
+
scale=ap_scale,
|
900 |
+
num_tokens=8,
|
901 |
+
do_copy=False
|
902 |
+
).to(device, dtype=dtype)
|
903 |
+
else:
|
904 |
+
attn_procs[name] = AttnProcessor2_0()
|
905 |
+
|
906 |
+
state_dict = torch.load(model_path, map_location=device)
|
907 |
+
for name, processor in attn_procs.items():
|
908 |
+
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
909 |
+
weight_name_v = name + ".to_v_ip.weight"
|
910 |
+
weight_name_k = name + ".to_k_ip.weight"
|
911 |
+
if dtype == torch.float32:
|
912 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].float())
|
913 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].float())
|
914 |
+
elif dtype == torch.float16:
|
915 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
|
916 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
|
917 |
+
unet.set_attn_processor(attn_procs)
|
918 |
+
class _Wrapper(AttnProcsLayers):
|
919 |
+
def forward(self, *args, **kwargs):
|
920 |
+
return unet(*args, **kwargs)
|
921 |
+
|
922 |
+
unet = _Wrapper(unet.attn_processors)
|
923 |
+
|
924 |
+
return pipeline_trained
|
925 |
+
|
926 |
@torch.no_grad()
|
927 |
def aud2latent(self, audio_path, audio_length_in_s):
|
928 |
DEVICE = torch.device(
|
929 |
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
930 |
mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
|
931 |
output_path = audio_path.replace('.wav', '_fbank.png')
|
932 |
visualize_mel_spectrogram(mel_spect_tensor, output_path)
|
|
|
939 |
@torch.no_grad()
|
940 |
def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
|
941 |
start_step = 0
|
942 |
+
# print(f"Scheduler timesteps: {self.scheduler.timesteps}")
|
943 |
+
num_inference_steps = min(num_inference_steps, int(max(self.scheduler.timesteps)))
|
944 |
device = start_latents.device
|
945 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
946 |
start_latents *= self.scheduler.init_noise_sigma
|
|
|
959 |
def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
|
960 |
closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
|
961 |
prompt = (
|
962 |
+
f"Jazz style music"
|
|
|
|
|
963 |
)
|
964 |
return prompt
|
965 |
|
|
|
967 |
def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
|
968 |
prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
|
969 |
alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
|
970 |
+
num_inference_steps = min(num_inference_steps, int(max(self.pipeline_trained.scheduler.timesteps)))
|
971 |
latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
|
972 |
if not use_morph_prompt:
|
973 |
+
print("Not using morphing prompt")
|
974 |
max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
|
975 |
if prompt_embeds_1.shape[1] < max_length:
|
976 |
pad_size = max_length - prompt_embeds_1.shape[1]
|
|
|
1019 |
# attention_mask = (attention_mask > 0.5).long()
|
1020 |
|
1021 |
if morphing_with_lora:
|
1022 |
+
self.pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
1023 |
+
waveform = self.pipeline_trained(
|
1024 |
time_pooling= time_pooling,
|
1025 |
freq_pooling= freq_pooling,
|
1026 |
latents = latents,
|
1027 |
num_inference_steps= num_inference_steps,
|
1028 |
+
guidance_scale = guidance_scale,
|
1029 |
num_waveforms_per_prompt= 1,
|
1030 |
audio_length_in_s=audio_length_in_s,
|
1031 |
prompt_embeds = prompt_embeds.chunk(2)[1],
|
|
|
1036 |
negative_attention_mask = attention_mask.chunk(2)[0],
|
1037 |
).audios[0]
|
1038 |
if morphing_with_lora:
|
1039 |
+
self.pipeline_trained.unet.set_attn_processor(original_processor)
|
1040 |
else:
|
1041 |
latent_model_input = latents
|
1042 |
morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
|
1043 |
if morphing_with_lora:
|
1044 |
+
self.pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
1045 |
+
waveform = self.pipeline_trained(
|
1046 |
time_pooling= time_pooling,
|
1047 |
freq_pooling= freq_pooling,
|
1048 |
latents = latent_model_input,
|
|
|
1054 |
negative_prompt= 'Low quality',
|
1055 |
).audios[0]
|
1056 |
if morphing_with_lora:
|
1057 |
+
self.pipeline_trained.unet.set_attn_processor(original_processor)
|
1058 |
|
1059 |
+
return waveform, latents
|
1060 |
|
1061 |
@torch.no_grad()
|
1062 |
def __call__(
|
1063 |
self,
|
1064 |
+
dtype,
|
1065 |
audio_file = None,
|
1066 |
audio_file2 = None,
|
1067 |
+
ap_scale = 1.0,
|
1068 |
+
text_ap_scale = 1.0,
|
1069 |
save_lora_dir = "./lora",
|
1070 |
load_lora_path_1 = None,
|
1071 |
load_lora_path_2 = None,
|
|
|
1089 |
attn_beta=0,
|
1090 |
lamd=0.6,
|
1091 |
fix_lora=None,
|
|
|
1092 |
num_frames=50,
|
1093 |
max_new_tokens: Optional[int] = None,
|
1094 |
callback_steps: Optional[int] = 1,
|
|
|
1096 |
morphing_with_lora=False,
|
1097 |
use_morph_prompt=False,
|
1098 |
):
|
1099 |
+
ap_adapter_path = 'pytorch_model.bin'
|
1100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1101 |
# 0. Load the pre-trained AP-adapter model
|
1102 |
layer_num = 0
|
|
|
1112 |
elif name.startswith("down_blocks"):
|
1113 |
block_id = int(name[len("down_blocks.")])
|
1114 |
hidden_size = self.unet.config.block_out_channels[block_id]
|
|
|
1115 |
if cross_attention_dim is None:
|
1116 |
attn_procs[name] = AttnProcessor2_0()
|
1117 |
else:
|
1118 |
cross_attention_dim = cross[layer_num % 8]
|
1119 |
layer_num += 1
|
1120 |
if cross_attention_dim == 768:
|
1121 |
+
attn_procs[name].scale = IPAttnProcessor2_0(
|
1122 |
hidden_size=hidden_size,
|
1123 |
name=name,
|
1124 |
cross_attention_dim=cross_attention_dim,
|
1125 |
+
text_scale=100,
|
1126 |
+
scale=ap_scale,
|
1127 |
num_tokens=8,
|
1128 |
do_copy=False
|
1129 |
+
).to(device, dtype=dtype)
|
1130 |
else:
|
1131 |
attn_procs[name] = AttnProcessor2_0()
|
1132 |
+
state_dict = torch.load(ap_adapter_path, map_location=device)
|
|
|
1133 |
for name, processor in attn_procs.items():
|
1134 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
1135 |
weight_name_v = name + ".to_v_ip.weight"
|
1136 |
weight_name_k = name + ".to_k_ip.weight"
|
1137 |
+
if dtype == torch.float32:
|
1138 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].float())
|
1139 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].float())
|
1140 |
+
elif dtype == torch.float16:
|
1141 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
|
1142 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
|
1143 |
self.unet.set_attn_processor(attn_procs)
|
1144 |
+
self.pipeline_trained = self.init_trained_pipeline(ap_adapter_path, device, dtype, ap_scale, text_ap_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
1145 |
|
|
|
|
|
1146 |
# 1. Pre-check
|
1147 |
height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
|
1148 |
_, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
|
1149 |
# print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
|
1150 |
|
1151 |
# # 2. Define call parameters
|
1152 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
1153 |
do_classifier_free_guidance = guidance_scale > 1.0
|
1154 |
self.use_lora = use_lora
|
1155 |
self.use_adain = use_adain
|
|
|
1163 |
weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
|
1164 |
load_lora_path_1 = save_lora_dir + "/" + weight_name
|
1165 |
if not os.path.exists(load_lora_path_1):
|
1166 |
+
train_lora(audio_file, dtype, time_pooling ,freq_pooling ,prompt_1, negative_prompt_1, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
|
1167 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
1168 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
1169 |
print(f"Load from {load_lora_path_1}.")
|
|
|
1178 |
weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
|
1179 |
load_lora_path_2 = save_lora_dir + "/" + weight_name
|
1180 |
if not os.path.exists(load_lora_path_2):
|
1181 |
+
train_lora(audio_file2, dtype,time_pooling ,freq_pooling ,prompt_2, negative_prompt_2, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
|
1182 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
1183 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
1184 |
print(f"Load from {load_lora_path_2}.")
|
|
|
1197 |
|
1198 |
|
1199 |
# 4. Prepare latent variables
|
1200 |
+
# ------- For the first audio file -------
|
1201 |
original_processor = list(self.unet.attn_processors.values())[0]
|
|
|
1202 |
if noisy_latent_with_lora:
|
1203 |
self.unet = load_lora(self.unet, lora_1, lora_2, 0)
|
|
|
1204 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
1205 |
audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
|
|
|
|
|
|
|
|
|
|
|
1206 |
# aud_noise_1 is the noisy latent representation of the audio file 1
|
1207 |
+
aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps = num_inference_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1208 |
# After reconstructed the audio file 1, we set the original processor back
|
1209 |
if noisy_latent_with_lora:
|
1210 |
self.unet.set_attn_processor(original_processor)
|
|
|
1211 |
|
1212 |
+
# ------- For the second audio file -------
|
1213 |
if noisy_latent_with_lora:
|
1214 |
self.unet = load_lora(self.unet, lora_1, lora_2, 1)
|
|
|
1215 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
1216 |
audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
|
|
|
|
|
|
|
|
|
1217 |
# aud_noise_2 is the noisy latent representation of the audio file 2
|
1218 |
+
aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps = num_inference_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1219 |
if noisy_latent_with_lora:
|
1220 |
self.unet.set_attn_processor(original_processor)
|
|
|
1221 |
# After reconstructed the audio file 1, we set the original processor back
|
1222 |
original_processor = list(self.unet.attn_processors.values())[0]
|
|
|
|
|
1223 |
def morph(alpha_list, desc):
|
1224 |
audios = []
|
1225 |
# if attn_beta is not None:
|
|
|
1227 |
self.unet = load_lora(
|
1228 |
self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
|
1229 |
attn_processor_dict = {}
|
|
|
1230 |
for k in self.unet.attn_processors.keys():
|
1231 |
# print(k)
|
1232 |
if do_replace_attn(k):
|
|
|
1233 |
if self.use_lora:
|
1234 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
1235 |
self.aud1_dict, k)
|
|
|
1237 |
attn_processor_dict[k] = StoreProcessor(original_processor,
|
1238 |
self.aud1_dict, k)
|
1239 |
else:
|
1240 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
1241 |
+
first_audio, first_latents = self.cal_latent(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1242 |
audio_length_in_s,
|
1243 |
time_pooling,
|
1244 |
freq_pooling,
|
|
|
1264 |
self.unet.set_attn_processor(original_processor)
|
1265 |
file_path = os.path.join(self.output_path, f"{0:02d}.wav")
|
1266 |
scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
|
|
|
1267 |
if self.use_lora:
|
1268 |
self.unet = load_lora(
|
1269 |
self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
|
1270 |
attn_processor_dict = {}
|
1271 |
for k in self.unet.attn_processors.keys():
|
1272 |
if do_replace_attn(k):
|
|
|
1273 |
if self.use_lora:
|
1274 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
1275 |
self.aud2_dict, k)
|
|
|
1278 |
self.aud2_dict, k)
|
1279 |
else:
|
1280 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
1281 |
+
last_audio, last_latents = self.cal_latent(
|
|
|
1282 |
audio_length_in_s,
|
1283 |
time_pooling,
|
1284 |
freq_pooling,
|
|
|
1302 |
)
|
1303 |
file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
|
1304 |
scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
|
1305 |
+
|
1306 |
self.unet.set_attn_processor(original_processor)
|
1307 |
|
1308 |
for i in tqdm(range(1, num_frames - 1), desc=desc):
|
|
|
1322 |
original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
|
1323 |
else:
|
1324 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
1325 |
+
audio, latents = self.cal_latent(
|
|
|
1326 |
audio_length_in_s,
|
1327 |
time_pooling,
|
1328 |
freq_pooling,
|
utils/lora_utils_successed_ver1.py
CHANGED
@@ -449,7 +449,7 @@ def plot_loss(loss_history, loss_plot_path, lora_steps):
|
|
449 |
# lora_steps: number of lora training step
|
450 |
# lora_lr: learning rate of lora training
|
451 |
# lora_rank: the rank of lora
|
452 |
-
def train_lora(audio_path ,
|
453 |
text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
|
454 |
vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
|
455 |
time_pooling = time_pooling
|
@@ -534,7 +534,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
534 |
scale=1.0,
|
535 |
num_tokens=8,
|
536 |
do_copy = do_copy
|
537 |
-
).to(device, dtype=
|
538 |
else:
|
539 |
unet_lora_attn_procs[name] = AttnProcessor2_0()
|
540 |
unet.set_attn_processor(unet_lora_attn_procs)
|
@@ -580,7 +580,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
580 |
fbank = torch.zeros((1024, 128))
|
581 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
582 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
583 |
-
model = AudioMAEConditionCTPoolRand().to(device).to(dtype=
|
584 |
model.eval()
|
585 |
mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
|
586 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
@@ -599,24 +599,6 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
599 |
generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
|
600 |
model_dtype = next(unet.parameters()).dtype
|
601 |
generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
|
602 |
-
|
603 |
-
# num_channels_latents = unet.config.in_channels
|
604 |
-
# batch_size = 1
|
605 |
-
# num_waveforms_per_prompt = 1
|
606 |
-
# generator = None
|
607 |
-
# latents = None
|
608 |
-
# latents = prepare_latents(
|
609 |
-
# vae,
|
610 |
-
# vocoder,
|
611 |
-
# noise_scheduler,
|
612 |
-
# batch_size * num_waveforms_per_prompt,
|
613 |
-
# num_channels_latents,
|
614 |
-
# height,
|
615 |
-
# prompt_embeds.dtype,
|
616 |
-
# device,
|
617 |
-
# generator,
|
618 |
-
# latents,
|
619 |
-
# )
|
620 |
|
621 |
loss_history = []
|
622 |
if not os.path.exists(save_lora_dir):
|
@@ -683,7 +665,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
683 |
safe_serialization=safe_serialization
|
684 |
)
|
685 |
|
686 |
-
def load_lora(unet, lora_0, lora_1, alpha):
|
687 |
attn_procs = unet.attn_processors
|
688 |
for name, processor in attn_procs.items():
|
689 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
@@ -691,10 +673,16 @@ def load_lora(unet, lora_0, lora_1, alpha):
|
|
691 |
weight_name_k = name + ".to_k_ip.weight"
|
692 |
if weight_name_v in lora_0 and weight_name_v in lora_1:
|
693 |
v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
|
694 |
-
|
|
|
|
|
|
|
695 |
|
696 |
if weight_name_k in lora_0 and weight_name_k in lora_1:
|
697 |
k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
|
698 |
-
|
|
|
|
|
|
|
699 |
unet.set_attn_processor(attn_procs)
|
700 |
return unet
|
|
|
449 |
# lora_steps: number of lora training step
|
450 |
# lora_lr: learning rate of lora training
|
451 |
# lora_rank: the rank of lora
|
452 |
+
def train_lora(audio_path ,dtype ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None,
|
453 |
text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
|
454 |
vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
|
455 |
time_pooling = time_pooling
|
|
|
534 |
scale=1.0,
|
535 |
num_tokens=8,
|
536 |
do_copy = do_copy
|
537 |
+
).to(device, dtype=dtype)
|
538 |
else:
|
539 |
unet_lora_attn_procs[name] = AttnProcessor2_0()
|
540 |
unet.set_attn_processor(unet_lora_attn_procs)
|
|
|
580 |
fbank = torch.zeros((1024, 128))
|
581 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
582 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
583 |
+
model = AudioMAEConditionCTPoolRand().to(device).to(dtype=dtype)
|
584 |
model.eval()
|
585 |
mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
|
586 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
|
|
599 |
generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
|
600 |
model_dtype = next(unet.parameters()).dtype
|
601 |
generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
|
603 |
loss_history = []
|
604 |
if not os.path.exists(save_lora_dir):
|
|
|
665 |
safe_serialization=safe_serialization
|
666 |
)
|
667 |
|
668 |
+
def load_lora(unet, lora_0, lora_1, alpha, dtype):
|
669 |
attn_procs = unet.attn_processors
|
670 |
for name, processor in attn_procs.items():
|
671 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
|
|
673 |
weight_name_k = name + ".to_k_ip.weight"
|
674 |
if weight_name_v in lora_0 and weight_name_v in lora_1:
|
675 |
v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
|
676 |
+
if dtype == torch.float32:
|
677 |
+
processor.to_v_ip.weight = torch.nn.Parameter(v_weight.float())
|
678 |
+
elif dtype == torch.float16:
|
679 |
+
processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half())
|
680 |
|
681 |
if weight_name_k in lora_0 and weight_name_k in lora_1:
|
682 |
k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
|
683 |
+
if dtype == torch.float32:
|
684 |
+
processor.to_k_ip.weight = torch.nn.Parameter(k_weight.float())
|
685 |
+
elif dtype == torch.float16:
|
686 |
+
processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half())
|
687 |
unet.set_attn_processor(attn_procs)
|
688 |
return unet
|