CSH-1220 commited on
Commit
55f08a9
·
1 Parent(s): aef267d

Files update

Browse files
app.py CHANGED
@@ -47,9 +47,13 @@ def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Lo
47
  )
48
 
49
  # Collect the output file paths
50
- output_paths = [os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")]
 
 
 
51
  return output_paths
52
 
 
53
  # Gradio interface function
54
  def interface(audio1, audio2, prompt1, prompt2):
55
  output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
 
47
  )
48
 
49
  # Collect the output file paths
50
+ output_paths = sorted(
51
+ [os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")],
52
+ key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
53
+ )
54
  return output_paths
55
 
56
+
57
  # Gradio interface function
58
  def interface(audio1, audio2, prompt1, prompt2):
59
  output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
download.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import torch
3
+
4
+ model_path = hf_hub_download(
5
+ repo_id="DennisHung/Pre-trained_AudioMAE_weights",
6
+ filename="pytorch_model.bin",
7
+ local_dir="./",
8
+ local_dir_use_symlinks=False
9
+ )
pipeline/morph_pipeline_successed_ver1.py CHANGED
@@ -49,64 +49,12 @@ if is_librosa_available():
49
  import librosa
50
  import warnings
51
  import matplotlib.pyplot as plt
52
- from huggingface_hub import hf_hub_download
53
  from .pipeline_audioldm2 import AudioLDM2Pipeline
54
 
55
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
56
-
57
- pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
58
- pipeline_trained = pipeline_trained.to(DEVICE)
59
- layer_num = 0
60
- cross = [None, None, 768, 768, 1024, 1024, None, None]
61
- unet = pipeline_trained.unet
62
-
63
-
64
- attn_procs = {}
65
- for name in unet.attn_processors.keys():
66
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
67
- if name.startswith("mid_block"):
68
- hidden_size = unet.config.block_out_channels[-1]
69
- elif name.startswith("up_blocks"):
70
- block_id = int(name[len("up_blocks.")])
71
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
72
- elif name.startswith("down_blocks"):
73
- block_id = int(name[len("down_blocks.")])
74
- hidden_size = unet.config.block_out_channels[block_id]
75
-
76
- if cross_attention_dim is None:
77
- attn_procs[name] = AttnProcessor2_0()
78
- else:
79
- cross_attention_dim = cross[layer_num % 8]
80
- layer_num += 1
81
- if cross_attention_dim == 768:
82
- attn_procs[name] = IPAttnProcessor2_0(
83
- hidden_size=hidden_size,
84
- name=name,
85
- cross_attention_dim=cross_attention_dim,
86
- scale=0.5,
87
- num_tokens=8,
88
- do_copy=False
89
- ).to(DEVICE, dtype=torch.float32)
90
- else:
91
- attn_procs[name] = AttnProcessor2_0()
92
-
93
- adapter_weight = hf_hub_download(
94
- repo_id="DennisHung/Pre-trained_AudioMAE_weights",
95
- filename="pytorch_model.bin",
96
- )
97
-
98
- state_dict = torch.load(adapter_weight, map_location=DEVICE)
99
- for name, processor in attn_procs.items():
100
- if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
101
- weight_name_v = name + ".to_v_ip.weight"
102
- weight_name_k = name + ".to_k_ip.weight"
103
- processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
104
- processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
105
-
106
- unet.set_attn_processor(attn_procs)
107
- unet.to(DEVICE, dtype=torch.float32)
108
-
109
 
 
 
 
110
 
111
 
112
  def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
@@ -125,10 +73,6 @@ def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
125
  plt.show()
126
 
127
 
128
- warnings.filterwarnings("ignore", category=FutureWarning)
129
- warnings.filterwarnings("ignore", category=UserWarning)
130
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
131
-
132
  class StoreProcessor():
133
  def __init__(self, original_processor, value_dict, name):
134
  self.original_processor = original_processor
@@ -140,12 +84,9 @@ class StoreProcessor():
140
  def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
141
  # Is self attention
142
  if encoder_hidden_states is None:
143
- # 將 hidden_states 存入 value_dict 中,名稱為 self.name
144
- # 如果輸入沒有 encoder_hidden_states,表示是自注意力層,則將輸入的 hidden_states 儲存在 value_dict 中。
145
  # print(f'In StoreProcessor: {self.name} {self.id}')
146
  self.value_dict[self.name][self.id] = hidden_states.detach()
147
  self.id += 1
148
- # 調用原始處理器,執行正常的注意力操作
149
  res = self.original_processor(attn, hidden_states, *args,
150
  encoder_hidden_states=encoder_hidden_states,
151
  attention_mask=attention_mask,
@@ -167,32 +108,26 @@ class LoadProcessor():
167
 
168
  def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
169
  # Is self attention
170
- # 判斷是否是自注意力(self-attention)
171
  if encoder_hidden_states is None:
172
- # 如果當前索引小於 10 倍的 self.lamd,使用自定義的混合邏輯
173
  if self.id < 10 * self.lamd:
174
  map0 = self.aud1_dict[self.name][self.id]
175
  map1 = self.aud2_dict[self.name][self.id]
176
  cross_map = self.beta * hidden_states + \
177
  (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
178
- # 調用原始處理器,將 cross_map 作為 encoder_hidden_states 傳入
179
  res = self.original_processor(attn, hidden_states, *args,
180
  encoder_hidden_states=cross_map,
181
  attention_mask=attention_mask,
182
  **kwargs)
183
  else:
184
- # 否則,使用原��的 encoder_hidden_states(可能為 None)
185
  res = self.original_processor(attn, hidden_states, *args,
186
  encoder_hidden_states=encoder_hidden_states,
187
  attention_mask=attention_mask,
188
  **kwargs)
189
 
190
  self.id += 1
191
- # 如果索引到達 self.aud1_dict[self.name] 的長度,重置索引為 0
192
  if self.id == len(self.aud1_dict[self.name]):
193
  self.id = 0
194
  else:
195
- # 如果是跨注意力(encoder_hidden_states 不為 None),直接使用原始處理器
196
  res = self.original_processor(attn, hidden_states, *args,
197
  encoder_hidden_states=encoder_hidden_states,
198
  attention_mask=attention_mask,
@@ -908,7 +843,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
908
  ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
909
  # print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
910
  mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
911
- model = AudioMAEConditionCTPoolRand().to(next(self.unet.parameters()).device)
912
  model.eval()
913
  LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
914
  uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
@@ -932,16 +867,66 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
932
 
933
  return prompt_embeds, attention_mask, generated_prompt_embeds
934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  @torch.no_grad()
936
  def aud2latent(self, audio_path, audio_length_in_s):
937
  DEVICE = torch.device(
938
  "cuda") if torch.cuda.is_available() else torch.device("cpu")
939
-
940
- # waveform, sr = torchaudio.load(audio_path)
941
- # fbank = torch.zeros((height, 64))
942
- # ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank, num_mels=64)
943
- # mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0).unsqueeze(0)
944
-
945
  mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
946
  output_path = audio_path.replace('.wav', '_fbank.png')
947
  visualize_mel_spectrogram(mel_spect_tensor, output_path)
@@ -954,7 +939,8 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
954
  @torch.no_grad()
955
  def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
956
  start_step = 0
957
- num_inference_steps = num_inference_steps
 
958
  device = start_latents.device
959
  self.scheduler.set_timesteps(num_inference_steps, device=device)
960
  start_latents *= self.scheduler.init_noise_sigma
@@ -973,9 +959,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
973
  def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
974
  closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
975
  prompt = (
976
- f"A musical performance morphing between '{prompt_1}' and '{prompt_2}'. "
977
- f"The sound is closer to '{closer_prompt}' with an interpolation factor of alpha={alpha:.2f}, "
978
- f"where alpha=0 represents fully the {prompt_1} and alpha=1 represents fully {prompt_2}."
979
  )
980
  return prompt
981
 
@@ -983,8 +967,10 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
983
  def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
984
  prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
985
  alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
 
986
  latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
987
  if not use_morph_prompt:
 
988
  max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
989
  if prompt_embeds_1.shape[1] < max_length:
990
  pad_size = max_length - prompt_embeds_1.shape[1]
@@ -1033,13 +1019,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1033
  # attention_mask = (attention_mask > 0.5).long()
1034
 
1035
  if morphing_with_lora:
1036
- pipeline_trained.unet.set_attn_processor(attn_processor_dict)
1037
- waveform = pipeline_trained(
1038
  time_pooling= time_pooling,
1039
  freq_pooling= freq_pooling,
1040
  latents = latents,
1041
  num_inference_steps= num_inference_steps,
1042
- guidance_scale= guidance_scale,
1043
  num_waveforms_per_prompt= 1,
1044
  audio_length_in_s=audio_length_in_s,
1045
  prompt_embeds = prompt_embeds.chunk(2)[1],
@@ -1050,13 +1036,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1050
  negative_attention_mask = attention_mask.chunk(2)[0],
1051
  ).audios[0]
1052
  if morphing_with_lora:
1053
- pipeline_trained.unet.set_attn_processor(original_processor)
1054
  else:
1055
  latent_model_input = latents
1056
  morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
1057
  if morphing_with_lora:
1058
- pipeline_trained.unet.set_attn_processor(attn_processor_dict)
1059
- waveform = pipeline_trained(
1060
  time_pooling= time_pooling,
1061
  freq_pooling= freq_pooling,
1062
  latents = latent_model_input,
@@ -1068,15 +1054,18 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1068
  negative_prompt= 'Low quality',
1069
  ).audios[0]
1070
  if morphing_with_lora:
1071
- pipeline_trained.unet.set_attn_processor(original_processor)
1072
 
1073
- return waveform
1074
 
1075
  @torch.no_grad()
1076
  def __call__(
1077
  self,
 
1078
  audio_file = None,
1079
  audio_file2 = None,
 
 
1080
  save_lora_dir = "./lora",
1081
  load_lora_path_1 = None,
1082
  load_lora_path_2 = None,
@@ -1100,7 +1089,6 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1100
  attn_beta=0,
1101
  lamd=0.6,
1102
  fix_lora=None,
1103
- save_intermediates=True,
1104
  num_frames=50,
1105
  max_new_tokens: Optional[int] = None,
1106
  callback_steps: Optional[int] = 1,
@@ -1108,6 +1096,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1108
  morphing_with_lora=False,
1109
  use_morph_prompt=False,
1110
  ):
 
1111
  device = "cuda" if torch.cuda.is_available() else "cpu"
1112
  # 0. Load the pre-trained AP-adapter model
1113
  layer_num = 0
@@ -1123,48 +1112,44 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1123
  elif name.startswith("down_blocks"):
1124
  block_id = int(name[len("down_blocks.")])
1125
  hidden_size = self.unet.config.block_out_channels[block_id]
1126
-
1127
  if cross_attention_dim is None:
1128
  attn_procs[name] = AttnProcessor2_0()
1129
  else:
1130
  cross_attention_dim = cross[layer_num % 8]
1131
  layer_num += 1
1132
  if cross_attention_dim == 768:
1133
- attn_procs[name] = IPAttnProcessor2_0(
1134
  hidden_size=hidden_size,
1135
  name=name,
1136
  cross_attention_dim=cross_attention_dim,
1137
- scale=0.5,
 
1138
  num_tokens=8,
1139
  do_copy=False
1140
- ).to(DEVICE, dtype=torch.float32)
1141
  else:
1142
  attn_procs[name] = AttnProcessor2_0()
1143
-
1144
- state_dict = torch.load(adapter_weight, map_location=device)
1145
  for name, processor in attn_procs.items():
1146
  if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
1147
  weight_name_v = name + ".to_v_ip.weight"
1148
  weight_name_k = name + ".to_k_ip.weight"
1149
- processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
1150
- processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
 
 
 
 
1151
  self.unet.set_attn_processor(attn_procs)
1152
- self.vae= self.vae.to(DEVICE, dtype=torch.float32)
1153
- self.unet = self.unet.to(DEVICE, dtype=torch.float32)
1154
- self.language_model = self.language_model.to(DEVICE, dtype=torch.float32)
1155
- self.projection_model = self.projection_model.to(DEVICE, dtype=torch.float32)
1156
- self.vocoder = self.vocoder.to(DEVICE, dtype=torch.float32)
1157
- self.text_encoder = self.text_encoder.to(DEVICE, dtype=torch.float32)
1158
- self.text_encoder_2 = self.text_encoder_2.to(DEVICE, dtype=torch.float32)
1159
 
1160
-
1161
-
1162
  # 1. Pre-check
1163
  height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
1164
  _, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
1165
  # print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
1166
 
1167
  # # 2. Define call parameters
 
1168
  do_classifier_free_guidance = guidance_scale > 1.0
1169
  self.use_lora = use_lora
1170
  self.use_adain = use_adain
@@ -1178,7 +1163,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1178
  weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
1179
  load_lora_path_1 = save_lora_dir + "/" + weight_name
1180
  if not os.path.exists(load_lora_path_1):
1181
- train_lora(audio_file ,height ,time_pooling ,freq_pooling ,prompt_1, negative_prompt_1, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
1182
  self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
1183
  self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
1184
  print(f"Load from {load_lora_path_1}.")
@@ -1193,7 +1178,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1193
  weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
1194
  load_lora_path_2 = save_lora_dir + "/" + weight_name
1195
  if not os.path.exists(load_lora_path_2):
1196
- train_lora(audio_file2 ,height,time_pooling ,freq_pooling ,prompt_2, negative_prompt_2, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
1197
  self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
1198
  self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
1199
  print(f"Load from {load_lora_path_2}.")
@@ -1212,75 +1197,29 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1212
 
1213
 
1214
  # 4. Prepare latent variables
1215
- # For the first audio file
1216
  original_processor = list(self.unet.attn_processors.values())[0]
1217
-
1218
  if noisy_latent_with_lora:
1219
  self.unet = load_lora(self.unet, lora_1, lora_2, 0)
1220
- # print(self.unet.attn_processors)
1221
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1222
  audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
1223
- # mel_spectrogram = self.vae.decode(audio_latent).sample
1224
- # first_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
1225
- # first_audio = first_audio[:, :original_waveform_length]
1226
- # torchaudio.save(f"{self.output_path}/{0:02d}_gt.wav", first_audio, 16000)
1227
-
1228
  # aud_noise_1 is the noisy latent representation of the audio file 1
1229
- aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps)
1230
- # We use the pre-trained model to generate the audio file from the noisy latent representation
1231
- # waveform = pipeline_trained(
1232
- # audio_file = audio_file,
1233
- # time_pooling= 2,
1234
- # freq_pooling= 2,
1235
- # prompt= prompt_1,
1236
- # latents = aud_noise_1,
1237
- # negative_prompt= negative_prompt_1,
1238
- # num_inference_steps= 100,
1239
- # guidance_scale= guidance_scale,
1240
- # num_waveforms_per_prompt= 1,
1241
- # audio_length_in_s=10,
1242
- # ).audios
1243
- # file_path = os.path.join(self.output_path, f"{0:02d}_gt2.wav")
1244
- # scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
1245
-
1246
  # After reconstructed the audio file 1, we set the original processor back
1247
  if noisy_latent_with_lora:
1248
  self.unet.set_attn_processor(original_processor)
1249
- # print(self.unet.attn_processors)
1250
 
1251
- # For the second audio file
1252
  if noisy_latent_with_lora:
1253
  self.unet = load_lora(self.unet, lora_1, lora_2, 1)
1254
- # print(self.unet.attn_processors)
1255
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1256
  audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
1257
- # mel_spectrogram = self.vae.decode(audio_latent).sample
1258
- # last_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
1259
- # last_audio = last_audio[:, :original_waveform_length]
1260
- # torchaudio.save(f"{self.output_path}/{num_frames-1:02d}_gt.wav", last_audio, 16000)
1261
  # aud_noise_2 is the noisy latent representation of the audio file 2
1262
- aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps)
1263
- # waveform = pipeline_trained(
1264
- # audio_file = audio_file2,
1265
- # time_pooling= 2,
1266
- # freq_pooling= 2,
1267
- # prompt= prompt_2,
1268
- # latents = aud_noise_2,
1269
- # negative_prompt= negative_prompt_2,
1270
- # num_inference_steps= 100,
1271
- # guidance_scale= guidance_scale,
1272
- # num_waveforms_per_prompt= 1,
1273
- # audio_length_in_s=10,
1274
- # ).audios
1275
- # file_path = os.path.join(self.output_path, f"{num_frames-1:02d}_gt2.wav")
1276
- # scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
1277
  if noisy_latent_with_lora:
1278
  self.unet.set_attn_processor(original_processor)
1279
- # print(self.unet.attn_processors)
1280
  # After reconstructed the audio file 1, we set the original processor back
1281
  original_processor = list(self.unet.attn_processors.values())[0]
1282
-
1283
-
1284
  def morph(alpha_list, desc):
1285
  audios = []
1286
  # if attn_beta is not None:
@@ -1288,11 +1227,9 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1288
  self.unet = load_lora(
1289
  self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
1290
  attn_processor_dict = {}
1291
- # print(self.unet.attn_processors)
1292
  for k in self.unet.attn_processors.keys():
1293
  # print(k)
1294
  if do_replace_attn(k):
1295
- # print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
1296
  if self.use_lora:
1297
  attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
1298
  self.aud1_dict, k)
@@ -1300,16 +1237,8 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1300
  attn_processor_dict[k] = StoreProcessor(original_processor,
1301
  self.aud1_dict, k)
1302
  else:
1303
- attn_processor_dict[k] = self.unet.attn_processors[k]
1304
- # print(attn_processor_dict)
1305
-
1306
- # print(attn_processor_dict)
1307
-
1308
- # print(self.unet.attn_processors)
1309
- # self.unet.set_attn_processor(attn_processor_dict)
1310
- # print(self.unet.attn_processors)
1311
-
1312
- first_audio = self.cal_latent(
1313
  audio_length_in_s,
1314
  time_pooling,
1315
  freq_pooling,
@@ -1335,14 +1264,12 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1335
  self.unet.set_attn_processor(original_processor)
1336
  file_path = os.path.join(self.output_path, f"{0:02d}.wav")
1337
  scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
1338
-
1339
  if self.use_lora:
1340
  self.unet = load_lora(
1341
  self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
1342
  attn_processor_dict = {}
1343
  for k in self.unet.attn_processors.keys():
1344
  if do_replace_attn(k):
1345
- # print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
1346
  if self.use_lora:
1347
  attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
1348
  self.aud2_dict, k)
@@ -1351,8 +1278,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1351
  self.aud2_dict, k)
1352
  else:
1353
  attn_processor_dict[k] = self.unet.attn_processors[k]
1354
- # self.unet.set_attn_processor(attn_processor_dict)
1355
- last_audio = self.cal_latent(
1356
  audio_length_in_s,
1357
  time_pooling,
1358
  freq_pooling,
@@ -1376,6 +1302,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1376
  )
1377
  file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
1378
  scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
 
1379
  self.unet.set_attn_processor(original_processor)
1380
 
1381
  for i in tqdm(range(1, num_frames - 1), desc=desc):
@@ -1395,8 +1322,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1395
  original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
1396
  else:
1397
  attn_processor_dict[k] = self.unet.attn_processors[k]
1398
- # self.unet.set_attn_processor(attn_processor_dict)
1399
- audio = self.cal_latent(
1400
  audio_length_in_s,
1401
  time_pooling,
1402
  freq_pooling,
 
49
  import librosa
50
  import warnings
51
  import matplotlib.pyplot as plt
 
52
  from .pipeline_audioldm2 import AudioLDM2Pipeline
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ warnings.filterwarnings("ignore", category=FutureWarning)
56
+ warnings.filterwarnings("ignore", category=UserWarning)
57
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
58
 
59
 
60
  def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
 
73
  plt.show()
74
 
75
 
 
 
 
 
76
  class StoreProcessor():
77
  def __init__(self, original_processor, value_dict, name):
78
  self.original_processor = original_processor
 
84
  def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
85
  # Is self attention
86
  if encoder_hidden_states is None:
 
 
87
  # print(f'In StoreProcessor: {self.name} {self.id}')
88
  self.value_dict[self.name][self.id] = hidden_states.detach()
89
  self.id += 1
 
90
  res = self.original_processor(attn, hidden_states, *args,
91
  encoder_hidden_states=encoder_hidden_states,
92
  attention_mask=attention_mask,
 
108
 
109
  def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
110
  # Is self attention
 
111
  if encoder_hidden_states is None:
 
112
  if self.id < 10 * self.lamd:
113
  map0 = self.aud1_dict[self.name][self.id]
114
  map1 = self.aud2_dict[self.name][self.id]
115
  cross_map = self.beta * hidden_states + \
116
  (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
 
117
  res = self.original_processor(attn, hidden_states, *args,
118
  encoder_hidden_states=cross_map,
119
  attention_mask=attention_mask,
120
  **kwargs)
121
  else:
 
122
  res = self.original_processor(attn, hidden_states, *args,
123
  encoder_hidden_states=encoder_hidden_states,
124
  attention_mask=attention_mask,
125
  **kwargs)
126
 
127
  self.id += 1
 
128
  if self.id == len(self.aud1_dict[self.name]):
129
  self.id = 0
130
  else:
 
131
  res = self.original_processor(attn, hidden_states, *args,
132
  encoder_hidden_states=encoder_hidden_states,
133
  attention_mask=attention_mask,
 
843
  ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
844
  # print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
845
  mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
846
+ model = AudioMAEConditionCTPoolRand().cuda()
847
  model.eval()
848
  LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
849
  uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
 
867
 
868
  return prompt_embeds, attention_mask, generated_prompt_embeds
869
 
870
+ def init_trained_pipeline(self, model_path, device, dtype, ap_scale, text_ap_scale):
871
+ pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=dtype).to(device)
872
+ layer_num = 0
873
+ cross = [None, None, 768, 768, 1024, 1024, None, None]
874
+ unet = pipeline_trained.unet
875
+ attn_procs = {}
876
+ for name in unet.attn_processors.keys():
877
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
878
+ if name.startswith("mid_block"):
879
+ hidden_size = unet.config.block_out_channels[-1]
880
+ elif name.startswith("up_blocks"):
881
+ block_id = int(name[len("up_blocks.")])
882
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
883
+ elif name.startswith("down_blocks"):
884
+ block_id = int(name[len("down_blocks.")])
885
+ hidden_size = unet.config.block_out_channels[block_id]
886
+
887
+ if cross_attention_dim is None:
888
+ attn_procs[name] = AttnProcessor2_0()
889
+ else:
890
+ cross_attention_dim = cross[layer_num % 8]
891
+ layer_num += 1
892
+ if cross_attention_dim == 768:
893
+ attn_procs[name] = IPAttnProcessor2_0(
894
+ hidden_size=hidden_size,
895
+ name=name,
896
+ flag='trained',
897
+ cross_attention_dim=cross_attention_dim,
898
+ text_scale=text_ap_scale,
899
+ scale=ap_scale,
900
+ num_tokens=8,
901
+ do_copy=False
902
+ ).to(device, dtype=dtype)
903
+ else:
904
+ attn_procs[name] = AttnProcessor2_0()
905
+
906
+ state_dict = torch.load(model_path, map_location=device)
907
+ for name, processor in attn_procs.items():
908
+ if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
909
+ weight_name_v = name + ".to_v_ip.weight"
910
+ weight_name_k = name + ".to_k_ip.weight"
911
+ if dtype == torch.float32:
912
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].float())
913
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].float())
914
+ elif dtype == torch.float16:
915
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
916
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
917
+ unet.set_attn_processor(attn_procs)
918
+ class _Wrapper(AttnProcsLayers):
919
+ def forward(self, *args, **kwargs):
920
+ return unet(*args, **kwargs)
921
+
922
+ unet = _Wrapper(unet.attn_processors)
923
+
924
+ return pipeline_trained
925
+
926
  @torch.no_grad()
927
  def aud2latent(self, audio_path, audio_length_in_s):
928
  DEVICE = torch.device(
929
  "cuda") if torch.cuda.is_available() else torch.device("cpu")
 
 
 
 
 
 
930
  mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
931
  output_path = audio_path.replace('.wav', '_fbank.png')
932
  visualize_mel_spectrogram(mel_spect_tensor, output_path)
 
939
  @torch.no_grad()
940
  def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
941
  start_step = 0
942
+ # print(f"Scheduler timesteps: {self.scheduler.timesteps}")
943
+ num_inference_steps = min(num_inference_steps, int(max(self.scheduler.timesteps)))
944
  device = start_latents.device
945
  self.scheduler.set_timesteps(num_inference_steps, device=device)
946
  start_latents *= self.scheduler.init_noise_sigma
 
959
  def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
960
  closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
961
  prompt = (
962
+ f"Jazz style music"
 
 
963
  )
964
  return prompt
965
 
 
967
  def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
968
  prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
969
  alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
970
+ num_inference_steps = min(num_inference_steps, int(max(self.pipeline_trained.scheduler.timesteps)))
971
  latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
972
  if not use_morph_prompt:
973
+ print("Not using morphing prompt")
974
  max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
975
  if prompt_embeds_1.shape[1] < max_length:
976
  pad_size = max_length - prompt_embeds_1.shape[1]
 
1019
  # attention_mask = (attention_mask > 0.5).long()
1020
 
1021
  if morphing_with_lora:
1022
+ self.pipeline_trained.unet.set_attn_processor(attn_processor_dict)
1023
+ waveform = self.pipeline_trained(
1024
  time_pooling= time_pooling,
1025
  freq_pooling= freq_pooling,
1026
  latents = latents,
1027
  num_inference_steps= num_inference_steps,
1028
+ guidance_scale = guidance_scale,
1029
  num_waveforms_per_prompt= 1,
1030
  audio_length_in_s=audio_length_in_s,
1031
  prompt_embeds = prompt_embeds.chunk(2)[1],
 
1036
  negative_attention_mask = attention_mask.chunk(2)[0],
1037
  ).audios[0]
1038
  if morphing_with_lora:
1039
+ self.pipeline_trained.unet.set_attn_processor(original_processor)
1040
  else:
1041
  latent_model_input = latents
1042
  morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
1043
  if morphing_with_lora:
1044
+ self.pipeline_trained.unet.set_attn_processor(attn_processor_dict)
1045
+ waveform = self.pipeline_trained(
1046
  time_pooling= time_pooling,
1047
  freq_pooling= freq_pooling,
1048
  latents = latent_model_input,
 
1054
  negative_prompt= 'Low quality',
1055
  ).audios[0]
1056
  if morphing_with_lora:
1057
+ self.pipeline_trained.unet.set_attn_processor(original_processor)
1058
 
1059
+ return waveform, latents
1060
 
1061
  @torch.no_grad()
1062
  def __call__(
1063
  self,
1064
+ dtype,
1065
  audio_file = None,
1066
  audio_file2 = None,
1067
+ ap_scale = 1.0,
1068
+ text_ap_scale = 1.0,
1069
  save_lora_dir = "./lora",
1070
  load_lora_path_1 = None,
1071
  load_lora_path_2 = None,
 
1089
  attn_beta=0,
1090
  lamd=0.6,
1091
  fix_lora=None,
 
1092
  num_frames=50,
1093
  max_new_tokens: Optional[int] = None,
1094
  callback_steps: Optional[int] = 1,
 
1096
  morphing_with_lora=False,
1097
  use_morph_prompt=False,
1098
  ):
1099
+ ap_adapter_path = 'pytorch_model.bin'
1100
  device = "cuda" if torch.cuda.is_available() else "cpu"
1101
  # 0. Load the pre-trained AP-adapter model
1102
  layer_num = 0
 
1112
  elif name.startswith("down_blocks"):
1113
  block_id = int(name[len("down_blocks.")])
1114
  hidden_size = self.unet.config.block_out_channels[block_id]
 
1115
  if cross_attention_dim is None:
1116
  attn_procs[name] = AttnProcessor2_0()
1117
  else:
1118
  cross_attention_dim = cross[layer_num % 8]
1119
  layer_num += 1
1120
  if cross_attention_dim == 768:
1121
+ attn_procs[name].scale = IPAttnProcessor2_0(
1122
  hidden_size=hidden_size,
1123
  name=name,
1124
  cross_attention_dim=cross_attention_dim,
1125
+ text_scale=100,
1126
+ scale=ap_scale,
1127
  num_tokens=8,
1128
  do_copy=False
1129
+ ).to(device, dtype=dtype)
1130
  else:
1131
  attn_procs[name] = AttnProcessor2_0()
1132
+ state_dict = torch.load(ap_adapter_path, map_location=device)
 
1133
  for name, processor in attn_procs.items():
1134
  if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
1135
  weight_name_v = name + ".to_v_ip.weight"
1136
  weight_name_k = name + ".to_k_ip.weight"
1137
+ if dtype == torch.float32:
1138
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].float())
1139
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].float())
1140
+ elif dtype == torch.float16:
1141
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
1142
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
1143
  self.unet.set_attn_processor(attn_procs)
1144
+ self.pipeline_trained = self.init_trained_pipeline(ap_adapter_path, device, dtype, ap_scale, text_ap_scale)
 
 
 
 
 
 
1145
 
 
 
1146
  # 1. Pre-check
1147
  height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
1148
  _, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
1149
  # print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
1150
 
1151
  # # 2. Define call parameters
1152
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1153
  do_classifier_free_guidance = guidance_scale > 1.0
1154
  self.use_lora = use_lora
1155
  self.use_adain = use_adain
 
1163
  weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
1164
  load_lora_path_1 = save_lora_dir + "/" + weight_name
1165
  if not os.path.exists(load_lora_path_1):
1166
+ train_lora(audio_file, dtype, time_pooling ,freq_pooling ,prompt_1, negative_prompt_1, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
1167
  self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
1168
  self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
1169
  print(f"Load from {load_lora_path_1}.")
 
1178
  weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
1179
  load_lora_path_2 = save_lora_dir + "/" + weight_name
1180
  if not os.path.exists(load_lora_path_2):
1181
+ train_lora(audio_file2, dtype,time_pooling ,freq_pooling ,prompt_2, negative_prompt_2, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
1182
  self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
1183
  self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
1184
  print(f"Load from {load_lora_path_2}.")
 
1197
 
1198
 
1199
  # 4. Prepare latent variables
1200
+ # ------- For the first audio file -------
1201
  original_processor = list(self.unet.attn_processors.values())[0]
 
1202
  if noisy_latent_with_lora:
1203
  self.unet = load_lora(self.unet, lora_1, lora_2, 0)
 
1204
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1205
  audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
 
 
 
 
 
1206
  # aud_noise_1 is the noisy latent representation of the audio file 1
1207
+ aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps = num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1208
  # After reconstructed the audio file 1, we set the original processor back
1209
  if noisy_latent_with_lora:
1210
  self.unet.set_attn_processor(original_processor)
 
1211
 
1212
+ # ------- For the second audio file -------
1213
  if noisy_latent_with_lora:
1214
  self.unet = load_lora(self.unet, lora_1, lora_2, 1)
 
1215
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1216
  audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
 
 
 
 
1217
  # aud_noise_2 is the noisy latent representation of the audio file 2
1218
+ aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps = num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1219
  if noisy_latent_with_lora:
1220
  self.unet.set_attn_processor(original_processor)
 
1221
  # After reconstructed the audio file 1, we set the original processor back
1222
  original_processor = list(self.unet.attn_processors.values())[0]
 
 
1223
  def morph(alpha_list, desc):
1224
  audios = []
1225
  # if attn_beta is not None:
 
1227
  self.unet = load_lora(
1228
  self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
1229
  attn_processor_dict = {}
 
1230
  for k in self.unet.attn_processors.keys():
1231
  # print(k)
1232
  if do_replace_attn(k):
 
1233
  if self.use_lora:
1234
  attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
1235
  self.aud1_dict, k)
 
1237
  attn_processor_dict[k] = StoreProcessor(original_processor,
1238
  self.aud1_dict, k)
1239
  else:
1240
+ attn_processor_dict[k] = self.unet.attn_processors[k]
1241
+ first_audio, first_latents = self.cal_latent(
 
 
 
 
 
 
 
 
1242
  audio_length_in_s,
1243
  time_pooling,
1244
  freq_pooling,
 
1264
  self.unet.set_attn_processor(original_processor)
1265
  file_path = os.path.join(self.output_path, f"{0:02d}.wav")
1266
  scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
 
1267
  if self.use_lora:
1268
  self.unet = load_lora(
1269
  self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
1270
  attn_processor_dict = {}
1271
  for k in self.unet.attn_processors.keys():
1272
  if do_replace_attn(k):
 
1273
  if self.use_lora:
1274
  attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
1275
  self.aud2_dict, k)
 
1278
  self.aud2_dict, k)
1279
  else:
1280
  attn_processor_dict[k] = self.unet.attn_processors[k]
1281
+ last_audio, last_latents = self.cal_latent(
 
1282
  audio_length_in_s,
1283
  time_pooling,
1284
  freq_pooling,
 
1302
  )
1303
  file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
1304
  scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
1305
+
1306
  self.unet.set_attn_processor(original_processor)
1307
 
1308
  for i in tqdm(range(1, num_frames - 1), desc=desc):
 
1322
  original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
1323
  else:
1324
  attn_processor_dict[k] = self.unet.attn_processors[k]
1325
+ audio, latents = self.cal_latent(
 
1326
  audio_length_in_s,
1327
  time_pooling,
1328
  freq_pooling,
utils/lora_utils_successed_ver1.py CHANGED
@@ -449,7 +449,7 @@ def plot_loss(loss_history, loss_plot_path, lora_steps):
449
  # lora_steps: number of lora training step
450
  # lora_lr: learning rate of lora training
451
  # lora_rank: the rank of lora
452
- def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None,
453
  text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
454
  vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
455
  time_pooling = time_pooling
@@ -534,7 +534,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
534
  scale=1.0,
535
  num_tokens=8,
536
  do_copy = do_copy
537
- ).to(device, dtype=torch.float32)
538
  else:
539
  unet_lora_attn_procs[name] = AttnProcessor2_0()
540
  unet.set_attn_processor(unet_lora_attn_procs)
@@ -580,7 +580,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
580
  fbank = torch.zeros((1024, 128))
581
  ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
582
  mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
583
- model = AudioMAEConditionCTPoolRand().to(device).to(dtype=torch.float32)
584
  model.eval()
585
  mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
586
  LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
@@ -599,24 +599,6 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
599
  generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
600
  model_dtype = next(unet.parameters()).dtype
601
  generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
602
-
603
- # num_channels_latents = unet.config.in_channels
604
- # batch_size = 1
605
- # num_waveforms_per_prompt = 1
606
- # generator = None
607
- # latents = None
608
- # latents = prepare_latents(
609
- # vae,
610
- # vocoder,
611
- # noise_scheduler,
612
- # batch_size * num_waveforms_per_prompt,
613
- # num_channels_latents,
614
- # height,
615
- # prompt_embeds.dtype,
616
- # device,
617
- # generator,
618
- # latents,
619
- # )
620
 
621
  loss_history = []
622
  if not os.path.exists(save_lora_dir):
@@ -683,7 +665,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
683
  safe_serialization=safe_serialization
684
  )
685
 
686
- def load_lora(unet, lora_0, lora_1, alpha):
687
  attn_procs = unet.attn_processors
688
  for name, processor in attn_procs.items():
689
  if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
@@ -691,10 +673,16 @@ def load_lora(unet, lora_0, lora_1, alpha):
691
  weight_name_k = name + ".to_k_ip.weight"
692
  if weight_name_v in lora_0 and weight_name_v in lora_1:
693
  v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
694
- processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half())
 
 
 
695
 
696
  if weight_name_k in lora_0 and weight_name_k in lora_1:
697
  k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
698
- processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half())
 
 
 
699
  unet.set_attn_processor(attn_procs)
700
  return unet
 
449
  # lora_steps: number of lora training step
450
  # lora_lr: learning rate of lora training
451
  # lora_rank: the rank of lora
452
+ def train_lora(audio_path ,dtype ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None,
453
  text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
454
  vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
455
  time_pooling = time_pooling
 
534
  scale=1.0,
535
  num_tokens=8,
536
  do_copy = do_copy
537
+ ).to(device, dtype=dtype)
538
  else:
539
  unet_lora_attn_procs[name] = AttnProcessor2_0()
540
  unet.set_attn_processor(unet_lora_attn_procs)
 
580
  fbank = torch.zeros((1024, 128))
581
  ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
582
  mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
583
+ model = AudioMAEConditionCTPoolRand().to(device).to(dtype=dtype)
584
  model.eval()
585
  mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
586
  LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
 
599
  generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
600
  model_dtype = next(unet.parameters()).dtype
601
  generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
 
603
  loss_history = []
604
  if not os.path.exists(save_lora_dir):
 
665
  safe_serialization=safe_serialization
666
  )
667
 
668
+ def load_lora(unet, lora_0, lora_1, alpha, dtype):
669
  attn_procs = unet.attn_processors
670
  for name, processor in attn_procs.items():
671
  if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
 
673
  weight_name_k = name + ".to_k_ip.weight"
674
  if weight_name_v in lora_0 and weight_name_v in lora_1:
675
  v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
676
+ if dtype == torch.float32:
677
+ processor.to_v_ip.weight = torch.nn.Parameter(v_weight.float())
678
+ elif dtype == torch.float16:
679
+ processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half())
680
 
681
  if weight_name_k in lora_0 and weight_name_k in lora_1:
682
  k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
683
+ if dtype == torch.float32:
684
+ processor.to_k_ip.weight = torch.nn.Parameter(k_weight.float())
685
+ elif dtype == torch.float16:
686
+ processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half())
687
  unet.set_attn_processor(attn_procs)
688
  return unet